KelSolaar commited on
Commit
fa06c67
·
0 Parent(s):

Initial commit.

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +37 -0
  2. .gitignore +37 -0
  3. .pre-commit-config.yaml +39 -0
  4. LICENSE +11 -0
  5. README.md +278 -0
  6. docs/_static/gamma_sweep_plot.pdf +0 -0
  7. docs/_static/gamma_sweep_plot.png +3 -0
  8. docs/learning_munsell.md +478 -0
  9. learning_munsell/__init__.py +7 -0
  10. learning_munsell/analysis/__init__.py +1 -0
  11. learning_munsell/analysis/error_analysis.py +304 -0
  12. learning_munsell/comparison/from_xyY/__init__.py +1 -0
  13. learning_munsell/comparison/from_xyY/compare_all_models.py +1292 -0
  14. learning_munsell/comparison/from_xyY/compare_gamma_model.py +390 -0
  15. learning_munsell/comparison/to_xyY/__init__.py +1 -0
  16. learning_munsell/comparison/to_xyY/compare_all_models.py +617 -0
  17. learning_munsell/data_generation/generate_training_data.py +310 -0
  18. learning_munsell/interpolation/__init__.py +1 -0
  19. learning_munsell/interpolation/from_xyY/__init__.py +43 -0
  20. learning_munsell/interpolation/from_xyY/compare_methods.py +208 -0
  21. learning_munsell/interpolation/from_xyY/delaunay_interpolator.py +283 -0
  22. learning_munsell/interpolation/from_xyY/kdtree_interpolator.py +263 -0
  23. learning_munsell/interpolation/from_xyY/rbf_interpolator.py +300 -0
  24. learning_munsell/losses/__init__.py +17 -0
  25. learning_munsell/losses/jax_delta_e.py +299 -0
  26. learning_munsell/models/__init__.py +47 -0
  27. learning_munsell/models/networks.py +1294 -0
  28. learning_munsell/training/from_xyY/__init__.py +1 -0
  29. learning_munsell/training/from_xyY/hyperparameter_search_error_predictor.py +503 -0
  30. learning_munsell/training/from_xyY/hyperparameter_search_multi_head.py +541 -0
  31. learning_munsell/training/from_xyY/hyperparameter_search_multi_head_error_predictor.py +552 -0
  32. learning_munsell/training/from_xyY/hyperparameter_search_multi_mlp.py +471 -0
  33. learning_munsell/training/from_xyY/refine_multi_head_real.py +358 -0
  34. learning_munsell/training/from_xyY/train_deep_wide.py +371 -0
  35. learning_munsell/training/from_xyY/train_ft_transformer.py +356 -0
  36. learning_munsell/training/from_xyY/train_mixture_of_experts.py +620 -0
  37. learning_munsell/training/from_xyY/train_mlp.py +269 -0
  38. learning_munsell/training/from_xyY/train_mlp_attention.py +460 -0
  39. learning_munsell/training/from_xyY/train_mlp_error_predictor.py +457 -0
  40. learning_munsell/training/from_xyY/train_mlp_gamma.py +297 -0
  41. learning_munsell/training/from_xyY/train_multi_head_3stage_error_predictor.py +411 -0
  42. learning_munsell/training/from_xyY/train_multi_head_circular.py +479 -0
  43. learning_munsell/training/from_xyY/train_multi_head_cross_attention_error_predictor.py +640 -0
  44. learning_munsell/training/from_xyY/train_multi_head_gamma.py +300 -0
  45. learning_munsell/training/from_xyY/train_multi_head_gamma_sweep.py +605 -0
  46. learning_munsell/training/from_xyY/train_multi_head_large.py +246 -0
  47. learning_munsell/training/from_xyY/train_multi_head_mlp.py +269 -0
  48. learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor.py +378 -0
  49. learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor_large.py +409 -0
  50. learning_munsell/training/from_xyY/train_multi_head_st2084.py +313 -0
.gitattributes ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.onnx.data filter=lfs diff=lfs merge=lfs -text
18
+ *.ot filter=lfs diff=lfs merge=lfs -text
19
+ *.parquet filter=lfs diff=lfs merge=lfs -text
20
+ *.pb filter=lfs diff=lfs merge=lfs -text
21
+ *.pickle filter=lfs diff=lfs merge=lfs -text
22
+ *.pkl filter=lfs diff=lfs merge=lfs -text
23
+ *.pt filter=lfs diff=lfs merge=lfs -text
24
+ *.pth filter=lfs diff=lfs merge=lfs -text
25
+ *.rar filter=lfs diff=lfs merge=lfs -text
26
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
27
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
29
+ *.tar filter=lfs diff=lfs merge=lfs -text
30
+ *.tflite filter=lfs diff=lfs merge=lfs -text
31
+ *.tgz filter=lfs diff=lfs merge=lfs -text
32
+ *.wasm filter=lfs diff=lfs merge=lfs -text
33
+ *.xz filter=lfs diff=lfs merge=lfs -text
34
+ *.zip filter=lfs diff=lfs merge=lfs -text
35
+ *.zst filter=lfs diff=lfs merge=lfs -text
36
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
37
+ *.png filter=lfs diff=lfs merge=lfs -text
.gitignore ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Common Files
2
+ *.egg-info
3
+ *.pyc
4
+ *.pyo
5
+ .DS_Store
6
+ .coverage*
7
+ uv.lock
8
+
9
+ # Common Directories
10
+ .fleet/
11
+ .idea/
12
+ .ipynb_checkpoints/
13
+ .python-version
14
+ .vs/
15
+ .vscode/
16
+ .sandbox/
17
+ build/
18
+ dist/
19
+ docs/_build/
20
+ docs/generated/
21
+ node_modules/
22
+ references/
23
+
24
+ __pycache__
25
+
26
+ .claude/settings.local.json
27
+ .claude/scratchpad.md
28
+
29
+ # Project Directories
30
+ data/
31
+ logs/
32
+ mlartifacts/
33
+ mlruns/
34
+ mlruns.db
35
+ reports/
36
+ results/
37
+ runs/
.pre-commit-config.yaml ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ repos:
2
+ - repo: https://github.com/pre-commit/pre-commit-hooks
3
+ rev: "v5.0.0"
4
+ hooks:
5
+ - id: check-added-large-files
6
+ - id: check-case-conflict
7
+ - id: check-merge-conflict
8
+ - id: check-symlinks
9
+ - id: check-yaml
10
+ - id: debug-statements
11
+ - id: end-of-file-fixer
12
+ - id: mixed-line-ending
13
+ - id: requirements-txt-fixer
14
+ - id: trailing-whitespace
15
+ - repo: https://github.com/codespell-project/codespell
16
+ rev: v2.4.1
17
+ hooks:
18
+ - id: codespell
19
+ args: ["--ignore-words-list=colour"]
20
+ - repo: https://github.com/PyCQA/isort
21
+ rev: "6.0.1"
22
+ hooks:
23
+ - id: isort
24
+ - repo: https://github.com/astral-sh/ruff-pre-commit
25
+ rev: "v0.12.4"
26
+ hooks:
27
+ - id: ruff-format
28
+ - id: ruff
29
+ args: [--fix]
30
+ - repo: https://github.com/pre-commit/mirrors-prettier
31
+ rev: "v4.0.0-alpha.8"
32
+ hooks:
33
+ - id: prettier
34
+ - repo: https://github.com/pre-commit/pygrep-hooks
35
+ rev: "v1.10.0"
36
+ hooks:
37
+ - id: rst-backticks
38
+ - id: rst-directive-colons
39
+ - id: rst-inline-touching-normal
LICENSE ADDED
@@ -0,0 +1,11 @@
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Copyright 2025 Colour Developers
2
+
3
+ Redistribution and use in source and binary forms, with or without modification, are permitted provided that the following conditions are met:
4
+
5
+ 1. Redistributions of source code must retain the above copyright notice, this list of conditions and the following disclaimer.
6
+
7
+ 2. Redistributions in binary form must reproduce the above copyright notice, this list of conditions and the following disclaimer in the documentation and/or other materials provided with the distribution.
8
+
9
+ 3. Neither the name of the copyright holder nor the names of its contributors may be used to endorse or promote products derived from this software without specific prior written permission.
10
+
11
+ THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS “AS IS” AND ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE
README.md ADDED
@@ -0,0 +1,278 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: bsd-3-clause
3
+ language:
4
+ - en
5
+ tags:
6
+ - python
7
+ - colour
8
+ - color
9
+ - colour-science
10
+ - color-science
11
+ - colour-spaces
12
+ - color-spaces
13
+ - colourspace
14
+ - colorspace
15
+ pipeline_tag: tabular-regression
16
+ library_name: onnxruntime
17
+ metrics:
18
+ - mae
19
+ model-index:
20
+ - name: from_xyY (CIE xyY to Munsell)
21
+ results:
22
+ - task:
23
+ type: tabular-regression
24
+ name: CIE xyY to Munsell Specification
25
+ dataset:
26
+ name: CIE xyY to Munsell Specification
27
+ type: munsell-renotation
28
+ metrics:
29
+ - type: delta-e
30
+ value: 0.52
31
+ name: Delta-E CIE2000
32
+ - type: inference_time_ms
33
+ value: 0.089
34
+ name: Inference Time (ms/sample)
35
+ - name: to_xyY (Munsell to CIE xyY)
36
+ results:
37
+ - task:
38
+ type: tabular-regression
39
+ name: Munsell Specification to CIE xyY
40
+ dataset:
41
+ name: Munsell Specification to CIE xyY
42
+ type: munsell-renotation
43
+ metrics:
44
+ - type: delta-e
45
+ value: 0.48
46
+ name: Delta-E CIE2000
47
+ - type: inference_time_ms
48
+ value: 0.008
49
+ name: Inference Time (ms/sample)
50
+ ---
51
+
52
+ # Learning Munsell - Machine Learning for Munsell Color Conversions
53
+
54
+ A project implementing machine learning-based methods for bidirectional conversion between CIE xyY colourspace values and Munsell specifications.
55
+
56
+ **Two Conversion Directions:**
57
+
58
+ - **from_xyY**: CIE xyY to Munsell specification
59
+ - **to_xyY**: Munsell specification to CIE xyY
60
+
61
+ ## Project Overview
62
+
63
+ ### Objective
64
+
65
+ Provide 100-1000x speedup for batch Munsell conversions compared to colour-science routines while maintaining high accuracy.
66
+
67
+ ### Results
68
+
69
+ **from_xyY** (CIE xyY to Munsell) — evaluated on all 2,734 REAL Munsell colors:
70
+
71
+ | Model | Delta-E | Speed (ms) |
72
+ |----------------------------------------------------------| ---------- | ---------- |
73
+ | Colour Library (Baseline) | 0.00 | 111.90 |
74
+ | **Multi-ResNet + Multi-Error Predictor (Large Dataset)** | **0.52** | 0.089 |
75
+ | Multi-MLP (W+B) + Multi-Error Predictor (W+B) Large | 0.52 | 0.057 |
76
+ | Multi-MLP + Multi-Error Predictor (Large Dataset) | 0.52 | 0.058 |
77
+ | Multi-MLP + Multi-Error Predictor | 0.53 | 0.058 |
78
+ | MLP + Error Predictor | 0.53 | 0.030 |
79
+ | Multi-ResNet (Large Dataset) | 0.54 | 0.044 |
80
+ | Multi-Head + Multi-Error Predictor | 0.54 | 0.042 |
81
+ | Multi-Head + Multi-Error Predictor (Large Dataset) | 0.56 | 0.043 |
82
+ | Deep + Wide | 0.60 | 0.074 |
83
+ | Multi-Head (Large Dataset) | 0.66 | 0.013 |
84
+ | Mixture of Experts | 0.80 | 0.020 |
85
+ | Transformer (Large Dataset) | 0.82 | 0.123 |
86
+ | Multi-MLP | 0.86 | 0.027 |
87
+ | MLP + Self-Attention | 0.88 | 0.173 |
88
+ | MLP (Base Only) | 1.09 | **0.007** |
89
+ | Unified MLP | 1.12 | 0.072 |
90
+
91
+ - **Best Accuracy**: Multi-ResNet + Multi-Error Predictor (Large Dataset) — Delta-E 0.52, 1,252x faster
92
+ - **Fastest**: MLP Base Only (0.007 ms/sample) — 15,492x faster than Colour library
93
+ - **Best Balance**: Multi-MLP (W+B: Weighted Boundary) + Multi-Error Predictor (W+B) Large — 1,951x faster with Delta-E 0.52
94
+
95
+ **to_xyY** (Munsell to CIE xyY) — evaluated on all 2,734 REAL Munsell colors:
96
+
97
+ | Model | Delta-E | Speed (ms) |
98
+ | --------------------------------------------- | ---------- | ----------- |
99
+ | Colour Library (Baseline) | 0.00 | 1.27 |
100
+ | **Multi-MLP (Optimized)** | **0.48** | 0.008 |
101
+ | Multi-MLP (Opt) + Multi-Error Predictor (Opt) | 0.48 | 0.025 |
102
+ | Multi-MLP + Multi-Error Predictor | 0.65 | 0.030 |
103
+ | Multi-MLP | 0.66 | 0.016 |
104
+ | Multi-MLP + Error Predictor | 0.67 | 0.018 |
105
+ | Multi-Head (Optimized) | 0.71 | 0.015 |
106
+ | Multi-Head | 0.78 | 0.008 |
107
+ | Multi-Head + Multi-Error Predictor | 1.11 | 0.028 |
108
+ | Simple MLP | 1.42 | **0.0008** |
109
+
110
+ - **Best Accuracy**: Multi-MLP (Optimized) — Delta-E 0.48, 154x faster
111
+ - **Fastest**: Simple MLP (0.0008 ms/sample) — 1,654x faster than Colour library
112
+
113
+ ### Approach
114
+
115
+ - **25+ architectures** tested for from_xyY (MLP, Multi-Head, Multi-MLP, Multi-ResNet, Transformers, Mixture of Experts)
116
+ - **9 architectures** tested for to_xyY (Simple MLP, Multi-Head, Multi-MLP with error predictors)
117
+ - **Two-stage models** (base + error predictor) on large dataset proved most effective
118
+ - **Best model**: Multi-ResNet + Multi-Error Predictor (Large Dataset) with Delta-E 0.52
119
+ - **Training data**: ~1.4M samples from dense xyY grid with boundary refinement and forward Munsell sampling
120
+ - **Deployment**: ONNX format with ONNX Runtime
121
+
122
+ For detailed architecture comparisons, model benchmarks, training pipeline details, and experimental results, see [docs/learning_munsell.md](docs/learning_munsell.md).
123
+
124
+ ## Installation
125
+
126
+ **Dependencies (Runtime)**:
127
+
128
+ - numpy >= 2.0
129
+ - onnxruntime >= 1.16
130
+
131
+ **Dependencies (Training)**:
132
+
133
+ - torch >= 2.0
134
+ - scikit-learn >= 1.3
135
+ - matplotlib >= 3.9
136
+ - mlflow >= 2.10
137
+ - optuna >= 3.0
138
+ - colour-science >= 0.4.7
139
+ - click >= 8.0
140
+ - onnx >= 1.15
141
+ - onnxscript >= 0.5.6
142
+ - tqdm >= 4.66
143
+ - jax >= 0.4.20
144
+ - jaxlib >= 0.4.20
145
+ - flax >= 0.10.7
146
+ - optax >= 0.2.6
147
+ - scipy >= 1.12
148
+ - tensorboard >= 2.20
149
+
150
+ From the project root:
151
+
152
+ ```bash
153
+ cd learning-munsell
154
+
155
+ # Install all dependencies (creates virtual environment automatically)
156
+ uv sync
157
+ ```
158
+
159
+ ## Usage
160
+
161
+ ### Generate Training Data
162
+
163
+ ```bash
164
+ uv run python learning_munsell/data_generation/generate_training_data.py
165
+ ```
166
+
167
+ **Note**: This step is computationally expensive (uses iterative algorithm for ground truth).
168
+
169
+ ### Train Models
170
+
171
+ **xyY to Munsell (from_xyY)**
172
+
173
+ Best performing model (Multi-ResNet + Multi-Error Predictor on Large Dataset):
174
+
175
+ ```bash
176
+ # Train base Multi-ResNet on large dataset (~1.4M samples)
177
+ uv run python learning_munsell/training/from_xyY/train_multi_resnet_large.py
178
+
179
+ # Train multi-error predictor
180
+ uv run python learning_munsell/training/from_xyY/train_multi_resnet_error_predictor_large.py
181
+ ```
182
+
183
+ Alternative (Multi-Head architecture):
184
+
185
+ ```bash
186
+ uv run python learning_munsell/training/from_xyY/train_multi_head_large.py
187
+ uv run python learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor_large.py
188
+ ```
189
+
190
+ Other architectures:
191
+
192
+ ```bash
193
+ uv run python learning_munsell/training/from_xyY/train_unified_mlp.py
194
+ uv run python learning_munsell/training/from_xyY/train_multi_mlp.py
195
+ uv run python learning_munsell/training/from_xyY/train_mlp_attention.py
196
+ uv run python learning_munsell/training/from_xyY/train_deep_wide.py
197
+ uv run python learning_munsell/training/from_xyY/train_ft_transformer.py
198
+ ```
199
+
200
+ **Munsell to xyY (to_xyY)**
201
+
202
+ Best performing model (Multi-MLP Optimized):
203
+
204
+ ```bash
205
+ uv run python learning_munsell/training/to_xyY/train_multi_mlp.py
206
+ uv run python learning_munsell/training/to_xyY/train_multi_head.py
207
+ uv run python learning_munsell/training/to_xyY/train_multi_mlp_multi_error_predictor.py
208
+ uv run python learning_munsell/training/to_xyY/train_multi_mlp_error_predictor.py
209
+ uv run python learning_munsell/training/to_xyY/train_multi_head_multi_error_predictor.py
210
+ ```
211
+
212
+ Train the differentiable approximator for use in Delta-E loss:
213
+
214
+ ```bash
215
+ uv run python learning_munsell/training/to_xyY/train_munsell_to_xyY_approximator.py
216
+ ```
217
+
218
+ ### Hyperparameter Search
219
+
220
+ ```bash
221
+ uv run python learning_munsell/training/from_xyY/hyperparameter_search_multi_head.py
222
+ uv run python learning_munsell/training/from_xyY/hyperparameter_search_multi_head_error_predictor.py
223
+ ```
224
+
225
+ ### Compare All Models
226
+
227
+ ```bash
228
+ uv run python learning_munsell/comparison/from_xyY/compare_all_models.py
229
+ ```
230
+
231
+ Generates comprehensive HTML report at `reports/from_xyY/model_comparison.html`.
232
+
233
+ ### Monitor Training
234
+
235
+ **MLflow**:
236
+
237
+ ```bash
238
+ uv run mlflow ui --backend-store-uri "sqlite:///mlruns.db" --port=5000
239
+ ```
240
+
241
+ Open <http://localhost:5000> in your browser.
242
+
243
+ ## Directory Structure
244
+
245
+ ```
246
+ learning-munsell/
247
+ +-- data/ # Training data
248
+ | +-- training_data.npz # Generated training samples
249
+ | +-- training_data_large.npz # Large dataset (~1.4M samples)
250
+ | +-- training_data_params.json # Generation parameters
251
+ | +-- training_data_large_params.json
252
+ +-- models/ # Trained models (ONNX + PyTorch)
253
+ | +-- from_xyY/ # xyY to Munsell models (25+ ONNX models)
254
+ | | +-- multi_resnet_error_predictor_large.onnx # BEST
255
+ | | +-- ... (additional model variants)
256
+ | +-- to_xyY/ # Munsell to xyY models (9 ONNX models)
257
+ | +-- multi_mlp_optimized.onnx # BEST
258
+ | +-- ... (additional model variants)
259
+ +-- learning_munsell/ # Source code
260
+ | +-- analysis/ # Analysis scripts
261
+ | +-- comparison/ # Model comparison scripts
262
+ | +-- data_generation/ # Data generation scripts
263
+ | +-- interpolation/ # Classical interpolation methods
264
+ | +-- losses/ # Loss functions (JAX Delta-E)
265
+ | +-- models/ # Model architecture definitions
266
+ | +-- training/ # Model training scripts
267
+ | +-- utilities/ # Shared utilities
268
+ +-- docs/ # Documentation
269
+ +-- reports/ # HTML comparison reports
270
+ +-- logs/ # Script output logs
271
+ +-- mlruns.db # MLflow experiment tracking database
272
+ ```
273
+
274
+ ## About
275
+
276
+ **Learning Munsell** by Colour Developers
277
+ Research project for the Colour library
278
+ <https://github.com/colour-science/colour>
docs/_static/gamma_sweep_plot.pdf ADDED
Binary file (22.2 kB). View file
 
docs/_static/gamma_sweep_plot.png ADDED

Git LFS Details

  • SHA256: e2a0d5dc57c0d37d5889cff4ac41a08b490387a54615d4372af5e5bd86018e36
  • Pointer size: 131 Bytes
  • Size of remote file: 137 kB
docs/learning_munsell.md ADDED
@@ -0,0 +1,478 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Learning Munsell
2
+
3
+ Technical documentation covering performance benchmarks, training methodology, architecture design, and experimental findings.
4
+
5
+ ## Overview
6
+
7
+ This project implements ML models for bidirectional conversion between CIE xyY colorspace values and Munsell specifications:
8
+
9
+ - **xyY to Munsell (from_xyY)**: 25+ architectures, best Delta-E 0.52
10
+ - **Munsell to xyY (to_xyY)**: 9 architectures, best Delta-E 0.48
11
+
12
+ ### Delta-E Interpretation
13
+
14
+ - **< 1.0**: Not perceptible by human eye
15
+ - **1-2**: Perceptible through close observation
16
+ - **2-10**: Perceptible at a glance
17
+ - **> 10**: Colors are perceived as completely different
18
+
19
+ Our best models achieve **Delta-E 0.48-0.52**, meaning the difference between ML prediction and iterative algorithm is **not perceptible by the human eye**.
20
+
21
+ ## xyY to Munsell (from_xyY)
22
+
23
+ ### Performance Benchmarks
24
+
25
+ Comprehensive comparison using all 2,734 REAL Munsell colors:
26
+
27
+ | Model | Delta-E | Speed (ms) |
28
+ |----------------------------------------------------------|-------------|------------|
29
+ | Colour Library (Baseline) | 0.00 | 111.90 |
30
+ | **Multi-ResNet + Multi-Error Predictor (Large Dataset)** | **0.52** | 0.089 |
31
+ | Multi-MLP (W+B) + Multi-Error Predictor (W+B) Large | 0.52 | 0.057 |
32
+ | Multi-MLP + Multi-Error Predictor (Large Dataset) | 0.52 | 0.058 |
33
+ | Multi-MLP + Multi-Error Predictor | 0.53 | 0.058 |
34
+ | MLP + Error Predictor | 0.53 | 0.030 |
35
+ | Multi-ResNet (Large Dataset) | 0.54 | 0.044 |
36
+ | Multi-Head + Multi-Error Predictor | 0.54 | 0.042 |
37
+ | Multi-Head + Multi-Error Predictor (Large Dataset) | 0.56 | 0.043 |
38
+ | Deep + Wide | 0.60 | 0.074 |
39
+ | Multi-Head (Large Dataset) | 0.66 | 0.013 |
40
+ | Mixture of Experts | 0.80 | 0.020 |
41
+ | Transformer (Large Dataset) | 0.82 | 0.123 |
42
+ | Multi-MLP | 0.86 | 0.027 |
43
+ | MLP + Self-Attention | 0.88 | 0.173 |
44
+ | MLP (Base Only) | 1.09 | **0.007** |
45
+ | Unified MLP | 1.12 | 0.072 |
46
+
47
+ Note: The Colour library baseline had 171 convergence failures out of 2,734 samples (6.3% failure rate).
48
+
49
+ **Best Models**:
50
+
51
+ - **Best Accuracy**: Multi-ResNet + Multi-Error Predictor (Large Dataset) - Delta-E 0.52
52
+ - **Fastest**: MLP Base Only (0.007 ms/sample) - 15,492x faster than Colour library
53
+ - **Best Balance**: Multi-MLP (W+B: Weighted Boundary) + Multi-Error Predictor (W+B) Large - 1,951x faster with Delta-E 0.52
54
+
55
+ ### Model Architectures
56
+
57
+ 25+ architectures were systematically evaluated:
58
+
59
+ **Single-Stage Models**
60
+
61
+ 1. **MLP (Base Only)** - Simple MLP network, 3 inputs to 4 outputs
62
+ 2. **Unified MLP** - Single large MLP with shared features
63
+ 3. **Multi-Head** - Shared encoder with 4 independent decoder heads
64
+ 4. **Multi-Head (Large Dataset)** - Multi-Head trained on 1.4M samples
65
+ 5. **Multi-MLP** - 4 completely independent MLP branches (one per output)
66
+ 6. **Multi-MLP (Large Dataset)** - Multi-MLP trained on 1.4M samples
67
+ 7. **MLP + Self-Attention** - MLP with attention mechanism for feature weighting
68
+ 8. **Deep + Wide** - Combined deep and wide network paths
69
+ 9. **Mixture of Experts** - Gating network selecting specialized expert networks
70
+ 10. **Transformer (Large Dataset)** - Feature Tokenizer Transformer for tabular data
71
+ 11. **FT-Transformer** - Feature Tokenizer Transformer (standard size)
72
+
73
+ **Two-Stage Models**
74
+
75
+ 12. **MLP + Error Predictor** - Base MLP with unified error correction
76
+ 13. **Multi-Head + Multi-Error Predictor** - Multi-Head with 4 independent error predictors
77
+ 14. **Multi-Head + Multi-Error Predictor (Large Dataset)** - Large dataset variant
78
+ 15. **Multi-MLP + Multi-Error Predictor** - 4 independent branches with 4 independent error predictors
79
+ 16. **Multi-MLP + Multi-Error Predictor (Large Dataset)** - Large dataset variant
80
+ 17. **Multi-ResNet + Multi-Error Predictor (Large Dataset)** - Deep ResNet-style branches (BEST)
81
+
82
+ The **Multi-ResNet + Multi-Error Predictor (Large Dataset)** architecture achieved the best results with Delta-E 0.52.
83
+
84
+ ### Training Methodology
85
+
86
+ **Data Generation**
87
+
88
+ 1. **Dense xyY Grid** (~500K samples)
89
+ - Regular grid in valid xyY space (MacAdam limits for Illuminant C)
90
+ - Captures general input distribution
91
+ 2. **Boundary Refinement** (~700K samples)
92
+ - Adaptive dense sampling near Munsell gamut boundaries
93
+ - Uses `maximum_chroma_from_renotation` to detect edges
94
+ - Focuses on regions where iterative algorithm is most complex
95
+ - Includes Y/GY/G hue regions with high value/chroma (challenging areas)
96
+ 3. **Forward Augmentation** (~200K samples)
97
+ - Dense Munsell space sampling via `munsell_specification_to_xyY`
98
+ - Ensures coverage of known valid colors
99
+
100
+ Total: ~1.4M samples for large dataset training.
101
+
102
+ **Loss Functions**
103
+
104
+ Two loss function approaches were tested:
105
+
106
+ *Precision-Focused Loss* (Default):
107
+
108
+ ```
109
+ total_loss = 1.0 * MSE + 0.5 * MAE + 0.3 * log_penalty + 0.5 * huber_loss
110
+ ```
111
+
112
+ - MSE: Standard mean squared error
113
+ - MAE: Mean absolute error
114
+ - Log penalty: Heavily penalizes small errors (pushes toward high precision)
115
+ - Huber loss: Small delta (0.01) for precision on small errors
116
+
117
+ *Pure MSE Loss* (Optimized config):
118
+
119
+ ```
120
+ total_loss = MSE
121
+ ```
122
+
123
+ Interestingly, the precision-focused loss achieved better Delta-E despite higher validation MSE, suggesting the custom weighting better correlates with perceptual accuracy.
124
+
125
+ ### Design Rationale
126
+
127
+ **Two-Stage Architecture**
128
+
129
+ The error predictor stage corrects systematic biases in the base model:
130
+
131
+ 1. Base model learns the general xyY to Munsell mapping
132
+ 2. Error predictor learns residual corrections specific to each component
133
+ 3. Combined prediction: `final = base_prediction + error_correction`
134
+
135
+ This decomposition allows each stage to specialize and reduces the complexity each network must learn.
136
+
137
+ **Independent Branch Design**
138
+
139
+ Munsell components have different characteristics:
140
+
141
+ - **Hue**: Circular (0-10, wrapping), most complex
142
+ - **Value**: Linear (0-10), easiest to predict
143
+ - **Chroma**: Highly variable range depending on hue/value
144
+ - **Code**: Discrete hue sector (0-9)
145
+
146
+ Shared encoders force compromises between these different prediction tasks. Independent branches allow full specialization.
147
+
148
+ **Architecture Details**
149
+
150
+ *MLP (Base Only)*
151
+
152
+ Simple feedforward network predicting all 4 outputs simultaneously:
153
+
154
+ Input (3) ──► Linear Layers ──► Output (4: hue, value, chroma, code)
155
+
156
+ - Smallest model (~8KB ONNX)
157
+ - Fastest inference (0.007 ms)
158
+ - Baseline for comparison
159
+
160
+ *Unified MLP*
161
+
162
+ Single large MLP with shared internal features:
163
+
164
+ Input (3) ──► 128 ──► 256 ──► 512 ──► 256 ──► 128 ──► Output (4)
165
+
166
+ - Shared representations across all outputs
167
+ - Moderate size, good speed
168
+
169
+ *Multi-Head MLP*
170
+
171
+ Shared encoder with specialized decoder heads:
172
+
173
+ Input (3) ──► SHARED ENCODER (3→128→256→512) ──┬──► Hue Head (512→256→128→1)
174
+ ├──► Value Head (512→256→128→1)
175
+ ├──► Chroma Head (512→384→256→128→1)
176
+ └──► Code Head (512→256→128→1)
177
+
178
+ - Shared encoder learns common color space features
179
+ - 4 specialized decoder heads branch from shared representation
180
+ - Parameter efficient (encoder weights shared)
181
+ - Fast inference (encoder computed once)
182
+
183
+ *Multi-MLP*
184
+
185
+ Fully independent branches with no weight sharing:
186
+
187
+ Input (3) ──► Hue Branch (3→128→256→512→256→128→1)
188
+ Input (3) ──► Value Branch (3→128→256→512→256→128→1)
189
+ Input (3) ──► Chroma Branch (3→256→512→1024→512→256→1) [2x wider]
190
+ Input (3) ──► Code Branch (3→128→256→512→256→128→1)
191
+
192
+ - 4 completely independent MLPs
193
+ - Each branch learns its own features from scratch
194
+ - Chroma branch is wider (2x) to handle its complexity
195
+ - Better accuracy than Multi-Head on large dataset (Delta-E 0.52 vs 0.56 with error predictors)
196
+
197
+ *Multi-ResNet*
198
+
199
+ Deep branches with residual-style connections:
200
+
201
+ Input (3) ──► Hue Branch (3→256→512→512→512→256→1) [6 layers]
202
+ Input (3) ──► Value Branch (3→256→512→512→512→256→1) [6 layers]
203
+ Input (3) ──► Chroma Branch (3→512→1024→1024→1024→512→1) [6 layers, 2x wider]
204
+ Input (3) ──► Code Branch (3→256→512→512→512→256→1) [6 layers]
205
+
206
+ - Deeper architecture than Multi-MLP
207
+ - BatchNorm + SiLU activation
208
+ - Best accuracy when combined with error predictor (Delta-E 0.52)
209
+ - Largest model (~14MB base, ~28MB with error predictor)
210
+
211
+ *Deep + Wide*
212
+
213
+ Combined deep and wide network paths:
214
+
215
+ Input (3) ──┬──► Deep Path (multiple layers) ──┬──► Concat ──► Output (4)
216
+ └──► Wide Path (direct connection) ─┘
217
+
218
+ - Deep path captures complex patterns
219
+ - Wide path preserves direct input information
220
+ - Good for mixed linear/nonlinear relationships
221
+
222
+ *MLP + Self-Attention*
223
+
224
+ MLP with attention mechanism for feature weighting:
225
+
226
+ Input (3) ──► MLP ──► Self-Attention ──► Output (4)
227
+
228
+ - Attention weights learn feature importance
229
+ - Slower due to attention computation (0.173 ms)
230
+ - Did not improve over simpler MLPs
231
+
232
+ *Mixture of Experts*
233
+
234
+ Gating network selecting specialized expert networks:
235
+
236
+ Input (3) ──► Gating Network ──► Weighted sum of Expert outputs ──► Output (4)
237
+
238
+ - Multiple expert networks specialize in different input regions
239
+ - Gating network learns which expert to use
240
+ - More complex but did not outperform Multi-MLP
241
+
242
+ *FT-Transformer*
243
+
244
+ Feature Tokenizer Transformer for tabular data:
245
+
246
+ Input (3) ──► Feature Tokenizer ──► Transformer Blocks ──► Output (4)
247
+
248
+ - Each input feature tokenized separately
249
+ - Self-attention across feature tokens
250
+ - Good for tabular data with feature interactions
251
+ - Slower inference due to attention computation
252
+
253
+ *Error Predictor (Two-Stage)*
254
+
255
+ Second-stage network that corrects base model errors:
256
+
257
+ Stage 1: Input (3) ──► Base Model ──► Base Prediction (4)
258
+ Stage 2: [Input (3), Base Prediction (4)] ──► Error Predictor ──► Error Correction (4)
259
+ Final: Base Prediction + Error Correction = Final Output
260
+
261
+ - Learns residual corrections for each component
262
+ - Can have unified (1 network) or multi (4 networks) error predictors
263
+ - Consistently improves accuracy across all base architectures
264
+ - Best results: Multi-ResNet + Multi-Error Predictor (Delta-E 0.52)
265
+
266
+ **Loss-Metric Mismatch**
267
+
268
+ An important finding: **optimizing MSE does not optimize Delta-E**.
269
+
270
+ The Optuna hyperparameter search minimized validation MSE, but the best MSE configuration did not achieve the best Delta-E. This is because:
271
+
272
+ - MSE treats all component errors equally
273
+ - Delta-E (CIE2000) weights errors based on human perception
274
+ - The precision-focused loss with custom weights better approximates perceptual importance
275
+
276
+ **Weighted Boundary Loss (Experimental)**
277
+
278
+ Analysis of model errors revealed systematic underperformance on Y/GY/G hues (Yellow/Green-Yellow/Green) with high value and chroma. The weighted boundary loss approach was explored to address this by:
279
+
280
+ 1. Applying 3x loss weight to samples in challenging regions:
281
+ - Hue: 0.18-0.35 (normalized range covering Y/YG/G)
282
+ - Value > 0.7 (high brightness)
283
+ - Chroma > 0.5 (high saturation)
284
+ 2. Adding boundary penalty to prevent predictions exceeding Munsell gamut limits
285
+
286
+ **Finding**: The large dataset approach (~1.4M samples with dense boundary sampling) naturally provides sufficient coverage of these challenging regions. Both the weighted boundary loss model (Multi-MLP W+B + Multi-Error Predictor W+B Large, Delta-E 0.524) and the standard large dataset model (Multi-MLP + Multi-Error Predictor Large, Delta-E 0.525) achieve nearly identical results, making explicit loss weighting optional. The best overall model is Multi-ResNet + Multi-Error Predictor (Large Dataset) with Delta-E 0.52.
287
+
288
+ ### Experimental Findings
289
+
290
+ The following experiments were conducted but did not improve results:
291
+
292
+ **Delta-E Training**
293
+
294
+ Training with differentiable Delta-E CIE2000 loss via round-trip through the Munsell-to-xyY approximator.
295
+
296
+ *Hypothesis*: Perceptual Delta-E loss might outperform MSE-trained models.
297
+
298
+ *Implementation*: JAX/Flax model with combined MSE + Delta-E loss. Requires lower learning rate (1e-4 vs 3e-4) for stability; higher rates cause NaN gradients.
299
+
300
+ *Results*: While Delta-E is comparable, **hue accuracy is ~10x worse**:
301
+
302
+ | Metric (Normalized MAE) | Delta-E Model | MSE Model |
303
+ |--------------------------|---------------|-----------|
304
+ | Hue MAE | 0.30 | 0.03 |
305
+ | Value MAE | 0.002 | 0.004 |
306
+ | Chroma MAE | 0.007 | 0.008 |
307
+ | Code MAE | 0.07 | 0.01 |
308
+ | **Delta-E (perceptual)** | **0.52** | **0.50** |
309
+
310
+ *Key Takeaway*: **Perceptual similarity != specification accuracy**. The MSE model's slightly better Delta-E (0.50 vs 0.52) comes at the cost of ~10x worse hue accuracy, making it unsuitable for specification prediction. Delta-E is too permissive for hue, allowing the model to find "shortcuts" that minimize perceptual difference without correctly predicting the Munsell specification.
311
+
312
+ **Classical Interpolation**
313
+
314
+ Classical interpolation methods were tested on 4,995 reference Munsell colors (80% train / 20% test split). ML evaluated on 2,734 REAL Munsell colors.
315
+
316
+ *Results (Validation MAE)*:
317
+
318
+ | Component | RBF | KD-Tree | Delaunay | ML (Best) |
319
+ |-----------|------|---------|----------|-----------|
320
+ | Hue | 1.40 | 1.40 | 1.29 | **0.03** |
321
+ | Value | 0.01 | 0.10 | 0.02 | 0.05 |
322
+ | Chroma | 0.22 | 0.99 | 0.35 | **0.11** |
323
+ | Code | 0.33 | 0.28 | 0.28 | **0.00** |
324
+
325
+ *Key Insight*: The reference dataset (4,995 colors) is too sparse for 3D xyY interpolation. Classical methods fail on hue prediction (MAE ~1.3-1.4), while ML achieves 47x better hue accuracy and 2-3x better chroma/code accuracy.
326
+
327
+ **Circular Hue Loss**
328
+
329
+ Circular distance metrics for hue prediction, accounting for cyclic nature (0-10 wraps).
330
+
331
+ *Results*: The circular loss model performed **21x worse** on hue MAE (5.14 vs 0.24).
332
+
333
+ *Key Takeaway*: **Mathematical correctness != training effectiveness**. The circular distance creates gradient discontinuities that harm optimization.
334
+
335
+ **REAL-Only Refinement**
336
+
337
+ Fine-tuning using only REAL Munsell colors (2,734) instead of ALL colors (4,995).
338
+
339
+ *Results*: Essentially identical performance (Delta-E 1.5233 vs 1.5191).
340
+
341
+ *Key Takeaway*: **Data quality is not the bottleneck**. Both REAL and extrapolated colors are sufficiently accurate.
342
+
343
+ **Gamma Normalization**
344
+
345
+ Gamma correction to the Y (luminance) channel during normalization.
346
+
347
+ *Results*: No consistent improvement across gamma values 1.0-3.0:
348
+
349
+ | Gamma | Median ΔE (± std) |
350
+ |----------------|-------------------|
351
+ | 1.0 (baseline) | 0.730 ± 0.054 |
352
+ | 2.5 (best) | 0.683 ± 0.132 |
353
+
354
+ ![Gamma sweep results](_static/gamma_sweep_plot.png)
355
+
356
+ *Key Takeaway*: **Gamma normalization does not provide consistent improvement**. Standard deviations overlap - differences are within noise.
357
+
358
+ ## Munsell to xyY (to_xyY)
359
+
360
+ ### Performance Benchmarks
361
+
362
+ Comprehensive comparison using all 2,734 REAL Munsell colors:
363
+
364
+ | Model | Delta-E | Speed (ms) |
365
+ |-----------------------------------------------|-------------|------------|
366
+ | Colour Library (Baseline) | 0.00 | 1.27 |
367
+ | **Multi-MLP (Optimized)** | **0.48** | 0.008 |
368
+ | Multi-MLP (Opt) + Multi-Error Predictor (Opt) | 0.48 | 0.025 |
369
+ | Multi-MLP + Multi-Error Predictor | 0.65 | 0.030 |
370
+ | Multi-MLP | 0.66 | 0.016 |
371
+ | Multi-MLP + Error Predictor | 0.67 | 0.018 |
372
+ | Multi-Head (Optimized) | 0.71 | 0.015 |
373
+ | Multi-Head | 0.78 | 0.008 |
374
+ | Multi-Head + Multi-Error Predictor | 1.11 | 0.028 |
375
+ | Simple MLP | 1.42 | **0.0008** |
376
+
377
+ **Best Models**:
378
+
379
+ - **Best Accuracy**: Multi-MLP (Optimized) - Delta-E 0.48
380
+ - **Fastest**: Simple MLP (0.0008 ms/sample) - 1,654x faster than Colour library
381
+ - **Best Balance**: Multi-MLP (Optimized) - 154x faster with Delta-E 0.48
382
+
383
+ ### Model Architectures
384
+
385
+ 9 architectures were evaluated for the Munsell to xyY direction:
386
+
387
+ **Single-Stage Models**
388
+
389
+ 1. **Simple MLP** - Basic MLP network, 4 inputs to 3 outputs
390
+ 2. **Multi-Head** - Shared encoder with 3 independent decoder heads (x, y, Y)
391
+ 3. **Multi-Head (Optimized)** - Hyperparameter-optimized variant
392
+ 4. **Multi-MLP** - 3 completely independent MLP branches
393
+ 5. **Multi-MLP (Optimized)** - Hyperparameter-optimized variant (BEST)
394
+
395
+ **Two-Stage Models**
396
+
397
+ 6. **Multi-MLP + Error Predictor** - Base Multi-MLP with unified error correction
398
+ 7. **Multi-MLP + Multi-Error Predictor** - 3 independent error predictors
399
+ 8. **Multi-MLP (Opt) + Multi-Error Predictor (Opt)** - Optimized two-stage
400
+ 9. **Multi-Head + Multi-Error Predictor** - Multi-Head with error correction
401
+
402
+ The **Multi-MLP (Optimized)** architecture achieved the best results with Delta-E 0.48.
403
+
404
+ ### Differentiable Approximator
405
+
406
+ A small MLP (68K parameters) trained to approximate the Munsell to xyY conversion for use in differentiable Delta-E loss:
407
+
408
+ - **Architecture**: 4 -> 128 -> 256 -> 128 -> 3 with LayerNorm + SiLU
409
+ - **Accuracy**: MAE ~0.0006 for x, y, and Y components
410
+ - **Output formats**: PyTorch (.pth), ONNX, and JAX-compatible weights (.npz)
411
+
412
+ This enables differentiable Munsell to xyY conversion, which was previously only possible through non-differentiable lookup tables.
413
+
414
+ ## Shared Infrastructure
415
+
416
+ ### Hyperparameter Optimization
417
+
418
+ Optuna was used for systematic hyperparameter search over:
419
+
420
+ - Learning rate (1e-4 to 1e-3)
421
+ - Batch size (256, 512, 1024)
422
+ - Dropout rate (0.0 to 0.2)
423
+ - Chroma branch width multiplier (1.0 to 2.0)
424
+ - Loss function weights (MSE, Huber)
425
+
426
+ Key finding: **No dropout (0.0)** consistently performed better across all models in both conversion directions, contrary to typical deep learning recommendations for regularization.
427
+
428
+ ### Training Infrastructure
429
+
430
+ - **Optimizer**: AdamW with weight decay
431
+ - **Scheduler**: ReduceLROnPlateau (patience=10, factor=0.5)
432
+ - **Early stopping**: Patience=20 epochs
433
+ - **Checkpointing**: Best model saved based on validation loss
434
+ - **Logging**: MLflow for experiment tracking
435
+
436
+ ### JAX Delta-E Implementation
437
+
438
+ Located in `learning_munsell/losses/jax_delta_e.py`:
439
+
440
+ - Differentiable xyY -> XYZ -> Lab color space conversions
441
+ - Full CIE 2000 Delta-E implementation with gradient support
442
+ - JIT-compiled functions for performance
443
+
444
+ Usage:
445
+
446
+ ```python
447
+ from learning_munsell.losses import delta_E_loss, delta_E_CIE2000
448
+
449
+ # Compute perceptual loss between predicted and target xyY
450
+ loss = delta_E_loss(pred_xyY, target_xyY)
451
+ ```
452
+
453
+ ## Limitations
454
+
455
+ ### BatchNorm Instability on MPS
456
+
457
+ Models using `BatchNorm1d` layers exhibit numerical instability when trained on Apple Silicon GPUs via the MPS backend:
458
+
459
+ 1. **Validation loss spikes** during training
460
+ 2. **Occasional extreme outputs** during inference (e.g., 20M instead of ~0.1)
461
+ 3. **Non-reproducible behavior**
462
+
463
+ **Affected Models**: Large dataset error predictors using BatchNorm.
464
+
465
+ **Workarounds**:
466
+
467
+ 1. Use CPU for training
468
+ 2. Replace BatchNorm with LayerNorm
469
+ 3. Use smaller models (300K samples vs 2M)
470
+ 4. Skip error predictor stage for affected models
471
+
472
+ The recommended production model (`multi_resnet_error_predictor_large.onnx`) was trained on the large dataset and does not exhibit this instability.
473
+
474
+ **References**:
475
+
476
+ - [BatchNorm non-trainable exception](https://github.com/pytorch/pytorch/issues/98602)
477
+ - [ONNX export incorrect on MPS](https://github.com/pytorch/pytorch/issues/83230)
478
+ - [MPS kernel bugs](https://elanapearl.github.io/blog/2025/the-bug-that-taught-me-pytorch/)
learning_munsell/__init__.py ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ """Learning Munsell - Machine Learning for Munsell Color Conversions."""
2
+
3
+ from pathlib import Path
4
+
5
+ __all__ = ["PROJECT_ROOT"]
6
+
7
+ PROJECT_ROOT = Path(__file__).parent.parent
learning_munsell/analysis/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Analysis utilities for Munsell color conversion models."""
learning_munsell/analysis/error_analysis.py ADDED
@@ -0,0 +1,304 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Analyze error distribution to identify problematic regions in Munsell space.
3
+
4
+ This script:
5
+ 1. Runs the best model on all REAL Munsell colors
6
+ 2. Computes Delta-E for each sample
7
+ 3. Identifies samples with high error (Delta-E > threshold)
8
+ 4. Analyzes patterns: which hue families, value ranges, chroma ranges have issues
9
+ 5. Outputs statistics and visualizations
10
+ """
11
+
12
+ import logging
13
+ from collections import defaultdict
14
+
15
+ import numpy as np
16
+ import onnxruntime as ort
17
+ from colour import XYZ_to_Lab, xyY_to_XYZ
18
+ from colour.difference import delta_E_CIE2000
19
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
20
+ from colour.notation.munsell import (
21
+ CCS_ILLUMINANT_MUNSELL,
22
+ munsell_colour_to_munsell_specification,
23
+ munsell_specification_to_xyY,
24
+ )
25
+
26
+ from learning_munsell import PROJECT_ROOT
27
+
28
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
29
+ LOGGER = logging.getLogger(__name__)
30
+
31
+ HUE_NAMES = {
32
+ 1: "R",
33
+ 2: "YR",
34
+ 3: "Y",
35
+ 4: "GY",
36
+ 5: "G",
37
+ 6: "BG",
38
+ 7: "B",
39
+ 8: "PB",
40
+ 9: "P",
41
+ 10: "RP",
42
+ 0: "RP",
43
+ }
44
+
45
+
46
+ def load_model_and_params(model_name: str):
47
+ """Load ONNX model and normalization parameters."""
48
+ model_dir = PROJECT_ROOT / "models" / "from_xyY"
49
+
50
+ model_path = model_dir / f"{model_name}.onnx"
51
+ params_path = model_dir / f"{model_name}_normalization_params.npz"
52
+
53
+ if not model_path.exists():
54
+ raise FileNotFoundError(f"Model not found: {model_path}")
55
+ if not params_path.exists():
56
+ raise FileNotFoundError(f"Params not found: {params_path}")
57
+
58
+ session = ort.InferenceSession(str(model_path))
59
+ params = np.load(params_path, allow_pickle=True)
60
+ input_params = params["input_params"].item()
61
+ output_params = params["output_params"].item()
62
+
63
+ return session, input_params, output_params
64
+
65
+
66
+ def normalize_input(xyY: np.ndarray, params: dict) -> np.ndarray:
67
+ """Normalize xyY input."""
68
+ normalized = np.copy(xyY).astype(np.float32)
69
+ # Scale Y from 0-100 to 0-1 range before normalization
70
+ normalized[..., 2] = xyY[..., 2] / 100.0
71
+ normalized[..., 0] = (xyY[..., 0] - params["x_range"][0]) / (
72
+ params["x_range"][1] - params["x_range"][0]
73
+ )
74
+ normalized[..., 1] = (xyY[..., 1] - params["y_range"][0]) / (
75
+ params["y_range"][1] - params["y_range"][0]
76
+ )
77
+ normalized[..., 2] = (normalized[..., 2] - params["Y_range"][0]) / (
78
+ params["Y_range"][1] - params["Y_range"][0]
79
+ )
80
+ return normalized
81
+
82
+
83
+ def denormalize_output(pred: np.ndarray, params: dict) -> np.ndarray:
84
+ """Denormalize Munsell output."""
85
+ denorm = np.copy(pred)
86
+ denorm[..., 0] = (
87
+ pred[..., 0] * (params["hue_range"][1] - params["hue_range"][0])
88
+ + params["hue_range"][0]
89
+ )
90
+ denorm[..., 1] = (
91
+ pred[..., 1] * (params["value_range"][1] - params["value_range"][0])
92
+ + params["value_range"][0]
93
+ )
94
+ denorm[..., 2] = (
95
+ pred[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0])
96
+ + params["chroma_range"][0]
97
+ )
98
+ denorm[..., 3] = (
99
+ pred[..., 3] * (params["code_range"][1] - params["code_range"][0])
100
+ + params["code_range"][0]
101
+ )
102
+ return denorm
103
+
104
+
105
+ def compute_delta_e(pred_spec: np.ndarray, gt_xyY: np.ndarray) -> float:
106
+ """Compute Delta-E between predicted spec (via xyY) and ground truth xyY."""
107
+ try:
108
+ pred_xyY = munsell_specification_to_xyY(pred_spec)
109
+ pred_XYZ = xyY_to_XYZ(pred_xyY)
110
+ pred_Lab = XYZ_to_Lab(pred_XYZ, CCS_ILLUMINANT_MUNSELL)
111
+
112
+ # Ground truth Y is in 0-100 range, need to scale to 0-1
113
+ gt_xyY_scaled = gt_xyY.copy()
114
+ gt_xyY_scaled[2] = gt_xyY[2] / 100.0
115
+ gt_XYZ = xyY_to_XYZ(gt_xyY_scaled)
116
+ gt_Lab = XYZ_to_Lab(gt_XYZ, CCS_ILLUMINANT_MUNSELL)
117
+
118
+ return delta_E_CIE2000(gt_Lab, pred_Lab)
119
+ except Exception:
120
+ return np.nan
121
+
122
+
123
+ def analyze_errors(model_name: str = "multi_head_large", threshold: float = 3.0):
124
+ """Analyze error distribution for a model."""
125
+ LOGGER.info("=" * 80)
126
+ LOGGER.info("Error Analysis for %s", model_name)
127
+ LOGGER.info("=" * 80)
128
+
129
+ # Load model
130
+ session, input_params, output_params = load_model_and_params(model_name)
131
+ input_name = session.get_inputs()[0].name
132
+
133
+ # Collect data
134
+ results = []
135
+
136
+ for munsell_spec_tuple, xyY_gt in MUNSELL_COLOURS_REAL:
137
+ hue_code_str, value, chroma = munsell_spec_tuple
138
+ munsell_str = f"{hue_code_str} {value}/{chroma}"
139
+
140
+ try:
141
+ gt_spec = munsell_colour_to_munsell_specification(munsell_str)
142
+ gt_xyY = np.array(xyY_gt)
143
+
144
+ # Predict
145
+ xyY_norm = normalize_input(gt_xyY.reshape(1, 3), input_params)
146
+ pred_norm = session.run(None, {input_name: xyY_norm})[0]
147
+ pred_spec = denormalize_output(pred_norm, output_params)[0]
148
+
149
+ # Clamp to valid ranges
150
+ pred_spec[0] = np.clip(pred_spec[0], 0.5, 10.0)
151
+ pred_spec[1] = np.clip(pred_spec[1], 1.0, 9.0)
152
+ pred_spec[2] = np.clip(pred_spec[2], 0.0, 50.0)
153
+ pred_spec[3] = np.clip(pred_spec[3], 1.0, 10.0)
154
+ pred_spec[3] = np.round(pred_spec[3])
155
+
156
+ # Compute Delta-E
157
+ delta_e = compute_delta_e(pred_spec, gt_xyY)
158
+
159
+ if not np.isnan(delta_e):
160
+ results.append({
161
+ "munsell_str": munsell_str,
162
+ "gt_spec": gt_spec,
163
+ "pred_spec": pred_spec,
164
+ "delta_e": delta_e,
165
+ "hue": gt_spec[0],
166
+ "value": gt_spec[1],
167
+ "chroma": gt_spec[2],
168
+ "code": int(gt_spec[3]),
169
+ "gt_xyY": gt_xyY,
170
+ })
171
+ except Exception as e:
172
+ LOGGER.warning("Failed for %s: %s", munsell_str, e)
173
+
174
+ LOGGER.info("\nTotal samples evaluated: %d", len(results))
175
+
176
+ # Overall statistics
177
+ delta_es = [r["delta_e"] for r in results]
178
+ LOGGER.info("\nOverall Delta-E Statistics:")
179
+ LOGGER.info(" Mean: %.4f", np.mean(delta_es))
180
+ LOGGER.info(" Median: %.4f", np.median(delta_es))
181
+ LOGGER.info(" Std: %.4f", np.std(delta_es))
182
+ LOGGER.info(" Min: %.4f", np.min(delta_es))
183
+ LOGGER.info(" Max: %.4f", np.max(delta_es))
184
+
185
+ # Distribution
186
+ LOGGER.info("\nDelta-E Distribution:")
187
+ for thresh in [1.0, 2.0, 3.0, 5.0, 10.0]:
188
+ count = sum(1 for d in delta_es if d <= thresh)
189
+ pct = 100 * count / len(delta_es)
190
+ LOGGER.info(" <= %.1f: %4d (%.1f%%)", thresh, count, pct)
191
+
192
+ # High error samples
193
+ high_error = [r for r in results if r["delta_e"] > threshold]
194
+ LOGGER.info("\nSamples with Delta-E > %.1f: %d (%.1f%%)",
195
+ threshold, len(high_error), 100 * len(high_error) / len(results))
196
+
197
+ # Analyze by hue family
198
+ LOGGER.info("\n" + "=" * 40)
199
+ LOGGER.info("Analysis by Hue Family")
200
+ LOGGER.info("=" * 40)
201
+
202
+ by_hue = defaultdict(list)
203
+ for r in results:
204
+ hue_name = HUE_NAMES.get(r["code"], f"?{r['code']}")
205
+ by_hue[hue_name].append(r["delta_e"])
206
+
207
+ LOGGER.info("\n%-4s %5s %6s %6s %6s %s",
208
+ "Hue", "Count", "Mean", "Median", "Max", ">3.0")
209
+ for hue_name in ["R", "YR", "Y", "GY", "G", "BG", "B", "PB", "P", "RP"]:
210
+ if hue_name in by_hue:
211
+ des = by_hue[hue_name]
212
+ high = sum(1 for d in des if d > 3.0)
213
+ LOGGER.info("%-4s %5d %6.2f %6.2f %6.2f %d (%.0f%%)",
214
+ hue_name, len(des), np.mean(des), np.median(des),
215
+ np.max(des), high, 100*high/len(des))
216
+
217
+ # Analyze by value range
218
+ LOGGER.info("\n" + "=" * 40)
219
+ LOGGER.info("Analysis by Value Range")
220
+ LOGGER.info("=" * 40)
221
+
222
+ value_ranges = [(1, 3), (3, 5), (5, 7), (7, 9)]
223
+ LOGGER.info("\n%-8s %5s %6s %6s %6s %s",
224
+ "Value", "Count", "Mean", "Median", "Max", ">3.0")
225
+ for v_min, v_max in value_ranges:
226
+ des = [r["delta_e"] for r in results if v_min <= r["value"] < v_max]
227
+ if des:
228
+ high = sum(1 for d in des if d > 3.0)
229
+ LOGGER.info("[%d-%d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)",
230
+ v_min, v_max, len(des), np.mean(des), np.median(des),
231
+ np.max(des), high, 100*high/len(des) if des else 0)
232
+
233
+ # Analyze by chroma range
234
+ LOGGER.info("\n" + "=" * 40)
235
+ LOGGER.info("Analysis by Chroma Range")
236
+ LOGGER.info("=" * 40)
237
+
238
+ chroma_ranges = [(0, 4), (4, 8), (8, 12), (12, 20), (20, 50)]
239
+ LOGGER.info("\n%-8s %5s %6s %6s %6s %s",
240
+ "Chroma", "Count", "Mean", "Median", "Max", ">3.0")
241
+ for c_min, c_max in chroma_ranges:
242
+ des = [r["delta_e"] for r in results if c_min <= r["chroma"] < c_max]
243
+ if des:
244
+ high = sum(1 for d in des if d > 3.0)
245
+ LOGGER.info("[%2d-%2d) %5d %6.2f %6.2f %6.2f %d (%.0f%%)",
246
+ c_min, c_max, len(des), np.mean(des), np.median(des),
247
+ np.max(des), high, 100*high/len(des) if des else 0)
248
+
249
+ # Top 20 worst samples
250
+ LOGGER.info("\n" + "=" * 40)
251
+ LOGGER.info("Top 20 Worst Samples")
252
+ LOGGER.info("=" * 40)
253
+
254
+ worst = sorted(results, key=lambda r: r["delta_e"], reverse=True)[:20]
255
+ LOGGER.info("\n%-15s %6s %-20s %-20s",
256
+ "Munsell", "DeltaE", "GT Spec", "Pred Spec")
257
+ for r in worst:
258
+ gt = f"[{r['gt_spec'][0]:.1f}, {r['gt_spec'][1]:.1f}, {r['gt_spec'][2]:.1f}, {int(r['gt_spec'][3])}]"
259
+ pred = f"[{r['pred_spec'][0]:.1f}, {r['pred_spec'][1]:.1f}, {r['pred_spec'][2]:.1f}, {int(r['pred_spec'][3])}]"
260
+ LOGGER.info("%-15s %6.2f %-20s %-20s",
261
+ r["munsell_str"], r["delta_e"], gt, pred)
262
+
263
+ # Analyze component errors for high-error samples
264
+ LOGGER.info("\n" + "=" * 40)
265
+ LOGGER.info("Component Errors for High-Error Samples (Delta-E > %.1f)", threshold)
266
+ LOGGER.info("=" * 40)
267
+
268
+ if high_error:
269
+ hue_errors = [abs(r["pred_spec"][0] - r["gt_spec"][0]) for r in high_error]
270
+ value_errors = [abs(r["pred_spec"][1] - r["gt_spec"][1]) for r in high_error]
271
+ chroma_errors = [abs(r["pred_spec"][2] - r["gt_spec"][2]) for r in high_error]
272
+ code_errors = [abs(r["pred_spec"][3] - r["gt_spec"][3]) for r in high_error]
273
+
274
+ LOGGER.info("\n%-10s %6s %6s %6s",
275
+ "Component", "Mean", "Median", "Max")
276
+ LOGGER.info("%-10s %6.2f %6.2f %6.2f", "Hue",
277
+ np.mean(hue_errors), np.median(hue_errors), np.max(hue_errors))
278
+ LOGGER.info("%-10s %6.2f %6.2f %6.2f", "Value",
279
+ np.mean(value_errors), np.median(value_errors), np.max(value_errors))
280
+ LOGGER.info("%-10s %6.2f %6.2f %6.2f", "Chroma",
281
+ np.mean(chroma_errors), np.median(chroma_errors), np.max(chroma_errors))
282
+ LOGGER.info("%-10s %6.2f %6.2f %6.2f", "Code",
283
+ np.mean(code_errors), np.median(code_errors), np.max(code_errors))
284
+
285
+ return results
286
+
287
+
288
+ def main():
289
+ """Run error analysis."""
290
+ # Try the best models
291
+ models = [
292
+ "multi_head_large",
293
+ ]
294
+
295
+ for model_name in models:
296
+ try:
297
+ analyze_errors(model_name, threshold=3.0)
298
+ except FileNotFoundError as e:
299
+ LOGGER.warning("Skipping %s: %s", model_name, e)
300
+ LOGGER.info("\n")
301
+
302
+
303
+ if __name__ == "__main__":
304
+ main()
learning_munsell/comparison/from_xyY/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Comparison scripts for xyY to Munsell conversion models."""
learning_munsell/comparison/from_xyY/compare_all_models.py ADDED
@@ -0,0 +1,1292 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compare all ML models for xyY to Munsell conversion on real Munsell data.
3
+
4
+ Models to compare:
5
+ 1. MLP (Base only)
6
+ 2. MLP + Error Predictor (Two-stage)
7
+ 3. Unified MLP
8
+ 4. MLP + Self-Attention
9
+ 5. MLP + Self-Attention + Error Predictor
10
+ 6. Deep + Wide
11
+ 7. Mixture of Experts
12
+ 8. FT-Transformer
13
+ """
14
+
15
+ import logging
16
+ import time
17
+ import warnings
18
+ from datetime import datetime
19
+ from pathlib import Path
20
+ from typing import Any
21
+
22
+ import numpy as np
23
+ import onnxruntime as ort
24
+ from colour import XYZ_to_Lab, xyY_to_XYZ
25
+ from colour.difference import delta_E_CIE2000
26
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
27
+ from colour.notation.munsell import (
28
+ CCS_ILLUMINANT_MUNSELL,
29
+ munsell_colour_to_munsell_specification,
30
+ munsell_specification_to_xyY,
31
+ xyY_to_munsell_specification,
32
+ )
33
+ from numpy.typing import NDArray
34
+
35
+ from learning_munsell import PROJECT_ROOT
36
+ from learning_munsell.utilities.common import (
37
+ benchmark_inference_speed,
38
+ get_model_size_mb,
39
+ )
40
+
41
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
42
+ LOGGER = logging.getLogger(__name__)
43
+
44
+
45
+ def normalize_input(X: NDArray, params: dict[str, Any] | None) -> NDArray:
46
+ """Normalize xyY input.
47
+
48
+ If params is None, xyY is assumed to already be in [0, 1] range (no normalization needed).
49
+ """
50
+ if params is None:
51
+ # xyY is already in [0, 1] range - no normalization needed
52
+ return X.astype(np.float32)
53
+
54
+ X_norm = np.copy(X)
55
+ X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / (
56
+ params["x_range"][1] - params["x_range"][0]
57
+ )
58
+ X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / (
59
+ params["y_range"][1] - params["y_range"][0]
60
+ )
61
+ X_norm[..., 2] = (X[..., 2] - params["Y_range"][0]) / (
62
+ params["Y_range"][1] - params["Y_range"][0]
63
+ )
64
+ return X_norm.astype(np.float32)
65
+
66
+
67
+ def denormalize_output(y_norm: NDArray, params: dict[str, Any]) -> NDArray:
68
+ """Denormalize Munsell output."""
69
+ y = np.copy(y_norm)
70
+ y[..., 0] = (
71
+ y_norm[..., 0] * (params["hue_range"][1] - params["hue_range"][0])
72
+ + params["hue_range"][0]
73
+ )
74
+ y[..., 1] = (
75
+ y_norm[..., 1] * (params["value_range"][1] - params["value_range"][0])
76
+ + params["value_range"][0]
77
+ )
78
+ y[..., 2] = (
79
+ y_norm[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0])
80
+ + params["chroma_range"][0]
81
+ )
82
+ y[..., 3] = (
83
+ y_norm[..., 3] * (params["code_range"][1] - params["code_range"][0])
84
+ + params["code_range"][0]
85
+ )
86
+ return y
87
+
88
+
89
+ def clamp_munsell_specification(specification: NDArray) -> NDArray:
90
+ """Clamp Munsell specification to valid ranges."""
91
+
92
+ clamped = np.copy(specification)
93
+ clamped[..., 0] = np.clip(specification[..., 0], 0.0, 10.0) # Hue: [0, 10]
94
+ clamped[..., 1] = np.clip(specification[..., 1], 1.0, 9.0) # Value: [1, 9] (colour library constraint)
95
+ clamped[..., 2] = np.clip(specification[..., 2], 0.0, 50.0) # Chroma: [0, 50]
96
+ clamped[..., 3] = np.clip(specification[..., 3], 1.0, 10.0) # Code: [1, 10]
97
+
98
+ return clamped
99
+
100
+
101
+ def evaluate_model(
102
+ session: ort.InferenceSession,
103
+ X_norm: NDArray,
104
+ ground_truth: NDArray,
105
+ params: dict[str, Any],
106
+ input_name: str = "xyY",
107
+ reference_Lab: NDArray | None = None,
108
+ ) -> dict[str, Any]:
109
+ """Evaluate a single model."""
110
+ pred_norm = session.run(None, {input_name: X_norm})[0]
111
+ pred = denormalize_output(pred_norm, params)
112
+ errors = np.abs(pred - ground_truth)
113
+
114
+ result = {
115
+ "hue_mae": np.mean(errors[:, 0]),
116
+ "value_mae": np.mean(errors[:, 1]),
117
+ "chroma_mae": np.mean(errors[:, 2]),
118
+ "code_mae": np.mean(errors[:, 3]),
119
+ "max_errors": np.max(errors, axis=1),
120
+ "hue_errors": errors[:, 0],
121
+ "value_errors": errors[:, 1],
122
+ "chroma_errors": errors[:, 2],
123
+ "code_errors": errors[:, 3],
124
+ }
125
+
126
+ # Compute Delta-E against ground truth
127
+ if reference_Lab is not None:
128
+ delta_E_values = []
129
+ for idx in range(len(pred)):
130
+ try:
131
+ # Convert ML prediction to Lab: Munsell spec → xyY → XYZ → Lab
132
+ ml_spec = clamp_munsell_specification(pred[idx])
133
+
134
+ # Round Code to nearest integer before round-trip conversion
135
+ ml_spec_for_conversion = ml_spec.copy()
136
+ ml_spec_for_conversion[3] = round(ml_spec[3])
137
+
138
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
139
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
140
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
141
+
142
+ delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
143
+ delta_E_values.append(delta_E)
144
+ except (RuntimeError, ValueError):
145
+ # Skip if conversion fails
146
+ continue
147
+
148
+ result["delta_E"] = np.mean(delta_E_values) if delta_E_values else np.nan
149
+ else:
150
+ result["delta_E"] = np.nan
151
+
152
+ return result
153
+
154
+
155
+ def generate_html_report(
156
+ results: dict[str, dict[str, Any]],
157
+ num_samples: int,
158
+ output_file: Path,
159
+ baseline_inference_time_ms: float | None = None,
160
+ ) -> None:
161
+ """Generate HTML report with visualizations."""
162
+ # Calculate metrics
163
+ avg_maes = {}
164
+ for model_name, result in results.items():
165
+ avg_maes[model_name] = np.mean(
166
+ [
167
+ result["hue_mae"],
168
+ result["value_mae"],
169
+ result["chroma_mae"],
170
+ result["code_mae"],
171
+ ]
172
+ )
173
+
174
+ # Sort by average MAE
175
+ sorted_models = sorted(avg_maes.items(), key=lambda x: x[1])
176
+
177
+ # Precision thresholds
178
+ thresholds = [1e-4, 1e-3, 1e-2, 1e-1, 1.0]
179
+
180
+ html = f"""<!DOCTYPE html>
181
+ <html lang="en" class="dark">
182
+ <head>
183
+ <meta charset="UTF-8">
184
+ <meta name="viewport" content="width=device-width, initial-scale=1.0">
185
+ <title>ML Model Comparison Report - {datetime.now().strftime("%Y-%m-%d %H:%M")}</title>
186
+ <script src="https://cdn.tailwindcss.com"></script>
187
+ <script>
188
+ tailwind.config = {{
189
+ darkMode: 'class',
190
+ theme: {{
191
+ extend: {{
192
+ colors: {{
193
+ border: "hsl(240 3.7% 15.9%)",
194
+ input: "hsl(240 3.7% 15.9%)",
195
+ ring: "hsl(240 4.9% 83.9%)",
196
+ background: "hsl(240 10% 3.9%)",
197
+ foreground: "hsl(0 0% 98%)",
198
+ primary: {{
199
+ DEFAULT: "hsl(263 70% 60%)",
200
+ foreground: "hsl(0 0% 98%)",
201
+ }},
202
+ secondary: {{
203
+ DEFAULT: "hsl(240 3.7% 15.9%)",
204
+ foreground: "hsl(0 0% 98%)",
205
+ }},
206
+ muted: {{
207
+ DEFAULT: "hsl(240 3.7% 15.9%)",
208
+ foreground: "hsl(240 5% 64.9%)",
209
+ }},
210
+ accent: {{
211
+ DEFAULT: "hsl(240 3.7% 15.9%)",
212
+ foreground: "hsl(0 0% 98%)",
213
+ }},
214
+ card: {{
215
+ DEFAULT: "hsl(240 10% 6%)",
216
+ foreground: "hsl(0 0% 98%)",
217
+ }},
218
+ }}
219
+ }}
220
+ }}
221
+ }}
222
+ </script>
223
+ <style>
224
+ .gradient-primary {{
225
+ background: linear-gradient(135deg, hsl(263 70% 50%) 0%, hsl(280 70% 45%) 100%);
226
+ }}
227
+ .bar-fill {{
228
+ background: linear-gradient(90deg, hsl(263 70% 60%) 0%, hsl(280 70% 55%) 100%);
229
+ transition: width 0.5s cubic-bezier(0.4, 0, 0.2, 1);
230
+ }}
231
+ </style>
232
+ </head>
233
+ <body class="bg-background text-foreground antialiased">
234
+ <div class="max-w-7xl mx-auto p-6 space-y-6">
235
+ <!-- Header -->
236
+ <div class="gradient-primary rounded-lg p-8 shadow-2xl border border-primary/20">
237
+ <h1 class="text-4xl font-bold text-white mb-2">ML Model Comparison Report</h1>
238
+ <div class="text-white/90 space-y-1">
239
+ <p class="text-lg">xyY to Munsell Specification Conversion</p>
240
+ <p class="text-sm">Generated: {datetime.now().strftime("%Y-%m-%d %H:%M:%S")}</p>
241
+ <p class="text-sm">Test Samples: <span class="font-semibold">{num_samples:,}</span> real Munsell colors</p>
242
+ </div>
243
+ </div>
244
+ """
245
+
246
+ # Best Models Summary (FIRST - moved to top)
247
+ # Find best models for each metric
248
+ delta_E_values = [
249
+ r["delta_E"] for r in results.values() if not np.isnan(r["delta_E"])
250
+ ]
251
+
252
+ best_delta_E = (
253
+ min(
254
+ results.items(),
255
+ key=lambda x: x[1]["delta_E"]
256
+ if not np.isnan(x[1]["delta_E"])
257
+ else float("inf"),
258
+ )[0]
259
+ if delta_E_values
260
+ else None
261
+ )
262
+ best_avg = sorted_models[0][0]
263
+
264
+ # Performance Metrics Table (FIRST - as summary)
265
+ # Find best for each metric
266
+ best_size = min(results.items(), key=lambda x: x[1]["model_size_mb"])[0]
267
+ best_speed = min(results.items(), key=lambda x: x[1]["inference_time_ms"])[0]
268
+
269
+ # Add Best Models Summary HTML
270
+ html += f"""
271
+ <!-- Best Models Summary -->
272
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
273
+ <h2 class="text-2xl font-semibold mb-6 pb-3 border-b border-primary/30">Best Models by Metric</h2>
274
+ <div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4">
275
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
276
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Smallest Size</div>
277
+ <div class="text-3xl font-bold text-primary mb-3">{results[best_size]["model_size_mb"]:.2f} MB</div>
278
+ <div class="text-sm text-foreground/80">{best_size}</div>
279
+ </div>
280
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
281
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Fastest Speed</div>
282
+ <div class="text-3xl font-bold text-primary mb-3">{results[best_speed]["inference_time_ms"]:.4f} ms</div>
283
+ <div class="text-sm text-foreground/80">{best_speed}</div>
284
+ </div>
285
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
286
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Best Delta-E</div>
287
+ <div class="text-3xl font-bold text-primary mb-3">{results[best_delta_E]["delta_E"]:.4f}</div>
288
+ <div class="text-sm text-foreground/80">{best_delta_E}</div>
289
+ </div>
290
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
291
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Best Average MAE</div>
292
+ <div class="text-3xl font-bold text-primary mb-3">{avg_maes[best_avg]:.4f}</div>
293
+ <div class="text-sm text-foreground/80">{best_avg}</div>
294
+ </div>
295
+ </div>
296
+ </div>
297
+ """
298
+
299
+ # Get baseline speed (Colour Library Iterative)
300
+ baseline_speed = baseline_inference_time_ms
301
+
302
+ # Sort by Delta-E for performance table (best first)
303
+ sorted_by_delta_E = sorted(
304
+ results.items(),
305
+ key=lambda x: x[1]["delta_E"]
306
+ if not np.isnan(x[1]["delta_E"])
307
+ else float("inf"),
308
+ )
309
+
310
+ # Calculate maximum speed multiplier (fastest model) for highlighting
311
+ max_speed_multiplier = 0.0
312
+ best_multiplier_model = None
313
+ for model_name, result in results.items():
314
+ speed_ms = result["inference_time_ms"]
315
+ if speed_ms > 0:
316
+ speed_multiplier = baseline_speed / speed_ms
317
+ if speed_multiplier > max_speed_multiplier:
318
+ max_speed_multiplier = speed_multiplier
319
+ best_multiplier_model = model_name
320
+
321
+ html += """
322
+ <!-- Performance Metrics Table -->
323
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
324
+ <h2 class="text-2xl font-semibold mb-6 pb-3 border-b border-primary/30">Model Performance Metrics</h2>
325
+ <div class="overflow-x-auto">
326
+ <table class="w-full text-sm">
327
+ <thead>
328
+ <tr class="border-b border-border">
329
+ <th class="text-left py-3 px-4 font-semibold text-muted-foreground">Model</th>
330
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">
331
+ Size (MB)
332
+ <div class="text-xs font-normal text-muted-foreground/70 mt-1">ONNX files</div>
333
+ </th>
334
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">
335
+ Speed (ms/sample)
336
+ <div class="text-xs font-normal text-muted-foreground/70 mt-1">10 iterations</div>
337
+ </th>
338
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">
339
+ vs Baseline
340
+ <div class="text-xs font-normal text-muted-foreground/70 mt-1">Colour Iterative</div>
341
+ </th>
342
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">
343
+ Delta-E
344
+ <div class="text-xs font-normal text-muted-foreground/70 mt-1">vs Colour Lib</div>
345
+ </th>
346
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">Average MAE</th>
347
+ </tr>
348
+ </thead>
349
+ <tbody>
350
+ """
351
+
352
+ for model_name, result in sorted_by_delta_E:
353
+ size_mb = result["model_size_mb"]
354
+ speed_ms = result["inference_time_ms"]
355
+ avg_mae = avg_maes[model_name]
356
+ delta_E = result["delta_E"]
357
+
358
+ # Calculate relative speed (how many times faster than baseline)
359
+ speed_multiplier = baseline_speed / speed_ms if speed_ms > 0 else 0
360
+
361
+ size_class = "text-primary font-semibold" if model_name == best_size else ""
362
+ speed_class = "text-primary font-semibold" if model_name == best_speed else ""
363
+ avg_class = "text-primary font-semibold" if model_name == best_avg else ""
364
+ delta_E_class = (
365
+ "text-primary font-semibold" if model_name == best_delta_E else ""
366
+ )
367
+
368
+ # Format Delta-E value
369
+ delta_E_str = f"{delta_E:.4f}" if not np.isnan(delta_E) else "—"
370
+
371
+ # Highlight only the fastest model
372
+ if abs(speed_multiplier - 1.0) < 0.01:
373
+ # Baseline
374
+ multiplier_class = "text-muted-foreground"
375
+ multiplier_text = "1.0x"
376
+ elif model_name == best_multiplier_model:
377
+ # Fastest model (highest multiplier)
378
+ multiplier_class = "text-primary font-semibold"
379
+ if speed_multiplier > 1000:
380
+ multiplier_text = f"{speed_multiplier:.0f}x"
381
+ elif speed_multiplier > 100:
382
+ multiplier_text = f"{speed_multiplier:.1f}x"
383
+ else:
384
+ multiplier_text = f"{speed_multiplier:.2f}x"
385
+ elif speed_multiplier > 1.0:
386
+ # Faster than baseline but not the fastest
387
+ multiplier_class = ""
388
+ if speed_multiplier > 1000:
389
+ multiplier_text = f"{speed_multiplier:.0f}x"
390
+ elif speed_multiplier > 100:
391
+ multiplier_text = f"{speed_multiplier:.1f}x"
392
+ else:
393
+ multiplier_text = f"{speed_multiplier:.2f}x"
394
+ else:
395
+ # Slower than baseline
396
+ multiplier_class = "text-destructive"
397
+ multiplier_text = f"{speed_multiplier:.2f}x"
398
+
399
+ html += f"""
400
+ <tr class="border-b border-border/50 hover:bg-muted/30 transition-colors">
401
+ <td class="py-3 px-4 font-medium">{model_name}</td>
402
+ <td class="py-3 px-4 text-right {size_class}">{size_mb:.2f}</td>
403
+ <td class="py-3 px-4 text-right {speed_class}">{speed_ms:.4f}</td>
404
+ <td class="py-3 px-4 text-right {multiplier_class}">{multiplier_text}</td>
405
+ <td class="py-3 px-4 text-right {delta_E_class}">{delta_E_str}</td>
406
+ <td class="py-3 px-4 text-right {avg_class}">{avg_mae:.4f}</td>
407
+ </tr>
408
+ """
409
+
410
+ html += """
411
+ </tbody>
412
+ </table>
413
+ </div>
414
+ <div class="mt-6 p-4 bg-muted/30 rounded-md border border-primary/20">
415
+ <div class="text-sm space-y-2">
416
+ <div><span class="text-primary font-semibold">Note:</span> Speed measured with 10 iterations (3 warmup + 10 benchmark) on 2,734 samples.</div>
417
+ <div class="text-xs text-muted-foreground">Two-stage models include both base and error predictor. Highlighted values show best in each metric.</div>
418
+ <div class="text-xs text-muted-foreground">Baseline comparison: Speed multipliers show relative performance vs Colour Library's iterative xyY_to_munsell_specification(). Values &lt;1.0x are faster.</div>
419
+ </div>
420
+ </div>
421
+ </div>
422
+ """
423
+
424
+ # Overall ranking by Delta-E
425
+ html += """
426
+ <!-- Overall Ranking -->
427
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
428
+ <h2 class="text-2xl font-semibold mb-4 pb-2 border-b border-primary/30">Overall Ranking (by Delta-E)</h2>
429
+ <div class="space-y-1">
430
+ """
431
+
432
+ # Sort by Delta-E (best = lowest)
433
+ sorted_by_delta_E_ranking = sorted(
434
+ [
435
+ (name, res["delta_E"])
436
+ for name, res in results.items()
437
+ if not np.isnan(res["delta_E"])
438
+ ],
439
+ key=lambda x: x[1],
440
+ )
441
+
442
+ max_delta_E = (
443
+ max(delta_E for _, delta_E in sorted_by_delta_E_ranking)
444
+ if sorted_by_delta_E_ranking
445
+ else 1.0
446
+ )
447
+ for rank, (model_name, delta_E) in enumerate(sorted_by_delta_E_ranking, 1):
448
+ width_pct = (delta_E / max_delta_E) * 100
449
+ html += f"""
450
+ <div class="flex items-center gap-3 p-2 rounded-md hover:bg-muted/50 transition-colors">
451
+ <div class="flex-none w-80 text-sm font-medium">
452
+ <span class="text-muted-foreground">{rank}.</span> {model_name}
453
+ </div>
454
+ <div class="flex-1 h-6 bg-muted rounded-md overflow-hidden">
455
+ <div class="bar-fill h-full rounded-md" style="width: {width_pct}%"></div>
456
+ </div>
457
+ <div class="flex-none w-20 text-right font-bold text-primary">{delta_E:.4f}</div>
458
+ </div>
459
+ """
460
+
461
+ html += """
462
+ </div>
463
+ </div>
464
+ """
465
+
466
+ # Precision Threshold Table
467
+ html += """
468
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
469
+ <h2 class="text-2xl font-semibold mb-3 pb-3 border-b border-primary/30">Accuracy at Precision Thresholds</h2>
470
+ <p class="text-sm text-muted-foreground mb-6">Percentage of predictions where max error across all components is below threshold:</p>
471
+ <div class="overflow-x-auto">
472
+ <table class="w-full text-sm">
473
+ <thead>
474
+ <tr class="border-b border-border">
475
+ <th class="text-left py-3 px-4 font-semibold text-muted-foreground">Model</th>
476
+ """
477
+
478
+ for threshold in thresholds:
479
+ html += f' <th class="text-right py-3 px-4 font-semibold text-muted-foreground">&lt; {threshold:.0e}</th>\n'
480
+
481
+ html += """
482
+ </tr>
483
+ </thead>
484
+ <tbody>
485
+ """
486
+
487
+ # Find best (highest) accuracy for each threshold column
488
+ best_accuracies = {}
489
+ min_accuracies = {}
490
+ for threshold in thresholds:
491
+ accuracies = [
492
+ np.mean(results[model_name]["max_errors"] < threshold) * 100
493
+ for model_name, _ in sorted_models
494
+ ]
495
+ best_accuracies[threshold] = max(accuracies)
496
+ min_accuracies[threshold] = min(accuracies)
497
+
498
+ for model_name, _ in sorted_models:
499
+ result = results[model_name]
500
+ row_class = (
501
+ "bg-primary/10 border-l-2 border-l-primary"
502
+ if model_name == best_avg
503
+ else ""
504
+ )
505
+ html += f"""
506
+ <tr class="border-b border-border hover:bg-muted/30 transition-colors {row_class}">
507
+ <td class="text-left py-3 px-4 font-medium">{model_name}</td>
508
+ """
509
+ for threshold in thresholds:
510
+ accuracy_pct = np.mean(result["max_errors"] < threshold) * 100
511
+ # Only highlight if there's meaningful variation
512
+ # (>0.1% difference between best and worst)
513
+ has_variation = (
514
+ best_accuracies[threshold] - min_accuracies[threshold]
515
+ ) > 0.1
516
+ is_best = abs(accuracy_pct - best_accuracies[threshold]) < 0.01
517
+ cell_class = (
518
+ "text-right py-3 px-4 font-bold text-primary"
519
+ if (has_variation and is_best)
520
+ else "text-right py-3 px-4"
521
+ )
522
+ html += f' <td class="{cell_class}">{accuracy_pct:.2f}%</td>\n'
523
+
524
+ html += """
525
+ </tr>
526
+ """
527
+
528
+ html += """
529
+ </tbody>
530
+ </table>
531
+ </div>
532
+ </div>
533
+
534
+ </div>
535
+ </body>
536
+ </html>
537
+ """
538
+
539
+ # Write HTML file
540
+ with open(output_file, "w") as f:
541
+ f.write(html)
542
+
543
+ LOGGER.info("")
544
+ LOGGER.info("HTML report saved to: %s", output_file)
545
+
546
+
547
+ def main() -> None:
548
+ """Compare all models."""
549
+ LOGGER.info("=" * 80)
550
+ LOGGER.info("Comprehensive Model Comparison")
551
+ LOGGER.info("=" * 80)
552
+
553
+ # Paths
554
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
555
+
556
+ # Load real Munsell dataset
557
+ LOGGER.info("")
558
+ LOGGER.info("Loading real Munsell dataset...")
559
+ xyY_samples = []
560
+ ground_truth = []
561
+
562
+ for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL:
563
+ try:
564
+ hue_code, value, chroma = munsell_spec_tuple
565
+ munsell_str = f"{hue_code} {value}/{chroma}"
566
+ spec = munsell_colour_to_munsell_specification(munsell_str)
567
+ xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
568
+ xyY_samples.append(xyY_scaled)
569
+ ground_truth.append(spec)
570
+ except Exception: # noqa: BLE001, S112
571
+ continue
572
+
573
+ xyY_samples = np.array(xyY_samples)
574
+ ground_truth = np.array(ground_truth)
575
+ LOGGER.info("Loaded %d valid Munsell colors", len(xyY_samples))
576
+
577
+ # Define models to compare
578
+ models = [
579
+ {
580
+ "name": "MLP (Base Only)",
581
+ "files": [model_directory / "mlp.onnx"],
582
+ "params_file": model_directory / "mlp_normalization_params.npz",
583
+ "type": "single",
584
+ },
585
+ {
586
+ "name": "MLP + Error Predictor",
587
+ "files": [
588
+ model_directory / "mlp.onnx",
589
+ model_directory / "mlp_error_predictor.onnx",
590
+ ],
591
+ "params_file": model_directory / "mlp_normalization_params.npz",
592
+ "type": "two_stage",
593
+ },
594
+ {
595
+ "name": "Unified MLP",
596
+ "files": [model_directory / "unified_mlp.onnx"],
597
+ "params_file": model_directory / "unified_mlp_normalization_params.npz",
598
+ "type": "single",
599
+ },
600
+ {
601
+ "name": "MLP + Self-Attention",
602
+ "files": [model_directory / "mlp_attention.onnx"],
603
+ "params_file": model_directory
604
+ / "mlp_attention_normalization_params.npz",
605
+ "type": "single",
606
+ },
607
+ {
608
+ "name": "MLP + Self-Attention + Error Predictor",
609
+ "files": [
610
+ model_directory / "mlp_attention.onnx",
611
+ model_directory / "mlp_attention_error_predictor.onnx",
612
+ ],
613
+ "params_file": model_directory
614
+ / "mlp_attention_normalization_params.npz",
615
+ "type": "two_stage",
616
+ },
617
+ {
618
+ "name": "Deep + Wide",
619
+ "files": [model_directory / "deep_wide.onnx"],
620
+ "params_file": model_directory / "deep_wide_normalization_params.npz",
621
+ "type": "single",
622
+ },
623
+ {
624
+ "name": "Mixture of Experts",
625
+ "files": [model_directory / "mixture_of_experts.onnx"],
626
+ "params_file": model_directory
627
+ / "mixture_of_experts_normalization_params.npz",
628
+ "type": "single",
629
+ },
630
+ {
631
+ "name": "FT-Transformer",
632
+ "files": [model_directory / "ft_transformer.onnx"],
633
+ "params_file": model_directory / "ft_transformer_normalization_params.npz",
634
+ "type": "single",
635
+ },
636
+ {
637
+ "name": "Multi-Head",
638
+ "files": [model_directory / "multi_head.onnx"],
639
+ "params_file": model_directory / "multi_head_normalization_params.npz",
640
+ "type": "single",
641
+ },
642
+ {
643
+ "name": "Multi-Head (Optimized)",
644
+ "files": [model_directory / "multi_head_optimized.onnx"],
645
+ "params_file": model_directory / "multi_head_optimized_normalization_params.npz",
646
+ "type": "single",
647
+ },
648
+ {
649
+ "name": "Multi-Head + Error Predictor",
650
+ "files": [
651
+ model_directory / "multi_head.onnx",
652
+ model_directory / "multi_head_error_predictor.onnx",
653
+ ],
654
+ "params_file": model_directory / "multi_head_normalization_params.npz",
655
+ "type": "two_stage",
656
+ },
657
+ {
658
+ "name": "Multi-MLP",
659
+ "files": [model_directory / "multi_mlp.onnx"],
660
+ "params_file": model_directory / "multi_mlp_normalization_params.npz",
661
+ "type": "single",
662
+ },
663
+ {
664
+ "name": "Multi-MLP + Error Predictor",
665
+ "files": [
666
+ model_directory / "multi_mlp.onnx",
667
+ model_directory / "multi_mlp_error_predictor.onnx",
668
+ ],
669
+ "params_file": model_directory / "multi_mlp_normalization_params.npz",
670
+ "type": "two_stage",
671
+ },
672
+ {
673
+ "name": "Multi-MLP + Multi-Error Predictor",
674
+ "files": [
675
+ model_directory / "multi_mlp.onnx",
676
+ model_directory / "multi_mlp_multi_error_predictor.onnx",
677
+ ],
678
+ "params_file": model_directory / "multi_mlp_normalization_params.npz",
679
+ "type": "two_stage",
680
+ },
681
+ {
682
+ "name": "Multi-MLP + Multi-Error Predictor (Optimized)",
683
+ "files": [
684
+ model_directory / "multi_mlp.onnx",
685
+ model_directory / "multi_mlp_multi_error_predictor_optimized.onnx",
686
+ ],
687
+ "params_file": model_directory / "multi_mlp_normalization_params.npz",
688
+ "type": "two_stage",
689
+ },
690
+ {
691
+ "name": "Multi-MLP (Optimized)",
692
+ "files": [model_directory / "multi_mlp_optimized.onnx"],
693
+ "params_file": model_directory / "multi_mlp_optimized_normalization_params.npz",
694
+ "type": "single",
695
+ },
696
+ {
697
+ "name": "Multi-Head + Multi-Error Predictor",
698
+ "files": [
699
+ model_directory / "multi_head.onnx",
700
+ model_directory / "multi_head_multi_error_predictor.onnx",
701
+ ],
702
+ "params_file": model_directory / "multi_head_normalization_params.npz",
703
+ "type": "two_stage",
704
+ },
705
+ {
706
+ "name": "Multi-Head + Cross-Attention Error Predictor",
707
+ "files": [
708
+ model_directory / "multi_head.onnx",
709
+ model_directory / "multi_head_cross_attention_error_predictor.onnx",
710
+ ],
711
+ "params_file": model_directory / "multi_head_normalization_params.npz",
712
+ "type": "two_stage",
713
+ },
714
+ {
715
+ "name": "Multi-Head (Optimized) + Multi-Error Predictor (Optimized)",
716
+ "files": [
717
+ model_directory / "multi_head_optimized.onnx",
718
+ model_directory / "multi_head_error_predictor_optimized.onnx",
719
+ ],
720
+ "params_file": model_directory / "multi_head_optimized_normalization_params.npz",
721
+ "type": "two_stage",
722
+ },
723
+ {
724
+ "name": "Multi-Head (Circular Loss)",
725
+ "files": [model_directory / "multi_head_circular.onnx"],
726
+ "params_file": model_directory / "multi_head_circular_normalization_params.npz",
727
+ "type": "single",
728
+ },
729
+ {
730
+ "name": "Multi-Head (Large Dataset)",
731
+ "files": [model_directory / "multi_head_large.onnx"],
732
+ "params_file": model_directory / "multi_head_large_normalization_params.npz",
733
+ "type": "single",
734
+ },
735
+ {
736
+ "name": "Multi-Head + Multi-Error Predictor (Large Dataset)",
737
+ "files": [
738
+ model_directory / "multi_head_large.onnx",
739
+ model_directory / "multi_head_multi_error_predictor_large.onnx",
740
+ ],
741
+ "params_file": model_directory / "multi_head_large_normalization_params.npz",
742
+ "type": "two_stage",
743
+ },
744
+ {
745
+ "name": "Multi-MLP (Large Dataset)",
746
+ "files": [model_directory / "multi_mlp_large.onnx"],
747
+ "params_file": model_directory / "multi_mlp_large_normalization_params.npz",
748
+ "type": "single",
749
+ },
750
+ {
751
+ "name": "Multi-MLP + Multi-Error Predictor (Large Dataset)",
752
+ "files": [
753
+ model_directory / "multi_mlp_large.onnx",
754
+ model_directory / "multi_mlp_multi_error_predictor_large.onnx",
755
+ ],
756
+ "params_file": model_directory / "multi_mlp_large_normalization_params.npz",
757
+ "type": "two_stage",
758
+ },
759
+ {
760
+ "name": "Transformer (Large Dataset)",
761
+ "files": [model_directory / "transformer_large.onnx"],
762
+ "params_file": model_directory / "transformer_large_normalization_params.npz",
763
+ "type": "single",
764
+ },
765
+ {
766
+ "name": "Transformer + Error Predictor (Large Dataset)",
767
+ "files": [
768
+ model_directory / "transformer_large.onnx",
769
+ model_directory / "transformer_multi_error_predictor_large.onnx",
770
+ ],
771
+ "params_file": model_directory / "transformer_large_normalization_params.npz",
772
+ "type": "two_stage",
773
+ },
774
+ {
775
+ "name": "Multi-Head Refined (REAL Only)",
776
+ "files": [model_directory / "multi_head_refined_real.onnx"],
777
+ "params_file": model_directory / "multi_head_refined_real_normalization_params.npz",
778
+ "type": "single",
779
+ },
780
+ {
781
+ "name": "Multi-Head Refined + Error Predictor (REAL Only)",
782
+ "files": [
783
+ model_directory / "multi_head_refined_real.onnx",
784
+ model_directory / "multi_head_multi_error_predictor_refined_real.onnx",
785
+ ],
786
+ "params_file": model_directory / "multi_head_refined_real_normalization_params.npz",
787
+ "type": "two_stage",
788
+ },
789
+ {
790
+ "name": "Multi-Head + Multi-Error Predictor + Multi-Error Predictor (3-Stage)",
791
+ "files": [
792
+ model_directory / "multi_head_large.onnx",
793
+ model_directory / "multi_head_multi_error_predictor_large.onnx",
794
+ model_directory / "multi_head_3stage_error_predictor.onnx",
795
+ ],
796
+ "params_file": model_directory / "multi_head_large_normalization_params.npz",
797
+ "type": "three_stage",
798
+ },
799
+ {
800
+ "name": "Multi-Head (Weighted + Boundary Loss)",
801
+ "files": [model_directory / "multi_head_weighted_boundary.onnx"],
802
+ "params_file": model_directory / "multi_head_weighted_boundary_normalization_params.npz",
803
+ "type": "single",
804
+ },
805
+ {
806
+ "name": "Multi-Head (Weighted + Boundary Loss) + Multi-Error Predictor",
807
+ "files": [
808
+ model_directory / "multi_head_weighted_boundary.onnx",
809
+ model_directory / "multi_head_weighted_boundary_multi_error_predictor.onnx",
810
+ ],
811
+ "params_file": model_directory / "multi_head_weighted_boundary_normalization_params.npz",
812
+ "type": "two_stage",
813
+ },
814
+ {
815
+ "name": "Multi-Head (Weighted + Boundary Loss) + Multi-Error Predictor (Weighted + Boundary Loss)",
816
+ "files": [
817
+ model_directory / "multi_head_weighted_boundary.onnx",
818
+ model_directory / "multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx",
819
+ ],
820
+ "params_file": model_directory / "multi_head_weighted_boundary_normalization_params.npz",
821
+ "type": "two_stage",
822
+ },
823
+ {
824
+ "name": "Multi-MLP (Weighted + Boundary Loss) (Large Dataset)",
825
+ "files": [model_directory / "multi_mlp_weighted_boundary.onnx"],
826
+ "params_file": model_directory / "multi_mlp_weighted_boundary_normalization_params.npz",
827
+ "type": "single",
828
+ },
829
+ {
830
+ "name": "Multi-MLP (Weighted + Boundary Loss) + Multi-Error Predictor (Weighted + Boundary Loss) (Large Dataset)",
831
+ "files": [
832
+ model_directory / "multi_mlp_weighted_boundary.onnx",
833
+ model_directory / "multi_mlp_weighted_boundary_multi_error_predictor.onnx",
834
+ ],
835
+ "params_file": model_directory / "multi_mlp_weighted_boundary_normalization_params.npz",
836
+ "type": "two_stage",
837
+ },
838
+ {
839
+ "name": "Multi-ResNet (Large Dataset)",
840
+ "files": [model_directory / "multi_resnet_large.onnx"],
841
+ "params_file": model_directory / "multi_resnet_large_normalization_params.npz",
842
+ "type": "single",
843
+ },
844
+ {
845
+ "name": "Multi-ResNet + Multi-Error Predictor (Large Dataset)",
846
+ "files": [
847
+ model_directory / "multi_resnet_large.onnx",
848
+ model_directory / "multi_resnet_error_predictor_large.onnx",
849
+ ],
850
+ "params_file": model_directory / "multi_resnet_large_normalization_params.npz",
851
+ "type": "two_stage",
852
+ },
853
+ ]
854
+
855
+ # Benchmark colour library's iterative implementation first
856
+ LOGGER.info("")
857
+ LOGGER.info("=" * 80)
858
+ LOGGER.info("Colour Library (Iterative)")
859
+ LOGGER.info("=" * 80)
860
+
861
+ # Benchmark the iterative xyY_to_munsell_specification function
862
+ # Note: Using full dataset (100% of samples)
863
+
864
+ # Set random seed for reproducibility
865
+ np.random.seed(42)
866
+
867
+ # Use 100% of samples for comprehensive benchmarking
868
+ sample_count = len(xyY_samples)
869
+ sampled_indices = np.arange(len(xyY_samples))
870
+ xyY_benchmark_samples = xyY_samples[sampled_indices]
871
+
872
+ # Measure inference time on sampled Munsell colors
873
+ start_time = time.perf_counter()
874
+ convergence_failures = 0
875
+ successful_inferences = 0
876
+
877
+ with warnings.catch_warnings():
878
+ warnings.simplefilter("ignore")
879
+ for xyy in xyY_benchmark_samples:
880
+ try:
881
+ xyY_to_munsell_specification(xyy)
882
+ successful_inferences += 1
883
+ except (RuntimeError, ValueError):
884
+ # Out-of-gamut color that doesn't converge or not in renotation system
885
+ convergence_failures += 1
886
+
887
+ end_time = time.perf_counter()
888
+
889
+ # Calculate average time per successful inference (in milliseconds)
890
+ total_time_s = end_time - start_time
891
+ colour_inference_time_ms = (
892
+ (total_time_s / successful_inferences) * 1000
893
+ if successful_inferences > 0
894
+ else 0
895
+ )
896
+
897
+ LOGGER.info("")
898
+ LOGGER.info("Performance Metrics:")
899
+ LOGGER.info(" Successful inferences: %d", successful_inferences)
900
+ LOGGER.info(" Convergence failures: %d", convergence_failures)
901
+ LOGGER.info(" Inference Speed: %.4f ms/sample", colour_inference_time_ms)
902
+ LOGGER.info(" Note: This is the baseline iterative implementation")
903
+
904
+ # Store the baseline speed
905
+ baseline_inference_time_ms = colour_inference_time_ms
906
+
907
+ # Convert ground truth Munsell specs to CIE Lab for Delta-E comparison
908
+ # Path: Munsell spec → xyY → XYZ → Lab
909
+ LOGGER.info("")
910
+ LOGGER.info(
911
+ "Converting ground truth to CIE Lab for Delta-E comparison..."
912
+ )
913
+ LOGGER.info(" Path: Munsell spec \u2192 xyY \u2192 XYZ \u2192 Lab")
914
+ reference_Lab = []
915
+ for spec in ground_truth:
916
+ try:
917
+ # Munsell specification → xyY
918
+ xyy = munsell_specification_to_xyY(spec)
919
+ # xyY → XYZ
920
+ XYZ = xyY_to_XYZ(xyy)
921
+ # XYZ → Lab (Illuminant C for Munsell)
922
+ Lab = XYZ_to_Lab(XYZ, CCS_ILLUMINANT_MUNSELL)
923
+ reference_Lab.append(Lab)
924
+ except (RuntimeError, ValueError):
925
+ # If conversion fails, use NaN
926
+ reference_Lab.append(np.array([np.nan, np.nan, np.nan]))
927
+
928
+ reference_Lab = np.array(reference_Lab)
929
+ LOGGER.info(
930
+ " Converted %d ground truth specs to CIE Lab",
931
+ len(reference_Lab),
932
+ )
933
+
934
+ # Use the same sampled subset for ML model evaluations (for fair comparison)
935
+ xyY_samples = xyY_benchmark_samples
936
+ ground_truth = ground_truth[sampled_indices]
937
+
938
+ # Evaluate each model
939
+ results = {}
940
+
941
+ for model_info in models:
942
+ model_name = model_info["name"]
943
+ LOGGER.info("")
944
+ LOGGER.info("=" * 80)
945
+ LOGGER.info(model_name)
946
+ LOGGER.info("=" * 80)
947
+
948
+ # Load normalization params for this model
949
+ params = np.load(model_info["params_file"], allow_pickle=True)
950
+ # input_params may not exist if xyY is already in [0, 1] range
951
+ input_params = (
952
+ params["input_params"].item()
953
+ if "input_params" in params.files
954
+ else None
955
+ )
956
+ output_params = params["output_params"].item()
957
+
958
+ # Normalize input with this model's params (None means no normalization)
959
+ X_norm = normalize_input(xyY_samples, input_params)
960
+
961
+ # Calculate model size
962
+ model_size_mb = get_model_size_mb(model_info["files"])
963
+
964
+ if model_info["type"] == "two_stage":
965
+ # Two-stage model
966
+ base_session = ort.InferenceSession(str(model_info["files"][0]))
967
+ error_session = ort.InferenceSession(str(model_info["files"][1]))
968
+
969
+ # Define inference callable for benchmarking
970
+ def two_stage_inference(
971
+ _base_session: ort.InferenceSession = base_session,
972
+ _error_session: ort.InferenceSession = error_session,
973
+ _X_norm: NDArray = X_norm,
974
+ ) -> NDArray:
975
+ base_pred = _base_session.run(None, {"xyY": _X_norm})[0]
976
+ combined = np.concatenate([_X_norm, base_pred], axis=1).astype(
977
+ np.float32
978
+ )
979
+ error_corr = _error_session.run(None, {"combined_input": combined})[
980
+ 0
981
+ ]
982
+ return base_pred + error_corr
983
+
984
+ # Benchmark speed
985
+ inference_time_ms = benchmark_inference_speed(
986
+ two_stage_inference, X_norm
987
+ )
988
+
989
+ # Get predictions
990
+ base_pred_norm = base_session.run(None, {"xyY": X_norm})[0]
991
+ combined_input = np.concatenate(
992
+ [X_norm, base_pred_norm], axis=1
993
+ ).astype(np.float32)
994
+ error_correction_norm = error_session.run(
995
+ None, {"combined_input": combined_input}
996
+ )[0]
997
+ final_pred_norm = base_pred_norm + error_correction_norm
998
+ pred = denormalize_output(final_pred_norm, output_params)
999
+ errors = np.abs(pred - ground_truth)
1000
+
1001
+ result = {
1002
+ "hue_mae": np.mean(errors[:, 0]),
1003
+ "value_mae": np.mean(errors[:, 1]),
1004
+ "chroma_mae": np.mean(errors[:, 2]),
1005
+ "code_mae": np.mean(errors[:, 3]),
1006
+ "max_errors": np.max(errors, axis=1),
1007
+ "hue_errors": errors[:, 0],
1008
+ "value_errors": errors[:, 1],
1009
+ "chroma_errors": errors[:, 2],
1010
+ "code_errors": errors[:, 3],
1011
+ "model_size_mb": model_size_mb,
1012
+ "inference_time_ms": inference_time_ms,
1013
+ }
1014
+
1015
+ # Compute Delta-E against ground truth
1016
+ delta_E_values = []
1017
+ for idx in range(len(pred)):
1018
+ try:
1019
+ # Convert ML prediction to Lab: Munsell spec → xyY → XYZ → Lab
1020
+ ml_spec = clamp_munsell_specification(pred[idx])
1021
+
1022
+ # Round Code to nearest integer before round-trip conversion
1023
+ ml_spec_for_conversion = ml_spec.copy()
1024
+ ml_spec_for_conversion[3] = round(ml_spec[3])
1025
+
1026
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
1027
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
1028
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
1029
+
1030
+ # Get ground truth Lab
1031
+ reference_Lab_sample = reference_Lab[idx]
1032
+
1033
+ # Compute Delta-E CIE2000
1034
+ delta_E = delta_E_CIE2000(reference_Lab_sample, ml_Lab)
1035
+ delta_E_values.append(delta_E)
1036
+ except (RuntimeError, ValueError):
1037
+ # Skip if conversion fails
1038
+ continue
1039
+
1040
+ result["delta_E"] = (
1041
+ np.mean(delta_E_values) if delta_E_values else np.nan
1042
+ )
1043
+ elif model_info["type"] == "three_stage":
1044
+ # Three-stage model: base + error predictor 1 + error predictor 2
1045
+ base_session = ort.InferenceSession(str(model_info["files"][0]))
1046
+ error1_session = ort.InferenceSession(str(model_info["files"][1]))
1047
+ error2_session = ort.InferenceSession(str(model_info["files"][2]))
1048
+
1049
+ # Define inference callable for benchmarking
1050
+ def three_stage_inference(
1051
+ _base_session: ort.InferenceSession = base_session,
1052
+ _error1_session: ort.InferenceSession = error1_session,
1053
+ _error2_session: ort.InferenceSession = error2_session,
1054
+ _X_norm: NDArray = X_norm,
1055
+ ) -> NDArray:
1056
+ # Stage 1: Base model
1057
+ base_pred = _base_session.run(None, {"xyY": _X_norm})[0]
1058
+ # Stage 2: First error correction
1059
+ combined1 = np.concatenate([_X_norm, base_pred], axis=1).astype(
1060
+ np.float32
1061
+ )
1062
+ error1_corr = _error1_session.run(
1063
+ None, {"combined_input": combined1}
1064
+ )[0]
1065
+ stage2_pred = base_pred + error1_corr
1066
+ # Stage 3: Second error correction
1067
+ combined2 = np.concatenate([_X_norm, stage2_pred], axis=1).astype(
1068
+ np.float32
1069
+ )
1070
+ error2_corr = _error2_session.run(
1071
+ None, {"combined_input": combined2}
1072
+ )[0]
1073
+ return stage2_pred + error2_corr
1074
+
1075
+ # Benchmark speed
1076
+ inference_time_ms = benchmark_inference_speed(
1077
+ three_stage_inference, X_norm
1078
+ )
1079
+
1080
+ # Get predictions
1081
+ base_pred_norm = base_session.run(None, {"xyY": X_norm})[0]
1082
+ combined1 = np.concatenate([X_norm, base_pred_norm], axis=1).astype(
1083
+ np.float32
1084
+ )
1085
+ error1_corr_norm = error1_session.run(
1086
+ None, {"combined_input": combined1}
1087
+ )[0]
1088
+ stage2_pred_norm = base_pred_norm + error1_corr_norm
1089
+ combined2 = np.concatenate([X_norm, stage2_pred_norm], axis=1).astype(
1090
+ np.float32
1091
+ )
1092
+ error2_corr_norm = error2_session.run(
1093
+ None, {"combined_input": combined2}
1094
+ )[0]
1095
+ final_pred_norm = stage2_pred_norm + error2_corr_norm
1096
+ pred = denormalize_output(final_pred_norm, output_params)
1097
+ errors = np.abs(pred - ground_truth)
1098
+
1099
+ result = {
1100
+ "hue_mae": np.mean(errors[:, 0]),
1101
+ "value_mae": np.mean(errors[:, 1]),
1102
+ "chroma_mae": np.mean(errors[:, 2]),
1103
+ "code_mae": np.mean(errors[:, 3]),
1104
+ "max_errors": np.max(errors, axis=1),
1105
+ "hue_errors": errors[:, 0],
1106
+ "value_errors": errors[:, 1],
1107
+ "chroma_errors": errors[:, 2],
1108
+ "code_errors": errors[:, 3],
1109
+ "model_size_mb": model_size_mb,
1110
+ "inference_time_ms": inference_time_ms,
1111
+ }
1112
+
1113
+ # Compute Delta-E against ground truth for three-stage model
1114
+ delta_E_values = []
1115
+ for idx in range(len(pred)):
1116
+ try:
1117
+ ml_spec = clamp_munsell_specification(pred[idx])
1118
+ ml_spec_for_conversion = ml_spec.copy()
1119
+ ml_spec_for_conversion[3] = round(ml_spec[3])
1120
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
1121
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
1122
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
1123
+ delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
1124
+ delta_E_values.append(delta_E)
1125
+ except (RuntimeError, ValueError):
1126
+ continue
1127
+
1128
+ result["delta_E"] = (
1129
+ np.mean(delta_E_values) if delta_E_values else np.nan
1130
+ )
1131
+ else:
1132
+ # Single model
1133
+ session = ort.InferenceSession(str(model_info["files"][0]))
1134
+
1135
+ # Define inference callable for benchmarking
1136
+ def single_inference(
1137
+ _session: ort.InferenceSession = session, _X_norm: NDArray = X_norm
1138
+ ) -> NDArray:
1139
+ return _session.run(None, {"xyY": _X_norm})[0]
1140
+
1141
+ # Benchmark speed
1142
+ inference_time_ms = benchmark_inference_speed(single_inference, X_norm)
1143
+
1144
+ result = evaluate_model(
1145
+ session,
1146
+ X_norm,
1147
+ ground_truth,
1148
+ output_params,
1149
+ reference_Lab=reference_Lab,
1150
+ )
1151
+ result["model_size_mb"] = model_size_mb
1152
+ result["inference_time_ms"] = inference_time_ms
1153
+
1154
+ results[model_name] = result
1155
+
1156
+ # Print results
1157
+ LOGGER.info("")
1158
+ LOGGER.info("Mean Absolute Errors:")
1159
+ LOGGER.info(" Hue: %.4f", result["hue_mae"])
1160
+ LOGGER.info(" Value: %.4f", result["value_mae"])
1161
+ LOGGER.info(" Chroma: %.4f", result["chroma_mae"])
1162
+ LOGGER.info(" Code: %.4f", result["code_mae"])
1163
+ if not np.isnan(result["delta_E"]):
1164
+ LOGGER.info(" Delta-E (vs Ground Truth): %.4f", result["delta_E"])
1165
+ LOGGER.info("")
1166
+ LOGGER.info("Performance Metrics:")
1167
+ LOGGER.info(" Model Size: %.2f MB", result["model_size_mb"])
1168
+ LOGGER.info(
1169
+ " Inference Speed: %.4f ms/sample", result["inference_time_ms"]
1170
+ )
1171
+
1172
+
1173
+ # Summary comparison
1174
+ LOGGER.info("")
1175
+ LOGGER.info("=" * 80)
1176
+ LOGGER.info("SUMMARY COMPARISON")
1177
+ LOGGER.info("=" * 80)
1178
+ LOGGER.info("")
1179
+
1180
+ if not results:
1181
+ LOGGER.info("⚠️ No models were successfully evaluated")
1182
+ return
1183
+
1184
+ # MAE comparison table
1185
+ LOGGER.info("Mean Absolute Error Comparison:")
1186
+ LOGGER.info("")
1187
+ header = "{:<35} {:>8} {:>8} {:>8} {:>8} {:>10}".format(
1188
+ "Model",
1189
+ "Hue",
1190
+ "Value",
1191
+ "Chroma",
1192
+ "Code",
1193
+ "Delta-E",
1194
+ )
1195
+ LOGGER.info(header)
1196
+ LOGGER.info("-" * 90)
1197
+
1198
+ for model_name, result in results.items():
1199
+ delta_E_str = (
1200
+ f"{result['delta_E']:.4f}" if not np.isnan(result["delta_E"]) else "N/A"
1201
+ )
1202
+ LOGGER.info(
1203
+ "%-35s %8.4f %8.4f %8.4f %8.4f %10s",
1204
+ model_name[:35],
1205
+ result["hue_mae"],
1206
+ result["value_mae"],
1207
+ result["chroma_mae"],
1208
+ result["code_mae"],
1209
+ delta_E_str,
1210
+ )
1211
+
1212
+ # Precision threshold comparison
1213
+ LOGGER.info("")
1214
+ LOGGER.info("Accuracy at Precision Thresholds:")
1215
+ LOGGER.info("")
1216
+
1217
+ thresholds = [1e-4, 1e-3, 1e-2, 1e-1, 1.0]
1218
+ header_parts = [f"{'Model/Threshold':<35}"]
1219
+ header_parts.extend(f"{f'< {threshold:.0e}':>10}" for threshold in thresholds)
1220
+ LOGGER.info(" ".join(header_parts))
1221
+ LOGGER.info("-" * 80)
1222
+
1223
+ for model_name, result in results.items():
1224
+ row_parts = [f"{model_name[:35]:<35}"]
1225
+ for threshold in thresholds:
1226
+ accuracy_pct = np.mean(result["max_errors"] < threshold) * 100
1227
+ row_parts.append(f"{accuracy_pct:9.2f}%")
1228
+ LOGGER.info(" ".join(row_parts))
1229
+
1230
+ # Performance metrics comparison
1231
+ LOGGER.info("")
1232
+ LOGGER.info("Model Size and Inference Speed Comparison:")
1233
+ LOGGER.info("")
1234
+ header = f"{'Model':<35} {'Size (MB)':>12} {'Speed (ms/sample)':>18}"
1235
+ LOGGER.info(header)
1236
+ LOGGER.info("-" * 80)
1237
+
1238
+ for model_name, result in results.items():
1239
+ LOGGER.info(
1240
+ "%-35s %11.2f %17.4f",
1241
+ model_name[:35],
1242
+ result["model_size_mb"],
1243
+ result["inference_time_ms"],
1244
+ )
1245
+
1246
+ # Find best model
1247
+ LOGGER.info("")
1248
+ LOGGER.info("=" * 80)
1249
+ LOGGER.info("BEST MODELS BY METRIC")
1250
+ LOGGER.info("=" * 80)
1251
+ LOGGER.info("")
1252
+
1253
+ metrics = ["hue_mae", "value_mae", "chroma_mae", "code_mae"]
1254
+ metric_names = ["Hue MAE", "Value MAE", "Chroma MAE", "Code MAE"]
1255
+
1256
+ for metric, metric_name in zip(metrics, metric_names, strict=False):
1257
+ best_model = min(results.items(), key=lambda x: x[1][metric])
1258
+ LOGGER.info(
1259
+ "%-15s: %s (%.4f)",
1260
+ metric_name,
1261
+ best_model[0],
1262
+ best_model[1][metric],
1263
+ )
1264
+
1265
+ # Overall best (average rank)
1266
+ LOGGER.info("")
1267
+ LOGGER.info("Overall Best (by average component MAE):")
1268
+ for model_name, result in results.items():
1269
+ avg_mae = np.mean(
1270
+ [
1271
+ result["hue_mae"],
1272
+ result["value_mae"],
1273
+ result["chroma_mae"],
1274
+ result["code_mae"],
1275
+ ]
1276
+ )
1277
+ LOGGER.info(" %s: %.4f", model_name, avg_mae)
1278
+
1279
+ LOGGER.info("")
1280
+ LOGGER.info("=" * 80)
1281
+
1282
+ # Generate HTML report
1283
+ report_dir = PROJECT_ROOT / "reports" / "from_xyY"
1284
+ report_dir.mkdir(exist_ok=True)
1285
+ report_file = report_dir / "model_comparison.html"
1286
+ generate_html_report(
1287
+ results, len(xyY_samples), report_file, baseline_inference_time_ms
1288
+ )
1289
+
1290
+
1291
+ if __name__ == "__main__":
1292
+ main()
learning_munsell/comparison/from_xyY/compare_gamma_model.py ADDED
@@ -0,0 +1,390 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Quick comparison of the gamma-corrected models against baselines.
3
+
4
+ This script compares:
5
+ 1. MLP (Base) vs MLP (Gamma 2.33)
6
+ 2. Multi-Head (Base) vs Multi-Head (Gamma 2.33) vs Multi-Head (ST.2084)
7
+ """
8
+
9
+ import logging
10
+ from typing import Any
11
+
12
+ import numpy as np
13
+ import onnxruntime as ort
14
+ from colour import XYZ_to_Lab, xyY_to_XYZ
15
+ from colour.difference import delta_E_CIE2000
16
+ from colour.models import eotf_inverse_ST2084
17
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
18
+ from colour.notation.munsell import (
19
+ CCS_ILLUMINANT_MUNSELL,
20
+ munsell_colour_to_munsell_specification,
21
+ munsell_specification_to_xyY,
22
+ )
23
+ from numpy.typing import NDArray
24
+
25
+ from learning_munsell import PROJECT_ROOT
26
+
27
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
28
+ LOGGER = logging.getLogger(__name__)
29
+
30
+
31
+ def normalize_input_standard(X: NDArray, params: dict[str, Any]) -> NDArray:
32
+ """Standard xyY normalization."""
33
+ X_norm = np.copy(X)
34
+ X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / (
35
+ params["x_range"][1] - params["x_range"][0]
36
+ )
37
+ X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / (
38
+ params["y_range"][1] - params["y_range"][0]
39
+ )
40
+ X_norm[..., 2] = (X[..., 2] - params["Y_range"][0]) / (
41
+ params["Y_range"][1] - params["Y_range"][0]
42
+ )
43
+ return X_norm.astype(np.float32)
44
+
45
+
46
+ def normalize_input_gamma(X: NDArray, params: dict[str, Any]) -> NDArray:
47
+ """Gamma-corrected xyY normalization."""
48
+ gamma = params.get("gamma", 2.33)
49
+ X_norm = np.copy(X)
50
+ X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / (
51
+ params["x_range"][1] - params["x_range"][0]
52
+ )
53
+ X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / (
54
+ params["y_range"][1] - params["y_range"][0]
55
+ )
56
+ # Normalize Y then apply gamma
57
+ Y_normalized = (X[..., 2] - params["Y_range"][0]) / (
58
+ params["Y_range"][1] - params["Y_range"][0]
59
+ )
60
+ Y_normalized = np.clip(Y_normalized, 0, 1)
61
+ X_norm[..., 2] = np.power(Y_normalized, 1.0 / gamma)
62
+ return X_norm.astype(np.float32)
63
+
64
+
65
+ def normalize_input_st2084(X: NDArray, params: dict[str, Any]) -> NDArray:
66
+ """ST.2084 (PQ) encoded xyY normalization."""
67
+ L_p = params.get("L_p", 100.0)
68
+ X_norm = np.copy(X)
69
+ X_norm[..., 0] = (X[..., 0] - params["x_range"][0]) / (
70
+ params["x_range"][1] - params["x_range"][0]
71
+ )
72
+ X_norm[..., 1] = (X[..., 1] - params["y_range"][0]) / (
73
+ params["y_range"][1] - params["y_range"][0]
74
+ )
75
+ # Normalize Y then apply ST.2084
76
+ Y_normalized = (X[..., 2] - params["Y_range"][0]) / (
77
+ params["Y_range"][1] - params["Y_range"][0]
78
+ )
79
+ Y_normalized = np.clip(Y_normalized, 0, 1)
80
+ # Scale to cd/m² and apply ST.2084 inverse EOTF
81
+ Y_cdm2 = Y_normalized * L_p
82
+ X_norm[..., 2] = eotf_inverse_ST2084(Y_cdm2, L_p=L_p)
83
+ return X_norm.astype(np.float32)
84
+
85
+
86
+ def denormalize_output(y_norm: NDArray, params: dict[str, Any]) -> NDArray:
87
+ """Denormalize Munsell output."""
88
+ y = np.copy(y_norm)
89
+ y[..., 0] = (
90
+ y_norm[..., 0] * (params["hue_range"][1] - params["hue_range"][0])
91
+ + params["hue_range"][0]
92
+ )
93
+ y[..., 1] = (
94
+ y_norm[..., 1] * (params["value_range"][1] - params["value_range"][0])
95
+ + params["value_range"][0]
96
+ )
97
+ y[..., 2] = (
98
+ y_norm[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0])
99
+ + params["chroma_range"][0]
100
+ )
101
+ y[..., 3] = (
102
+ y_norm[..., 3] * (params["code_range"][1] - params["code_range"][0])
103
+ + params["code_range"][0]
104
+ )
105
+ return y
106
+
107
+
108
+ def clamp_munsell_specification(spec: NDArray) -> NDArray:
109
+ """Clamp Munsell specification to valid ranges."""
110
+ clamped = np.copy(spec)
111
+ clamped[..., 0] = np.clip(spec[..., 0], 0.0, 10.0) # Hue: [0, 10]
112
+ clamped[..., 1] = np.clip(spec[..., 1], 1.0, 9.0) # Value: [1, 9] (colour library constraint)
113
+ clamped[..., 2] = np.clip(spec[..., 2], 0.0, 50.0) # Chroma: [0, 50]
114
+ clamped[..., 3] = np.clip(spec[..., 3], 1.0, 10.0) # Code: [1, 10]
115
+ return clamped
116
+
117
+
118
+ def compute_delta_e(pred: NDArray, reference_Lab: NDArray) -> list[float]:
119
+ """Compute Delta-E for predictions."""
120
+ delta_E_values = []
121
+ for idx in range(len(pred)):
122
+ try:
123
+ ml_spec = clamp_munsell_specification(pred[idx])
124
+ ml_spec_for_conversion = ml_spec.copy()
125
+ ml_spec_for_conversion[3] = round(ml_spec[3])
126
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
127
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
128
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
129
+ delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
130
+ delta_E_values.append(delta_E)
131
+ except (RuntimeError, ValueError):
132
+ continue
133
+ return delta_E_values
134
+
135
+
136
+ def main() -> None:
137
+ """Compare gamma model against baseline."""
138
+ LOGGER.info("=" * 80)
139
+ LOGGER.info("Gamma Model Comparison: MLP vs MLP (Gamma 2.33)")
140
+ LOGGER.info("=" * 80)
141
+
142
+ models_dir = PROJECT_ROOT / "models" / "from_xyY"
143
+
144
+ # Load real Munsell data
145
+ LOGGER.info("\nLoading real Munsell colours...")
146
+ xyY_values = []
147
+ munsell_specs = []
148
+ reference_Lab = []
149
+
150
+ for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL:
151
+ try:
152
+ hue_code, value, chroma = munsell_spec_tuple
153
+ munsell_str = f"{hue_code} {value}/{chroma}"
154
+ spec = munsell_colour_to_munsell_specification(munsell_str)
155
+ xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
156
+
157
+ XYZ = xyY_to_XYZ(xyY_scaled)
158
+ Lab = XYZ_to_Lab(XYZ, CCS_ILLUMINANT_MUNSELL)
159
+
160
+ xyY_values.append(xyY_scaled)
161
+ munsell_specs.append(spec)
162
+ reference_Lab.append(Lab)
163
+ except (RuntimeError, ValueError):
164
+ continue
165
+
166
+ xyY_array = np.array(xyY_values)
167
+ ground_truth = np.array(munsell_specs)
168
+ reference_Lab = np.array(reference_Lab)
169
+
170
+ LOGGER.info("Loaded %d real Munsell colours", len(xyY_array))
171
+
172
+ # Test baseline MLP
173
+ LOGGER.info("\n" + "-" * 40)
174
+ LOGGER.info("1. MLP (Base) - Standard Normalization")
175
+ LOGGER.info("-" * 40)
176
+
177
+ base_onnx = models_dir / "mlp.onnx"
178
+ base_params_file = models_dir / "mlp_normalization_params.npz"
179
+
180
+ if base_onnx.exists() and base_params_file.exists():
181
+ base_session = ort.InferenceSession(str(base_onnx))
182
+ base_params_data = np.load(base_params_file, allow_pickle=True)
183
+ base_input_params = base_params_data["input_params"].item()
184
+ base_output_params = base_params_data["output_params"].item()
185
+
186
+ X_norm_base = normalize_input_standard(xyY_array, base_input_params)
187
+ pred_norm = base_session.run(None, {"xyY": X_norm_base})[0]
188
+ pred_base = denormalize_output(pred_norm, base_output_params)
189
+
190
+ errors_base = np.abs(pred_base - ground_truth)
191
+ delta_E_base = compute_delta_e(pred_base, reference_Lab)
192
+
193
+ LOGGER.info(" Hue MAE: %.4f", np.mean(errors_base[:, 0]))
194
+ LOGGER.info(" Value MAE: %.4f", np.mean(errors_base[:, 1]))
195
+ LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_base[:, 2]))
196
+ LOGGER.info(" Code MAE: %.4f", np.mean(errors_base[:, 3]))
197
+ LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)",
198
+ np.mean(delta_E_base), np.median(delta_E_base))
199
+ else:
200
+ LOGGER.info(" Model not found, skipping...")
201
+ delta_E_base = []
202
+
203
+ # Test gamma MLP
204
+ LOGGER.info("\n" + "-" * 40)
205
+ LOGGER.info("2. MLP (Gamma 2.33) - Gamma-Corrected Y")
206
+ LOGGER.info("-" * 40)
207
+
208
+ gamma_onnx = models_dir / "mlp_gamma.onnx"
209
+ gamma_params_file = models_dir / "mlp_gamma_normalization_params.npz"
210
+
211
+ if gamma_onnx.exists() and gamma_params_file.exists():
212
+ gamma_session = ort.InferenceSession(str(gamma_onnx))
213
+ gamma_params_data = np.load(gamma_params_file, allow_pickle=True)
214
+ gamma_input_params = gamma_params_data["input_params"].item()
215
+ gamma_output_params = gamma_params_data["output_params"].item()
216
+
217
+ X_norm_gamma = normalize_input_gamma(xyY_array, gamma_input_params)
218
+ pred_norm = gamma_session.run(None, {"xyY_gamma": X_norm_gamma})[0]
219
+ pred_gamma = denormalize_output(pred_norm, gamma_output_params)
220
+
221
+ errors_gamma = np.abs(pred_gamma - ground_truth)
222
+ delta_E_gamma = compute_delta_e(pred_gamma, reference_Lab)
223
+
224
+ LOGGER.info(" Hue MAE: %.4f", np.mean(errors_gamma[:, 0]))
225
+ LOGGER.info(" Value MAE: %.4f", np.mean(errors_gamma[:, 1]))
226
+ LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_gamma[:, 2]))
227
+ LOGGER.info(" Code MAE: %.4f", np.mean(errors_gamma[:, 3]))
228
+ LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)",
229
+ np.mean(delta_E_gamma), np.median(delta_E_gamma))
230
+ else:
231
+ LOGGER.info(" Model not found, skipping...")
232
+ delta_E_gamma = []
233
+
234
+ # Summary comparison for MLP
235
+ if delta_E_base and delta_E_gamma:
236
+ LOGGER.info("\n" + "=" * 80)
237
+ LOGGER.info("MLP COMPARISON SUMMARY")
238
+ LOGGER.info("=" * 80)
239
+ LOGGER.info("")
240
+ LOGGER.info("Delta-E (lower is better):")
241
+ LOGGER.info(" MLP (Base): %.4f mean, %.4f median",
242
+ np.mean(delta_E_base), np.median(delta_E_base))
243
+ LOGGER.info(" MLP (Gamma): %.4f mean, %.4f median",
244
+ np.mean(delta_E_gamma), np.median(delta_E_gamma))
245
+ LOGGER.info("")
246
+
247
+ improvement = (np.mean(delta_E_base) - np.mean(delta_E_gamma)) / np.mean(delta_E_base) * 100
248
+ if improvement > 0:
249
+ LOGGER.info(" Gamma model is %.1f%% BETTER", improvement)
250
+ else:
251
+ LOGGER.info(" Gamma model is %.1f%% WORSE", -improvement)
252
+
253
+ # Test Multi-Head baseline
254
+ LOGGER.info("\n" + "=" * 80)
255
+ LOGGER.info("MULTI-HEAD GAMMA EXPERIMENT")
256
+ LOGGER.info("=" * 80)
257
+
258
+ LOGGER.info("\n" + "-" * 40)
259
+ LOGGER.info("3. Multi-Head (Base) - Standard Normalization")
260
+ LOGGER.info("-" * 40)
261
+
262
+ mh_base_onnx = models_dir / "multi_head.onnx"
263
+ mh_base_params_file = models_dir / "multi_head_normalization_params.npz"
264
+
265
+ if mh_base_onnx.exists() and mh_base_params_file.exists():
266
+ mh_base_session = ort.InferenceSession(str(mh_base_onnx))
267
+ mh_base_params_data = np.load(mh_base_params_file, allow_pickle=True)
268
+ mh_base_input_params = mh_base_params_data["input_params"].item()
269
+ mh_base_output_params = mh_base_params_data["output_params"].item()
270
+
271
+ X_norm_mh_base = normalize_input_standard(xyY_array, mh_base_input_params)
272
+ pred_norm = mh_base_session.run(None, {"xyY": X_norm_mh_base})[0]
273
+ pred_mh_base = denormalize_output(pred_norm, mh_base_output_params)
274
+
275
+ errors_mh_base = np.abs(pred_mh_base - ground_truth)
276
+ delta_E_mh_base = compute_delta_e(pred_mh_base, reference_Lab)
277
+
278
+ LOGGER.info(" Hue MAE: %.4f", np.mean(errors_mh_base[:, 0]))
279
+ LOGGER.info(" Value MAE: %.4f", np.mean(errors_mh_base[:, 1]))
280
+ LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_mh_base[:, 2]))
281
+ LOGGER.info(" Code MAE: %.4f", np.mean(errors_mh_base[:, 3]))
282
+ LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)",
283
+ np.mean(delta_E_mh_base), np.median(delta_E_mh_base))
284
+ else:
285
+ LOGGER.info(" Model not found, skipping...")
286
+ delta_E_mh_base = []
287
+
288
+ # Test Multi-Head gamma
289
+ LOGGER.info("\n" + "-" * 40)
290
+ LOGGER.info("4. Multi-Head (Gamma 2.33) - Gamma-Corrected Y")
291
+ LOGGER.info("-" * 40)
292
+
293
+ mh_gamma_onnx = models_dir / "multi_head_gamma.onnx"
294
+ mh_gamma_params_file = models_dir / "multi_head_gamma_normalization_params.npz"
295
+
296
+ if mh_gamma_onnx.exists() and mh_gamma_params_file.exists():
297
+ mh_gamma_session = ort.InferenceSession(str(mh_gamma_onnx))
298
+ mh_gamma_params_data = np.load(mh_gamma_params_file, allow_pickle=True)
299
+ mh_gamma_input_params = mh_gamma_params_data["input_params"].item()
300
+ mh_gamma_output_params = mh_gamma_params_data["output_params"].item()
301
+
302
+ X_norm_mh_gamma = normalize_input_gamma(xyY_array, mh_gamma_input_params)
303
+ pred_norm = mh_gamma_session.run(None, {"xyY_gamma": X_norm_mh_gamma})[0]
304
+ pred_mh_gamma = denormalize_output(pred_norm, mh_gamma_output_params)
305
+
306
+ errors_mh_gamma = np.abs(pred_mh_gamma - ground_truth)
307
+ delta_E_mh_gamma = compute_delta_e(pred_mh_gamma, reference_Lab)
308
+
309
+ LOGGER.info(" Hue MAE: %.4f", np.mean(errors_mh_gamma[:, 0]))
310
+ LOGGER.info(" Value MAE: %.4f", np.mean(errors_mh_gamma[:, 1]))
311
+ LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_mh_gamma[:, 2]))
312
+ LOGGER.info(" Code MAE: %.4f", np.mean(errors_mh_gamma[:, 3]))
313
+ LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)",
314
+ np.mean(delta_E_mh_gamma), np.median(delta_E_mh_gamma))
315
+ else:
316
+ LOGGER.info(" Model not found, skipping...")
317
+ delta_E_mh_gamma = []
318
+
319
+ # Test Multi-Head ST.2084
320
+ LOGGER.info("\n" + "-" * 40)
321
+ LOGGER.info("5. Multi-Head (ST.2084) - PQ-Encoded Y")
322
+ LOGGER.info("-" * 40)
323
+
324
+ mh_st2084_onnx = models_dir / "multi_head_st2084.onnx"
325
+ mh_st2084_params_file = models_dir / "multi_head_st2084_normalization_params.npz"
326
+
327
+ if mh_st2084_onnx.exists() and mh_st2084_params_file.exists():
328
+ mh_st2084_session = ort.InferenceSession(str(mh_st2084_onnx))
329
+ mh_st2084_params_data = np.load(mh_st2084_params_file, allow_pickle=True)
330
+ mh_st2084_input_params = mh_st2084_params_data["input_params"].item()
331
+ mh_st2084_output_params = mh_st2084_params_data["output_params"].item()
332
+
333
+ X_norm_mh_st2084 = normalize_input_st2084(xyY_array, mh_st2084_input_params)
334
+ pred_norm = mh_st2084_session.run(None, {"xyY_st2084": X_norm_mh_st2084})[0]
335
+ pred_mh_st2084 = denormalize_output(pred_norm, mh_st2084_output_params)
336
+
337
+ errors_mh_st2084 = np.abs(pred_mh_st2084 - ground_truth)
338
+ delta_E_mh_st2084 = compute_delta_e(pred_mh_st2084, reference_Lab)
339
+
340
+ LOGGER.info(" Hue MAE: %.4f", np.mean(errors_mh_st2084[:, 0]))
341
+ LOGGER.info(" Value MAE: %.4f", np.mean(errors_mh_st2084[:, 1]))
342
+ LOGGER.info(" Chroma MAE: %.4f", np.mean(errors_mh_st2084[:, 2]))
343
+ LOGGER.info(" Code MAE: %.4f", np.mean(errors_mh_st2084[:, 3]))
344
+ LOGGER.info(" Delta-E: %.4f (mean), %.4f (median)",
345
+ np.mean(delta_E_mh_st2084), np.median(delta_E_mh_st2084))
346
+ else:
347
+ LOGGER.info(" Model not found, skipping...")
348
+ delta_E_mh_st2084 = []
349
+
350
+ # Summary comparison for Multi-Head
351
+ if delta_E_mh_base and delta_E_mh_gamma:
352
+ LOGGER.info("\n" + "=" * 80)
353
+ LOGGER.info("MULTI-HEAD COMPARISON SUMMARY")
354
+ LOGGER.info("=" * 80)
355
+ LOGGER.info("")
356
+ LOGGER.info("Delta-E (lower is better):")
357
+ LOGGER.info(" Multi-Head (Base): %.4f mean, %.4f median",
358
+ np.mean(delta_E_mh_base), np.median(delta_E_mh_base))
359
+ LOGGER.info(" Multi-Head (Gamma): %.4f mean, %.4f median",
360
+ np.mean(delta_E_mh_gamma), np.median(delta_E_mh_gamma))
361
+ if delta_E_mh_st2084:
362
+ LOGGER.info(" Multi-Head (ST.2084): %.4f mean, %.4f median",
363
+ np.mean(delta_E_mh_st2084), np.median(delta_E_mh_st2084))
364
+ LOGGER.info("")
365
+
366
+ mh_gamma_improvement = (np.mean(delta_E_mh_base) - np.mean(delta_E_mh_gamma)) / np.mean(delta_E_mh_base) * 100
367
+ if mh_gamma_improvement > 0:
368
+ LOGGER.info(" Multi-Head Gamma vs Base: %.1f%% BETTER", mh_gamma_improvement)
369
+ else:
370
+ LOGGER.info(" Multi-Head Gamma vs Base: %.1f%% WORSE", -mh_gamma_improvement)
371
+
372
+ if delta_E_mh_st2084:
373
+ mh_st2084_improvement = (np.mean(delta_E_mh_base) - np.mean(delta_E_mh_st2084)) / np.mean(delta_E_mh_base) * 100
374
+ if mh_st2084_improvement > 0:
375
+ LOGGER.info(" Multi-Head ST.2084 vs Base: %.1f%% BETTER", mh_st2084_improvement)
376
+ else:
377
+ LOGGER.info(" Multi-Head ST.2084 vs Base: %.1f%% WORSE", -mh_st2084_improvement)
378
+
379
+ # Compare ST.2084 vs Gamma
380
+ st2084_vs_gamma = (np.mean(delta_E_mh_gamma) - np.mean(delta_E_mh_st2084)) / np.mean(delta_E_mh_gamma) * 100
381
+ if st2084_vs_gamma > 0:
382
+ LOGGER.info(" Multi-Head ST.2084 vs Gamma: %.1f%% BETTER", st2084_vs_gamma)
383
+ else:
384
+ LOGGER.info(" Multi-Head ST.2084 vs Gamma: %.1f%% WORSE", -st2084_vs_gamma)
385
+
386
+ LOGGER.info("\n" + "=" * 80)
387
+
388
+
389
+ if __name__ == "__main__":
390
+ main()
learning_munsell/comparison/to_xyY/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Comparison scripts for Munsell to xyY conversion models."""
learning_munsell/comparison/to_xyY/compare_all_models.py ADDED
@@ -0,0 +1,617 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compare all ML models for Munsell to xyY conversion on real Munsell data.
3
+
4
+ Models to compare:
5
+ 1. Simple MLP Approximator
6
+ 2. Multi-Head MLP
7
+ 3. Multi-Head MLP (Optimized) - with hyperparameter optimization
8
+ 4. Multi-Head + Multi-Error Predictor
9
+ 5. Multi-MLP - 3 independent branches
10
+ 6. Multi-MLP (Optimized) - 3 independent branches with optimized hyperparameters
11
+ 7. Multi-MLP + Error Predictor
12
+ 8. Multi-MLP + Multi-Error Predictor
13
+ 9. Multi-MLP (Optimized) + Multi-Error Predictor (Optimized)
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ import time
20
+ import warnings
21
+ from typing import TYPE_CHECKING
22
+
23
+ import numpy as np
24
+ import onnxruntime as ort
25
+ from colour import XYZ_to_Lab, xyY_to_XYZ
26
+ from colour.difference import delta_E_CIE2000
27
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
28
+ from colour.notation.munsell import (
29
+ CCS_ILLUMINANT_MUNSELL,
30
+ munsell_colour_to_munsell_specification,
31
+ munsell_specification_to_xyY,
32
+ )
33
+ from numpy.typing import NDArray # noqa: TC002
34
+
35
+ from learning_munsell import PROJECT_ROOT
36
+ from learning_munsell.utilities.common import (
37
+ benchmark_inference_speed,
38
+ generate_html_report_footer,
39
+ generate_html_report_header,
40
+ generate_ranking_section,
41
+ get_model_size_mb,
42
+ )
43
+
44
+ if TYPE_CHECKING:
45
+ from pathlib import Path
46
+
47
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
48
+ LOGGER = logging.getLogger(__name__)
49
+
50
+
51
+ def normalize_munsell(munsell: np.ndarray) -> np.ndarray:
52
+ """Normalize Munsell specs to [0, 1] range."""
53
+ normalized = munsell.copy()
54
+ normalized[..., 0] = munsell[..., 0] / 10.0 # Hue (in decade)
55
+ normalized[..., 1] = munsell[..., 1] / 10.0 # Value
56
+ normalized[..., 2] = munsell[..., 2] / 50.0 # Chroma
57
+ normalized[..., 3] = munsell[..., 3] / 10.0 # Code
58
+ return normalized.astype(np.float32)
59
+
60
+
61
+ def evaluate_model(
62
+ session: ort.InferenceSession,
63
+ X_norm: np.ndarray,
64
+ ground_truth: np.ndarray,
65
+ input_name: str = "munsell_normalized",
66
+ ) -> dict:
67
+ """Evaluate a single model."""
68
+ pred = session.run(None, {input_name: X_norm})[0]
69
+ errors = np.abs(pred - ground_truth)
70
+
71
+ return {
72
+ "x_mae": np.mean(errors[:, 0]),
73
+ "y_mae": np.mean(errors[:, 1]),
74
+ "Y_mae": np.mean(errors[:, 2]),
75
+ "predictions": pred,
76
+ "errors": errors,
77
+ "max_errors": np.max(errors, axis=1),
78
+ }
79
+
80
+
81
+ def compute_delta_E(
82
+ ml_predictions: np.ndarray,
83
+ reference_xyY: np.ndarray,
84
+ ) -> float:
85
+ """Compute Delta-E CIE2000 between ML predictions and reference xyY (ground truth)."""
86
+ delta_E_values = []
87
+
88
+ for ml_xyY, ref_xyY in zip(ml_predictions, reference_xyY, strict=False):
89
+ try:
90
+ ml_XYZ = xyY_to_XYZ(ml_xyY)
91
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
92
+
93
+ ref_XYZ = xyY_to_XYZ(ref_xyY)
94
+ ref_Lab = XYZ_to_Lab(ref_XYZ, CCS_ILLUMINANT_MUNSELL)
95
+
96
+ delta_E = delta_E_CIE2000(ref_Lab, ml_Lab)
97
+ if not np.isnan(delta_E):
98
+ delta_E_values.append(delta_E)
99
+ except (RuntimeError, ValueError):
100
+ continue
101
+
102
+ return np.mean(delta_E_values) if delta_E_values else np.nan
103
+
104
+
105
+ def generate_html_report(
106
+ results: dict,
107
+ num_samples: int,
108
+ output_file: Path,
109
+ baseline_inference_time_ms: float,
110
+ ) -> None:
111
+ """Generate HTML report with visualizations."""
112
+ # Calculate average MAE
113
+ avg_maes = {}
114
+ for model_name, result in results.items():
115
+ avg_maes[model_name] = np.mean(
116
+ [
117
+ result["x_mae"],
118
+ result["y_mae"],
119
+ result["Y_mae"],
120
+ ]
121
+ )
122
+
123
+ # Sort by average MAE
124
+ sorted_models = sorted(avg_maes.items(), key=lambda x: x[1])
125
+
126
+ # Start HTML
127
+ html = generate_html_report_header(
128
+ title="ML Model Comparison Report",
129
+ subtitle="Munsell to xyY Conversion",
130
+ num_samples=num_samples,
131
+ )
132
+
133
+ # Best Models Summary
134
+ best_size = min(results.items(), key=lambda x: x[1]["model_size_mb"])[0]
135
+ best_speed = min(results.items(), key=lambda x: x[1]["inference_time_ms"])[0]
136
+ best_avg = sorted_models[0][0]
137
+
138
+ # Find best Delta-E
139
+ delta_E_results = [
140
+ (n, r["delta_E"]) for n, r in results.items() if not np.isnan(r["delta_E"])
141
+ ]
142
+ best_delta_E = (
143
+ min(delta_E_results, key=lambda x: x[1])[0] if delta_E_results else None
144
+ )
145
+
146
+ html += f"""
147
+ <!-- Best Models Summary -->
148
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
149
+ <h2 class="text-2xl font-semibold mb-6 pb-3 border-b border-primary/30">Best Models by Metric</h2>
150
+ <div class="grid grid-cols-1 md:grid-cols-2 lg:grid-cols-4 gap-4">
151
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
152
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Smallest Size</div>
153
+ <div class="text-3xl font-bold text-primary mb-3">{results[best_size]["model_size_mb"]:.2f} MB</div>
154
+ <div class="text-sm text-foreground/80">{best_size}</div>
155
+ </div>
156
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
157
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Fastest Speed</div>
158
+ <div class="text-3xl font-bold text-primary mb-3">{results[best_speed]["inference_time_ms"]:.4f} ms</div>
159
+ <div class="text-sm text-foreground/80">{best_speed}</div>
160
+ </div>
161
+ """
162
+
163
+ if best_delta_E:
164
+ html += f"""
165
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
166
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Best Delta-E</div>
167
+ <div class="text-3xl font-bold text-primary mb-3">{results[best_delta_E]["delta_E"]:.6f}</div>
168
+ <div class="text-sm text-foreground/80">{best_delta_E}</div>
169
+ </div>
170
+ """
171
+
172
+ html += f"""
173
+ <div class="bg-gradient-to-br from-primary/10 to-primary/5 rounded-lg p-5 border border-primary/20">
174
+ <div class="text-xs font-semibold text-muted-foreground uppercase tracking-wide mb-2">Best Average MAE</div>
175
+ <div class="text-3xl font-bold text-primary mb-3">{avg_maes[best_avg]:.6f}</div>
176
+ <div class="text-sm text-foreground/80">{best_avg}</div>
177
+ </div>
178
+ </div>
179
+ </div>
180
+ """
181
+
182
+ # Performance Metrics Table
183
+ sorted_by_avg_mae = sorted(results.items(), key=lambda x: avg_maes[x[0]])
184
+
185
+ html += """
186
+ <!-- Performance Metrics Table -->
187
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
188
+ <h2 class="text-2xl font-semibold mb-6 pb-3 border-b border-primary/30">Model Performance Metrics</h2>
189
+ <div class="overflow-x-auto">
190
+ <table class="w-full text-sm">
191
+ <thead>
192
+ <tr class="border-b border-border">
193
+ <th class="text-left py-3 px-4 font-semibold text-muted-foreground">Model</th>
194
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">Size (MB)</th>
195
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">Speed (ms/sample)</th>
196
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">vs Baseline</th>
197
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">MAE x</th>
198
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">MAE y</th>
199
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">MAE Y</th>
200
+ <th class="text-right py-3 px-4 font-semibold text-muted-foreground">Delta-E</th>
201
+ </tr>
202
+ </thead>
203
+ <tbody>
204
+ """
205
+
206
+ for model_name, result in sorted_by_avg_mae:
207
+ size_mb = result["model_size_mb"]
208
+ speed_ms = result["inference_time_ms"]
209
+ delta_E = result["delta_E"]
210
+
211
+ # Calculate speedup vs baseline
212
+ speedup = baseline_inference_time_ms / speed_ms if speed_ms > 0 else 0
213
+
214
+ size_class = "text-primary font-semibold" if model_name == best_size else ""
215
+ speed_class = "text-primary font-semibold" if model_name == best_speed else ""
216
+ delta_E_class = (
217
+ "text-primary font-semibold" if model_name == best_delta_E else ""
218
+ )
219
+
220
+ delta_E_str = f"{delta_E:.6f}" if not np.isnan(delta_E) else "—"
221
+
222
+ speedup_text = f"{speedup:.0f}x" if speedup > 100 else f"{speedup:.1f}x"
223
+
224
+ html += f"""
225
+ <tr class="border-b border-border/50 hover:bg-muted/30 transition-colors">
226
+ <td class="py-3 px-4 font-medium">{model_name}</td>
227
+ <td class="py-3 px-4 text-right {size_class}">{size_mb:.2f}</td>
228
+ <td class="py-3 px-4 text-right {speed_class}">{speed_ms:.4f}</td>
229
+ <td class="py-3 px-4 text-right text-primary font-semibold">{speedup_text}</td>
230
+ <td class="py-3 px-4 text-right">{result["x_mae"]:.6f}</td>
231
+ <td class="py-3 px-4 text-right">{result["y_mae"]:.6f}</td>
232
+ <td class="py-3 px-4 text-right">{result["Y_mae"]:.6f}</td>
233
+ <td class="py-3 px-4 text-right {delta_E_class}">{delta_E_str}</td>
234
+ </tr>
235
+ """
236
+
237
+ html += """
238
+ </tbody>
239
+ </table>
240
+ </div>
241
+ </div>
242
+ """
243
+
244
+ # Add ranking section
245
+ html += generate_ranking_section(
246
+ results,
247
+ metric_key="avg_mae",
248
+ title="Overall Ranking (by Average MAE)",
249
+ )
250
+
251
+ # Precision thresholds
252
+ thresholds = [1e-5, 1e-4, 1e-3, 1e-2, 1e-1]
253
+
254
+ html += """
255
+ <div class="bg-card rounded-lg border border-border p-6 shadow-lg">
256
+ <h2 class="text-2xl font-semibold mb-3 pb-3 border-b border-primary/30">Accuracy at Precision Thresholds</h2>
257
+ <p class="text-sm text-muted-foreground mb-6">Percentage of predictions where max error across all components is below threshold:</p>
258
+ <div class="overflow-x-auto">
259
+ <table class="w-full text-sm">
260
+ <thead>
261
+ <tr class="border-b border-border">
262
+ <th class="text-left py-3 px-4 font-semibold text-muted-foreground">Model</th>
263
+ """
264
+
265
+ for threshold in thresholds:
266
+ html += f' <th class="text-right py-3 px-4 font-semibold text-muted-foreground">&lt; {threshold:.0e}</th>\n'
267
+
268
+ html += """
269
+ </tr>
270
+ </thead>
271
+ <tbody>
272
+ """
273
+
274
+ for model_name, _ in sorted_models:
275
+ result = results[model_name]
276
+ html += f"""
277
+ <tr class="border-b border-border hover:bg-muted/30 transition-colors">
278
+ <td class="text-left py-3 px-4 font-medium">{model_name}</td>
279
+ """
280
+ for threshold in thresholds:
281
+ accuracy_pct = np.mean(result["max_errors"] < threshold) * 100
282
+ html += f' <td class="text-right py-3 px-4">{accuracy_pct:.2f}%</td>\n'
283
+
284
+ html += """
285
+ </tr>
286
+ """
287
+
288
+ html += """
289
+ </tbody>
290
+ </table>
291
+ </div>
292
+ </div>
293
+ """
294
+
295
+ html += generate_html_report_footer()
296
+
297
+ # Write HTML file
298
+ with open(output_file, "w") as f:
299
+ f.write(html)
300
+
301
+ LOGGER.info("")
302
+ LOGGER.info("HTML report saved to: %s", output_file)
303
+
304
+
305
+ def main() -> None:
306
+ """Compare all models."""
307
+ LOGGER.info("=" * 80)
308
+ LOGGER.info("Munsell to xyY Model Comparison")
309
+ LOGGER.info("=" * 80)
310
+
311
+ # Paths
312
+ model_directory = PROJECT_ROOT / "models" / "to_xyY"
313
+
314
+ # Load real Munsell dataset
315
+ LOGGER.info("")
316
+ LOGGER.info("Loading real Munsell dataset...")
317
+ munsell_specs = []
318
+ xyY_ground_truth = []
319
+
320
+ for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL:
321
+ try:
322
+ hue_code, value, chroma = munsell_spec_tuple
323
+ munsell_str = f"{hue_code} {value}/{chroma}"
324
+ spec = munsell_colour_to_munsell_specification(munsell_str)
325
+ xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
326
+ munsell_specs.append(spec)
327
+ xyY_ground_truth.append(xyY_scaled)
328
+ except Exception: # noqa: BLE001, S112
329
+ continue
330
+
331
+ munsell_specs = np.array(munsell_specs, dtype=np.float32)
332
+ xyY_ground_truth = np.array(xyY_ground_truth, dtype=np.float32)
333
+ LOGGER.info("Loaded %d valid Munsell colors", len(munsell_specs))
334
+
335
+ # Normalize inputs
336
+ munsell_normalized = normalize_munsell(munsell_specs)
337
+
338
+ # Benchmark colour library first
339
+ LOGGER.info("")
340
+ LOGGER.info("=" * 80)
341
+ LOGGER.info("Colour Library (munsell_specification_to_xyY)")
342
+ LOGGER.info("=" * 80)
343
+
344
+ # Benchmark the munsell_specification_to_xyY function
345
+ # Note: Using full dataset (100% of samples)
346
+
347
+ # Set random seed for reproducibility
348
+ np.random.seed(42)
349
+
350
+ # Use 100% of samples for comprehensive benchmarking
351
+ sampled_indices = np.arange(len(munsell_specs))
352
+ munsell_benchmark = munsell_specs[sampled_indices]
353
+
354
+ start_time = time.perf_counter()
355
+ colour_predictions = []
356
+ successful_inferences = 0
357
+
358
+ with warnings.catch_warnings():
359
+ warnings.simplefilter("ignore")
360
+ for spec in munsell_benchmark:
361
+ try:
362
+ xyY = munsell_specification_to_xyY(spec)
363
+ colour_predictions.append(xyY)
364
+ successful_inferences += 1
365
+ except (RuntimeError, ValueError):
366
+ colour_predictions.append(np.array([np.nan, np.nan, np.nan]))
367
+
368
+ end_time = time.perf_counter()
369
+
370
+ total_time_s = end_time - start_time
371
+ baseline_inference_time_ms = (
372
+ (total_time_s / successful_inferences) * 1000
373
+ if successful_inferences > 0
374
+ else 0
375
+ )
376
+ colour_predictions = np.array(colour_predictions)
377
+
378
+ LOGGER.info(" Successful inferences: %d", successful_inferences)
379
+ LOGGER.info(" Inference Speed: %.4f ms/sample", baseline_inference_time_ms)
380
+
381
+ # Define models to compare
382
+ models = [
383
+ {
384
+ "name": "Simple MLP",
385
+ "files": [model_directory / "munsell_to_xyY_approximator.onnx"],
386
+ "params_file": model_directory
387
+ / "munsell_to_xyY_approximator_normalization_params.npz",
388
+ "type": "single",
389
+ },
390
+ {
391
+ "name": "Multi-Head",
392
+ "files": [model_directory / "multi_head.onnx"],
393
+ "params_file": model_directory / "multi_head_normalization_params.npz",
394
+ "type": "single",
395
+ },
396
+ {
397
+ "name": "Multi-Head (Optimized)",
398
+ "files": [model_directory / "multi_head_optimized.onnx"],
399
+ "params_file": model_directory
400
+ / "multi_head_optimized_normalization_params.npz",
401
+ "type": "single",
402
+ },
403
+ {
404
+ "name": "Multi-Head + Multi-Error Predictor",
405
+ "files": [
406
+ model_directory / "multi_head.onnx",
407
+ model_directory / "multi_head_multi_error_predictor.onnx",
408
+ ],
409
+ "params_file": model_directory
410
+ / "multi_head_multi_error_predictor_normalization_params.npz",
411
+ "type": "two_stage",
412
+ },
413
+ {
414
+ "name": "Multi-MLP",
415
+ "files": [model_directory / "multi_mlp.onnx"],
416
+ "params_file": model_directory / "multi_mlp_normalization_params.npz",
417
+ "type": "single",
418
+ },
419
+ {
420
+ "name": "Multi-MLP (Optimized)",
421
+ "files": [model_directory / "multi_mlp_optimized.onnx"],
422
+ "params_file": model_directory
423
+ / "multi_mlp_optimized_normalization_params.npz",
424
+ "type": "single",
425
+ },
426
+ {
427
+ "name": "Multi-MLP + Error Predictor",
428
+ "files": [
429
+ model_directory / "multi_mlp.onnx",
430
+ model_directory / "multi_mlp_error_predictor.onnx",
431
+ ],
432
+ "params_file": model_directory
433
+ / "multi_mlp_error_predictor_normalization_params.npz",
434
+ "type": "two_stage",
435
+ },
436
+ {
437
+ "name": "Multi-MLP + Multi-Error Predictor",
438
+ "files": [
439
+ model_directory / "multi_mlp.onnx",
440
+ model_directory / "multi_mlp_multi_error_predictor.onnx",
441
+ ],
442
+ "params_file": model_directory
443
+ / "multi_mlp_multi_error_predictor_normalization_params.npz",
444
+ "type": "two_stage",
445
+ },
446
+ {
447
+ "name": "Multi-MLP (Optimized) + Multi-Error Predictor (Optimized)",
448
+ "files": [
449
+ model_directory / "multi_mlp_optimized.onnx",
450
+ model_directory / "multi_mlp_multi_error_predictor_optimized.onnx",
451
+ ],
452
+ "params_file": model_directory
453
+ / "multi_mlp_multi_error_predictor_optimized_normalization_params.npz",
454
+ "type": "two_stage",
455
+ },
456
+ ]
457
+
458
+ # Evaluate each model
459
+ results = {}
460
+
461
+ for model_info in models:
462
+ model_name = model_info["name"]
463
+ LOGGER.info("")
464
+ LOGGER.info("=" * 80)
465
+ LOGGER.info(model_name)
466
+ LOGGER.info("=" * 80)
467
+
468
+ # Calculate model size
469
+ model_size_mb = get_model_size_mb(model_info["files"])
470
+
471
+ if model_info["type"] == "two_stage":
472
+ # Two-stage model
473
+ base_session = ort.InferenceSession(str(model_info["files"][0]))
474
+ error_session = ort.InferenceSession(str(model_info["files"][1]))
475
+ error_input_name = error_session.get_inputs()[0].name
476
+
477
+ # Define inference callable
478
+ def two_stage_inference(
479
+ _base_session: ort.InferenceSession = base_session,
480
+ _error_session: ort.InferenceSession = error_session,
481
+ _munsell_normalized: NDArray = munsell_normalized,
482
+ _error_input_name: str = error_input_name,
483
+ ) -> NDArray:
484
+ base_pred = _base_session.run(
485
+ None, {"munsell_normalized": _munsell_normalized}
486
+ )[0]
487
+ combined = np.concatenate(
488
+ [_munsell_normalized, base_pred], axis=1
489
+ ).astype(np.float32)
490
+ error_corr = _error_session.run(
491
+ None, {_error_input_name: combined}
492
+ )[0]
493
+ return base_pred + error_corr
494
+
495
+ # Benchmark speed
496
+ inference_time_ms = benchmark_inference_speed(
497
+ two_stage_inference, munsell_normalized
498
+ )
499
+
500
+ # Get predictions
501
+ base_pred = base_session.run(
502
+ None, {"munsell_normalized": munsell_normalized}
503
+ )[0]
504
+ combined = np.concatenate(
505
+ [munsell_normalized, base_pred], axis=1
506
+ ).astype(np.float32)
507
+ error_corr = error_session.run(
508
+ None, {error_input_name: combined}
509
+ )[0]
510
+ pred = base_pred + error_corr
511
+
512
+ errors = np.abs(pred - xyY_ground_truth)
513
+ result = {
514
+ "x_mae": np.mean(errors[:, 0]),
515
+ "y_mae": np.mean(errors[:, 1]),
516
+ "Y_mae": np.mean(errors[:, 2]),
517
+ "predictions": pred,
518
+ "errors": errors,
519
+ "max_errors": np.max(errors, axis=1),
520
+ }
521
+ else:
522
+ # Single model
523
+ session = ort.InferenceSession(str(model_info["files"][0]))
524
+
525
+ # Define inference callable
526
+ def single_inference(
527
+ _session: ort.InferenceSession = session,
528
+ _munsell_normalized: NDArray = munsell_normalized,
529
+ ) -> NDArray:
530
+ return _session.run(
531
+ None, {"munsell_normalized": _munsell_normalized}
532
+ )[0]
533
+
534
+ # Benchmark speed
535
+ inference_time_ms = benchmark_inference_speed(
536
+ single_inference, munsell_normalized
537
+ )
538
+
539
+ result = evaluate_model(session, munsell_normalized, xyY_ground_truth)
540
+
541
+ result["model_size_mb"] = model_size_mb
542
+ result["inference_time_ms"] = inference_time_ms
543
+ result["avg_mae"] = np.mean(
544
+ [result["x_mae"], result["y_mae"], result["Y_mae"]]
545
+ )
546
+
547
+ # Compute Delta-E against ground truth (measured xyY)
548
+ sampled_predictions = result["predictions"][sampled_indices]
549
+ result["delta_E"] = compute_delta_E(
550
+ sampled_predictions,
551
+ xyY_ground_truth,
552
+ )
553
+
554
+ results[model_name] = result
555
+
556
+ # Print results
557
+ LOGGER.info("")
558
+ LOGGER.info("Mean Absolute Errors:")
559
+ LOGGER.info(" x: %.6f", result["x_mae"])
560
+ LOGGER.info(" y: %.6f", result["y_mae"])
561
+ LOGGER.info(" Y: %.6f", result["Y_mae"])
562
+ if not np.isnan(result["delta_E"]):
563
+ LOGGER.info(" Delta-E (vs Ground Truth): %.6f", result["delta_E"])
564
+ LOGGER.info("")
565
+ LOGGER.info("Performance Metrics:")
566
+ LOGGER.info(" Model Size: %.2f MB", result["model_size_mb"])
567
+ LOGGER.info(
568
+ " Inference Speed: %.4f ms/sample", result["inference_time_ms"]
569
+ )
570
+ LOGGER.info(
571
+ " Speedup vs Colour: %.1fx",
572
+ baseline_inference_time_ms / inference_time_ms,
573
+ )
574
+
575
+
576
+ # Summary
577
+ LOGGER.info("")
578
+ LOGGER.info("=" * 80)
579
+ LOGGER.info("SUMMARY COMPARISON")
580
+ LOGGER.info("=" * 80)
581
+ LOGGER.info("")
582
+
583
+ if not results:
584
+ LOGGER.info("No models were successfully evaluated")
585
+ return
586
+
587
+ # MAE comparison table
588
+ LOGGER.info("Mean Absolute Error Comparison:")
589
+ LOGGER.info("")
590
+ header = f"{'Model':<40} {'x':>10} {'y':>10} {'Y':>10} {'Delta-E':>12}"
591
+ LOGGER.info(header)
592
+ LOGGER.info("-" * 85)
593
+
594
+ for model_name, result in results.items():
595
+ delta_E_str = (
596
+ f"{result['delta_E']:.6f}" if not np.isnan(result["delta_E"]) else "N/A"
597
+ )
598
+ LOGGER.info(
599
+ "%-40s %10.6f %10.6f %10.6f %12s",
600
+ model_name,
601
+ result["x_mae"],
602
+ result["y_mae"],
603
+ result["Y_mae"],
604
+ delta_E_str,
605
+ )
606
+
607
+ # Generate HTML report
608
+ report_dir = PROJECT_ROOT / "reports" / "to_xyY"
609
+ report_dir.mkdir(parents=True, exist_ok=True)
610
+ report_file = report_dir / "model_comparison.html"
611
+ generate_html_report(
612
+ results, len(munsell_specs), report_file, baseline_inference_time_ms
613
+ )
614
+
615
+
616
+ if __name__ == "__main__":
617
+ main()
learning_munsell/data_generation/generate_training_data.py ADDED
@@ -0,0 +1,310 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Generate training data for ML-based xyY to Munsell conversion.
3
+
4
+ Generates samples by sampling in Munsell space and converting to xyY via
5
+ forward conversion, guaranteeing 100% valid samples.
6
+
7
+ Usage:
8
+ uv run python -m learning_munsell.data_generation.generate_training_data
9
+ uv run python -m learning_munsell.data_generation.generate_training_data \\
10
+ --n-samples 2000000 --perturbation 0.10 --output training_data_large
11
+ """
12
+
13
+ import argparse
14
+ import json
15
+ import logging
16
+ import multiprocessing as mp
17
+ import warnings
18
+ from datetime import datetime, timezone
19
+
20
+ import numpy as np
21
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_ALL
22
+ from colour.notation.munsell import (
23
+ munsell_colour_to_munsell_specification,
24
+ munsell_specification_to_xyY,
25
+ )
26
+ from colour.utilities import ColourUsageWarning
27
+ from numpy.typing import NDArray
28
+ from sklearn.model_selection import train_test_split
29
+
30
+ from learning_munsell import PROJECT_ROOT
31
+
32
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
33
+ LOGGER = logging.getLogger(__name__)
34
+
35
+
36
+ def _worker_generate_samples(
37
+ args: tuple[int, NDArray, int, float],
38
+ ) -> tuple[list[NDArray], list[NDArray]]:
39
+ """
40
+ Worker function to generate samples in parallel.
41
+
42
+ Parameters
43
+ ----------
44
+ args : tuple
45
+ - worker_id: Worker identifier
46
+ - base_specs: Array of base Munsell specifications
47
+ - samples_per_base: Number of samples to generate per base color
48
+ - perturbation_pct: Perturbation percentage
49
+
50
+ Returns
51
+ -------
52
+ tuple
53
+ - xyY_samples: List of xyY arrays
54
+ - munsell_samples: List of Munsell specification arrays
55
+ """
56
+ worker_id, base_specs, samples_per_base, perturbation_pct = args
57
+
58
+ np.random.seed(42 + worker_id)
59
+
60
+ warnings.filterwarnings("ignore", category=ColourUsageWarning)
61
+ warnings.filterwarnings("ignore", category=RuntimeWarning)
62
+
63
+ xyY_samples = []
64
+ munsell_samples = []
65
+
66
+ hue_range = 9.5
67
+ value_range = 9.0
68
+ chroma_range = 50.0
69
+
70
+ for base_spec in base_specs:
71
+ for _ in range(samples_per_base):
72
+ hue_delta = np.random.uniform(
73
+ -perturbation_pct * hue_range, perturbation_pct * hue_range
74
+ )
75
+ value_delta = np.random.uniform(
76
+ -perturbation_pct * value_range, perturbation_pct * value_range
77
+ )
78
+ chroma_delta = np.random.uniform(
79
+ -perturbation_pct * chroma_range, perturbation_pct * chroma_range
80
+ )
81
+
82
+ perturbed_spec = base_spec.copy()
83
+ perturbed_spec[0] = np.clip(base_spec[0] + hue_delta, 0.5, 10.0)
84
+ perturbed_spec[1] = np.clip(base_spec[1] + value_delta, 1.0, 10.0)
85
+ perturbed_spec[2] = np.clip(base_spec[2] + chroma_delta, 0.0, 50.0)
86
+
87
+ try:
88
+ xyY = munsell_specification_to_xyY(perturbed_spec)
89
+ xyY_samples.append(xyY)
90
+ munsell_samples.append(perturbed_spec)
91
+ except Exception: # noqa: BLE001, S110
92
+ continue
93
+
94
+ return xyY_samples, munsell_samples
95
+
96
+
97
+ def generate_forward_munsell_samples(
98
+ n_samples: int = 500000,
99
+ perturbation_pct: float = 0.05,
100
+ n_workers: int | None = None,
101
+ ) -> tuple[NDArray, NDArray]:
102
+ """
103
+ Generate samples by sampling directly in Munsell space and converting to xyY.
104
+
105
+ Parameters
106
+ ----------
107
+ n_samples : int
108
+ Target number of samples to generate.
109
+ perturbation_pct : float
110
+ Perturbation as percentage of valid range.
111
+ n_workers : int, optional
112
+ Number of parallel workers. Defaults to CPU count.
113
+
114
+ Returns
115
+ -------
116
+ tuple
117
+ - xyY_samples: Array of shape (n, 3) with xyY values
118
+ - munsell_samples: Array of shape (n, 4) with Munsell specifications
119
+ """
120
+ if n_workers is None:
121
+ n_workers = mp.cpu_count()
122
+
123
+ LOGGER.info(
124
+ "Generating %d samples with %.0f%% perturbations using %d workers...",
125
+ n_samples,
126
+ perturbation_pct * 100,
127
+ n_workers,
128
+ )
129
+
130
+ # Extract base Munsell specifications
131
+ base_specs = []
132
+ for munsell_spec_tuple, _ in MUNSELL_COLOURS_ALL:
133
+ hue_code_str, value, chroma = munsell_spec_tuple
134
+ munsell_str = f"{hue_code_str} {value}/{chroma}"
135
+ spec = munsell_colour_to_munsell_specification(munsell_str)
136
+ base_specs.append(spec)
137
+
138
+ base_specs = np.array(base_specs)
139
+ samples_per_base = n_samples // len(base_specs) + 1
140
+
141
+ LOGGER.info("Using %d base Munsell colors", len(base_specs))
142
+ LOGGER.info("Generating ~%d samples per base color", samples_per_base)
143
+
144
+ # Split base specs across workers
145
+ specs_per_worker = len(base_specs) // n_workers
146
+ worker_args = []
147
+
148
+ for i in range(n_workers):
149
+ start_idx = i * specs_per_worker
150
+ end_idx = start_idx + specs_per_worker if i < n_workers - 1 else len(base_specs)
151
+ worker_specs = base_specs[start_idx:end_idx]
152
+ worker_args.append((i, worker_specs, samples_per_base, perturbation_pct))
153
+
154
+ # Run in parallel
155
+ LOGGER.info("Starting %d parallel workers...", n_workers)
156
+ with mp.Pool(n_workers) as pool:
157
+ results = pool.map(_worker_generate_samples, worker_args)
158
+
159
+ # Combine results
160
+ all_xyY = []
161
+ all_munsell = []
162
+ for xyY_samples, munsell_samples in results:
163
+ all_xyY.extend(xyY_samples)
164
+ all_munsell.extend(munsell_samples)
165
+
166
+ # Trim to exact sample count
167
+ all_xyY = all_xyY[:n_samples]
168
+ all_munsell = all_munsell[:n_samples]
169
+
170
+ LOGGER.info("Generated %d valid samples", len(all_xyY))
171
+ return np.array(all_xyY), np.array(all_munsell)
172
+
173
+
174
+ def main(
175
+ n_samples: int = 500000,
176
+ perturbation_pct: float = 0.05,
177
+ output: str = "training_data",
178
+ ) -> None:
179
+ """Generate and save training data."""
180
+ LOGGER.info("=" * 80)
181
+ LOGGER.info("Training Data Generation")
182
+ LOGGER.info("=" * 80)
183
+
184
+ output_dir = PROJECT_ROOT / "data"
185
+ output_dir.mkdir(exist_ok=True)
186
+
187
+ LOGGER.info("")
188
+ LOGGER.info("SAMPLING STRATEGY")
189
+ LOGGER.info("=" * 80)
190
+ LOGGER.info("Forward Munsell->xyY sampling:")
191
+ LOGGER.info(
192
+ " - Base: %d colors from MUNSELL_COLOURS_ALL", len(MUNSELL_COLOURS_ALL)
193
+ )
194
+ LOGGER.info(
195
+ " - Perturbations: +/-%.0f%% of valid range per component",
196
+ perturbation_pct * 100,
197
+ )
198
+ LOGGER.info(
199
+ " - Hue: +/-%.2f (+/-%.0f%% of 9.5 range)",
200
+ perturbation_pct * 9.5,
201
+ perturbation_pct * 100,
202
+ )
203
+ LOGGER.info(
204
+ " - Value: +/-%.2f (+/-%.0f%% of 9.0 range)",
205
+ perturbation_pct * 9.0,
206
+ perturbation_pct * 100,
207
+ )
208
+ LOGGER.info(
209
+ " - Chroma: +/-%.1f (+/-%.0f%% of 50.0 range)",
210
+ perturbation_pct * 50.0,
211
+ perturbation_pct * 100,
212
+ )
213
+ LOGGER.info(" - Target samples: %d", n_samples)
214
+ LOGGER.info("=" * 80)
215
+ LOGGER.info("")
216
+
217
+ # Generate samples
218
+ xyY_all, munsell_all = generate_forward_munsell_samples(
219
+ n_samples=n_samples,
220
+ perturbation_pct=perturbation_pct,
221
+ )
222
+
223
+ valid_mask = np.ones(len(xyY_all), dtype=bool)
224
+
225
+ LOGGER.info("")
226
+ LOGGER.info("Sample statistics:")
227
+ LOGGER.info(" Total samples generated: %d", len(xyY_all))
228
+ LOGGER.info(" All samples are valid (100%% by forward conversion)")
229
+
230
+ LOGGER.info("")
231
+ LOGGER.info("Using %d valid samples for training", len(xyY_all))
232
+
233
+ # Split into train/validation/test (70/15/15)
234
+ X_temp, X_test, y_temp, y_test = train_test_split(
235
+ xyY_all, munsell_all, test_size=0.15, random_state=42
236
+ )
237
+ X_train, X_val, y_train, y_val = train_test_split(
238
+ X_temp, y_temp, test_size=0.15 / 0.85, random_state=42
239
+ )
240
+
241
+ LOGGER.info("")
242
+ LOGGER.info("Data split:")
243
+ LOGGER.info(" Train: %d samples", len(X_train))
244
+ LOGGER.info(" Validation: %d samples", len(X_val))
245
+ LOGGER.info(" Test: %d samples", len(X_test))
246
+
247
+ # Save training data
248
+ cache_file = output_dir / f"{output}.npz"
249
+ np.savez_compressed(
250
+ cache_file,
251
+ X_train=X_train,
252
+ y_train=y_train,
253
+ X_val=X_val,
254
+ y_val=y_val,
255
+ X_test=X_test,
256
+ y_test=y_test,
257
+ xyY_all=xyY_all,
258
+ munsell_all=munsell_all,
259
+ valid_mask=valid_mask,
260
+ )
261
+
262
+ # Save parameters to sidecar file
263
+ params_file = output_dir / f"{output}_params.json"
264
+ params = {
265
+ "n_samples": n_samples,
266
+ "perturbation_pct": perturbation_pct,
267
+ "n_base_colors": len(MUNSELL_COLOURS_ALL),
268
+ "train_samples": len(X_train),
269
+ "val_samples": len(X_val),
270
+ "test_samples": len(X_test),
271
+ "generated_at": datetime.now(timezone.utc).isoformat(),
272
+ }
273
+ with open(params_file, "w") as f:
274
+ json.dump(params, f, indent=2)
275
+
276
+ LOGGER.info("")
277
+ LOGGER.info("Training data saved to: %s", cache_file)
278
+ LOGGER.info("Parameters saved to: %s", params_file)
279
+ LOGGER.info("=" * 80)
280
+
281
+
282
+ if __name__ == "__main__":
283
+ parser = argparse.ArgumentParser(
284
+ description="Generate training data for xyY to Munsell conversion"
285
+ )
286
+ parser.add_argument(
287
+ "--n-samples",
288
+ type=int,
289
+ default=500000,
290
+ help="Number of samples to generate (default: 500000)",
291
+ )
292
+ parser.add_argument(
293
+ "--perturbation",
294
+ type=float,
295
+ default=0.05,
296
+ help="Perturbation as fraction of valid range (default: 0.05)",
297
+ )
298
+ parser.add_argument(
299
+ "--output",
300
+ type=str,
301
+ default="training_data",
302
+ help="Output filename without extension (default: training_data)",
303
+ )
304
+ args = parser.parse_args()
305
+
306
+ main(
307
+ n_samples=args.n_samples,
308
+ perturbation_pct=args.perturbation,
309
+ output=args.output,
310
+ )
learning_munsell/interpolation/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Interpolation-based methods for Munsell conversions."""
learning_munsell/interpolation/from_xyY/__init__.py ADDED
@@ -0,0 +1,43 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Interpolation-based methods for xyY to Munsell conversions."""
2
+
3
+ import numpy as np
4
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_ALL
5
+ from colour.notation.munsell import munsell_colour_to_munsell_specification
6
+ from numpy.typing import NDArray
7
+
8
+
9
+ def load_munsell_reference_data() -> tuple[NDArray, NDArray]:
10
+ """
11
+ Load reference Munsell data from colour library.
12
+
13
+ Returns xyY coordinates and corresponding Munsell specifications
14
+ [hue, value, chroma, code] for all 4,995 reference colors.
15
+
16
+ The Y values are normalized to [0, 1] range (originally 0-102.57).
17
+
18
+ Returns
19
+ -------
20
+ Tuple[NDArray, NDArray]
21
+ X : xyY values of shape (4995, 3) with Y normalized to [0, 1]
22
+ y : Munsell specifications of shape (4995, 4)
23
+ """
24
+ xyY_list = []
25
+ munsell_list = []
26
+
27
+ for munsell_tuple, xyY in MUNSELL_COLOURS_ALL:
28
+ hue_name, value, chroma = munsell_tuple
29
+ munsell_string = f"{hue_name} {value}/{chroma}"
30
+
31
+ # Convert to numeric specification [hue, value, chroma, code]
32
+ spec = munsell_colour_to_munsell_specification(munsell_string)
33
+
34
+ # Normalize Y to [0, 1] range (max ~102.57)
35
+ xyY_normalized = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
36
+
37
+ xyY_list.append(xyY_normalized)
38
+ munsell_list.append(spec)
39
+
40
+ return np.array(xyY_list), np.array(munsell_list)
41
+
42
+
43
+ __all__ = ["load_munsell_reference_data"]
learning_munsell/interpolation/from_xyY/compare_methods.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Compare classical interpolation methods against the best ML model.
3
+
4
+ Evaluates RBF, KD-Tree, and Delaunay interpolation on REAL Munsell colors
5
+ and compares with the Multi-Head (W+B) + Multi-Error Predictor (W+B) model.
6
+ """
7
+
8
+ import logging
9
+
10
+ import numpy as np
11
+ import onnxruntime as ort
12
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_ALL
13
+ from colour.notation.munsell import munsell_colour_to_munsell_specification
14
+ from scipy.interpolate import LinearNDInterpolator, RBFInterpolator
15
+ from scipy.spatial import KDTree
16
+ from sklearn.model_selection import train_test_split
17
+
18
+ from learning_munsell import PROJECT_ROOT
19
+
20
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
21
+ LOGGER = logging.getLogger(__name__)
22
+
23
+
24
+ def load_reference_data():
25
+ """Load ALL Munsell colors as training data for interpolators."""
26
+ X, y = [], []
27
+ for munsell_tuple, xyY in MUNSELL_COLOURS_ALL:
28
+ hue_name, value, chroma = munsell_tuple
29
+ munsell_str = f"{hue_name} {value}/{chroma}"
30
+ spec = munsell_colour_to_munsell_specification(munsell_str)
31
+ # Normalize Y to [0, 1]
32
+ X.append([xyY[0], xyY[1], xyY[2] / 100.0])
33
+ y.append(spec)
34
+ return np.array(X), np.array(y)
35
+
36
+
37
+
38
+
39
+ def evaluate(predictions, y_true, method_name):
40
+ """Calculate MAE for each component."""
41
+ errors = np.abs(predictions - y_true)
42
+ results = {
43
+ "hue": errors[:, 0].mean(),
44
+ "value": errors[:, 1].mean(),
45
+ "chroma": errors[:, 2].mean(),
46
+ "code": errors[:, 3].mean(),
47
+ }
48
+ LOGGER.info(" %s:", method_name)
49
+ for comp in ["hue", "value", "chroma", "code"]:
50
+ LOGGER.info(" %s MAE: %.4f", comp.capitalize(), results[comp])
51
+ return results
52
+
53
+
54
+ def rbf_predict(X_train, y_train, X_test):
55
+ """RBF interpolation prediction."""
56
+ predictions = np.zeros((len(X_test), 4))
57
+ for i in range(4):
58
+ rbf = RBFInterpolator(X_train, y_train[:, i], kernel="thin_plate_spline")
59
+ predictions[:, i] = rbf(X_test)
60
+ return predictions
61
+
62
+
63
+ def kdtree_predict(X_train, y_train, X_test, k=5):
64
+ """KD-Tree with inverse distance weighting prediction."""
65
+ tree = KDTree(X_train)
66
+ distances, indices = tree.query(X_test, k=k)
67
+ distances = np.maximum(distances, 1e-10)
68
+ weights = 1.0 / (distances**2)
69
+ weights /= weights.sum(axis=1, keepdims=True)
70
+
71
+ predictions = np.zeros((len(X_test), 4))
72
+ for i in range(len(X_test)):
73
+ predictions[i] = np.sum(weights[i, :, np.newaxis] * y_train[indices[i]], axis=0)
74
+ return predictions
75
+
76
+
77
+ def delaunay_predict(X_train, y_train, X_test):
78
+ """Delaunay interpolation with NN fallback."""
79
+ predictions = np.zeros((len(X_test), 4))
80
+ tree = KDTree(X_train)
81
+
82
+ for i in range(4):
83
+ interp = LinearNDInterpolator(X_train, y_train[:, i])
84
+ predictions[:, i] = interp(X_test)
85
+
86
+ # Fallback to nearest neighbor for NaN
87
+ nan_mask = np.any(np.isnan(predictions), axis=1)
88
+ if nan_mask.sum() > 0:
89
+ _, indices = tree.query(X_test[nan_mask])
90
+ predictions[nan_mask] = y_train[indices]
91
+
92
+ return predictions
93
+
94
+
95
+ def ml_predict(X_test):
96
+ """ML model prediction using base + error predictor."""
97
+ base_path = PROJECT_ROOT / "models" / "from_xyY" / "multi_head_weighted_boundary.onnx"
98
+ error_path = (
99
+ PROJECT_ROOT
100
+ / "models"
101
+ / "from_xyY"
102
+ / "multi_head_weighted_boundary_multi_error_predictor_weighted_boundary.onnx"
103
+ )
104
+
105
+ if not base_path.exists() or not error_path.exists():
106
+ return None
107
+
108
+ # Input is already normalized to [0, 1] for x, y, Y
109
+ X_norm = X_test.astype(np.float32)
110
+
111
+ # Base model prediction
112
+ base_session = ort.InferenceSession(str(base_path))
113
+ base_out = base_session.run(None, {"xyY": X_norm})[0]
114
+
115
+ # Error predictor (takes xyY + base predictions)
116
+ error_session = ort.InferenceSession(str(error_path))
117
+ combined_input = np.concatenate([X_norm, base_out], axis=1).astype(np.float32)
118
+ error_out = error_session.run(None, {"combined_input": combined_input})[0]
119
+
120
+ # Combined prediction (normalized)
121
+ pred_norm = base_out + error_out
122
+
123
+ # Denormalize using actual ranges from params file
124
+ predictions = np.zeros_like(pred_norm)
125
+ predictions[:, 0] = pred_norm[:, 0] * (10.0 - 0.5) + 0.5 # Hue: [0.5, 10]
126
+ predictions[:, 1] = pred_norm[:, 1] * (10.0 - 0.0) + 0.0 # Value: [0, 10]
127
+ predictions[:, 2] = pred_norm[:, 2] * (50.0 - 0.0) + 0.0 # Chroma: [0, 50]
128
+ predictions[:, 3] = pred_norm[:, 3] * (10.0 - 1.0) + 1.0 # Code: [1, 10]
129
+
130
+ return predictions
131
+
132
+
133
+ def main():
134
+ """Compare all methods using held-out test set."""
135
+ LOGGER.info("=" * 80)
136
+ LOGGER.info("Classical Interpolation vs ML Model Comparison")
137
+ LOGGER.info("=" * 80)
138
+
139
+ LOGGER.info("")
140
+ LOGGER.info("Loading data...")
141
+ X_all, y_all = load_reference_data()
142
+
143
+ # 80/20 train/test split for fair comparison
144
+ X_train, X_test, y_train, y_test = train_test_split(
145
+ X_all, y_all, test_size=0.2, random_state=42
146
+ )
147
+ LOGGER.info(" Total: %d colors", len(X_all))
148
+ LOGGER.info(" Training: %d colors (80%%)", len(X_train))
149
+ LOGGER.info(" Test: %d colors (20%%)", len(X_test))
150
+
151
+ results = {}
152
+
153
+ # RBF
154
+ LOGGER.info("")
155
+ LOGGER.info("-" * 60)
156
+ LOGGER.info("RBF Interpolation (thin_plate_spline)")
157
+ rbf_pred = rbf_predict(X_train, y_train, X_test)
158
+ results["RBF"] = evaluate(rbf_pred, y_test, "RBF")
159
+
160
+ # KD-Tree
161
+ LOGGER.info("")
162
+ LOGGER.info("-" * 60)
163
+ LOGGER.info("KD-Tree Interpolation (k=5, IDW)")
164
+ kdt_pred = kdtree_predict(X_train, y_train, X_test, k=5)
165
+ results["KD-Tree"] = evaluate(kdt_pred, y_test, "KD-Tree")
166
+
167
+ # Delaunay
168
+ LOGGER.info("")
169
+ LOGGER.info("-" * 60)
170
+ LOGGER.info("Delaunay Interpolation (with NN fallback)")
171
+ del_pred = delaunay_predict(X_train, y_train, X_test)
172
+ results["Delaunay"] = evaluate(del_pred, y_test, "Delaunay")
173
+
174
+ # ML
175
+ LOGGER.info("")
176
+ LOGGER.info("-" * 60)
177
+ LOGGER.info("ML Model (Multi-Head W+B + Multi-Error Predictor W+B)")
178
+ ml_pred = ml_predict(X_test)
179
+ if ml_pred is not None:
180
+ results["ML"] = evaluate(ml_pred, y_test, "ML")
181
+ else:
182
+ LOGGER.info(" Skipped (model not found)")
183
+
184
+ # Summary
185
+ LOGGER.info("")
186
+ LOGGER.info("=" * 80)
187
+ LOGGER.info("SUMMARY (MAE on %d held-out test colors)", len(X_test))
188
+ LOGGER.info("=" * 80)
189
+ LOGGER.info("")
190
+ LOGGER.info("%-12s %8s %8s %8s %8s", "Method", "Hue", "Value", "Chroma", "Code")
191
+ LOGGER.info("-" * 52)
192
+
193
+ for method, mae in results.items():
194
+ LOGGER.info(
195
+ "%-12s %8.4f %8.4f %8.4f %8.4f",
196
+ method,
197
+ mae["hue"],
198
+ mae["value"],
199
+ mae["chroma"],
200
+ mae["code"],
201
+ )
202
+
203
+ LOGGER.info("")
204
+ LOGGER.info("=" * 80)
205
+
206
+
207
+ if __name__ == "__main__":
208
+ main()
learning_munsell/interpolation/from_xyY/delaunay_interpolator.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Delaunay triangulation based interpolation for xyY to Munsell conversion.
3
+
4
+ This approach uses scipy's LinearNDInterpolator which performs piecewise
5
+ linear interpolation based on Delaunay triangulation.
6
+
7
+ Uses the 4,995 reference colors from MUNSELL_COLOURS_ALL directly.
8
+
9
+ Advantages:
10
+ - Piecewise linear: exact at data points, linear between
11
+ - Handles irregular point distributions
12
+ - No hyperparameters to tune
13
+
14
+ Disadvantages:
15
+ - Returns NaN outside convex hull of data points
16
+ - Non-convex Munsell boundary may cause issues
17
+ - C0 continuous only (discontinuous gradients at cell boundaries)
18
+ """
19
+
20
+ import logging
21
+ import pickle
22
+ from pathlib import Path
23
+
24
+ import numpy as np
25
+ from numpy.typing import NDArray
26
+ from scipy.interpolate import LinearNDInterpolator
27
+ from scipy.spatial import KDTree
28
+ from sklearn.model_selection import train_test_split
29
+
30
+ from learning_munsell import PROJECT_ROOT, setup_logging
31
+ from learning_munsell.interpolation.from_xyY import load_munsell_reference_data
32
+
33
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
34
+ LOGGER = logging.getLogger(__name__)
35
+
36
+
37
+ class MunsellDelaunayInterpolator:
38
+ """
39
+ Delaunay triangulation based interpolator for xyY to Munsell conversion.
40
+
41
+ Uses LinearNDInterpolator for piecewise linear interpolation within
42
+ the Delaunay triangulation. Falls back to nearest neighbor for points
43
+ outside the convex hull.
44
+ """
45
+
46
+ def __init__(self, fallback_to_nearest: bool = True) -> None:
47
+ """
48
+ Initialize the Delaunay interpolator.
49
+
50
+ Parameters
51
+ ----------
52
+ fallback_to_nearest
53
+ If True, use nearest neighbor for points outside convex hull.
54
+ If False, return NaN for such points.
55
+ """
56
+ self.fallback_to_nearest = fallback_to_nearest
57
+ self.interpolators: dict = {}
58
+ self.kdtree: KDTree | None = None
59
+ self.y_data: NDArray | None = None
60
+ self.fitted = False
61
+
62
+ def fit(self, X: NDArray, y: NDArray) -> "MunsellDelaunayInterpolator":
63
+ """
64
+ Build the Delaunay interpolator from training data.
65
+
66
+ Parameters
67
+ ----------
68
+ X
69
+ xyY input values of shape (n, 3)
70
+ y
71
+ Munsell output values [hue, value, chroma, code] of shape (n, 4)
72
+
73
+ Returns
74
+ -------
75
+ self
76
+ """
77
+ LOGGER.info("Building Delaunay interpolator...")
78
+ LOGGER.info(" Fallback to nearest: %s", self.fallback_to_nearest)
79
+ LOGGER.info(" Data points: %d", len(X))
80
+
81
+ component_names = ["hue", "value", "chroma", "code"]
82
+
83
+ for i, name in enumerate(component_names):
84
+ LOGGER.info(" Building %s interpolator...", name)
85
+ self.interpolators[name] = LinearNDInterpolator(X, y[:, i])
86
+
87
+ # Build KDTree for nearest neighbor fallback
88
+ if self.fallback_to_nearest:
89
+ LOGGER.info(" Building KD-Tree for fallback...")
90
+ self.kdtree = KDTree(X)
91
+ self.y_data = y.copy()
92
+
93
+ self.fitted = True
94
+ LOGGER.info("Delaunay interpolator built successfully")
95
+ return self
96
+
97
+ def predict(self, X: NDArray) -> NDArray:
98
+ """
99
+ Predict Munsell values using Delaunay interpolation.
100
+
101
+ Parameters
102
+ ----------
103
+ X
104
+ xyY input values of shape (n, 3)
105
+
106
+ Returns
107
+ -------
108
+ NDArray
109
+ Predicted Munsell values [hue, value, chroma, code] of shape (n, 4)
110
+ """
111
+ if not self.fitted:
112
+ msg = "Interpolator not fitted. Call fit() first."
113
+ raise RuntimeError(msg)
114
+
115
+ results = np.zeros((len(X), 4))
116
+
117
+ for i, name in enumerate(["hue", "value", "chroma", "code"]):
118
+ results[:, i] = self.interpolators[name](X)
119
+
120
+ # Handle NaN values (points outside convex hull)
121
+ if self.fallback_to_nearest:
122
+ nan_mask = np.any(np.isnan(results), axis=1)
123
+ n_nan = nan_mask.sum()
124
+
125
+ if n_nan > 0:
126
+ LOGGER.debug(" %d points outside hull, using nearest neighbor", n_nan)
127
+ # Find nearest neighbors for NaN points
128
+ _, indices = self.kdtree.query(X[nan_mask])
129
+ results[nan_mask] = self.y_data[indices]
130
+
131
+ return results
132
+
133
+ def save(self, path: Path) -> None:
134
+ """Save the interpolator to disk."""
135
+ with open(path, "wb") as f:
136
+ pickle.dump(
137
+ {
138
+ "fallback_to_nearest": self.fallback_to_nearest,
139
+ "interpolators": self.interpolators,
140
+ "kdtree": self.kdtree,
141
+ "y_data": self.y_data,
142
+ },
143
+ f,
144
+ )
145
+ LOGGER.info("Saved Delaunay interpolator to %s", path)
146
+
147
+ @classmethod
148
+ def load(cls, path: Path) -> "MunsellDelaunayInterpolator":
149
+ """Load the interpolator from disk."""
150
+ with open(path, "rb") as f:
151
+ data = pickle.load(f) # noqa: S301
152
+
153
+ instance = cls(fallback_to_nearest=data["fallback_to_nearest"])
154
+ instance.interpolators = data["interpolators"]
155
+ instance.kdtree = data["kdtree"]
156
+ instance.y_data = data["y_data"]
157
+ instance.fitted = True
158
+
159
+ LOGGER.info("Loaded Delaunay interpolator from %s", path)
160
+ return instance
161
+
162
+
163
+ def evaluate_delaunay(
164
+ interpolator: MunsellDelaunayInterpolator,
165
+ X: NDArray,
166
+ y: NDArray,
167
+ name: str = "Test",
168
+ ) -> dict:
169
+ """Evaluate Delaunay interpolator performance."""
170
+ predictions = interpolator.predict(X)
171
+
172
+ # Check for NaN values
173
+ nan_count = np.isnan(predictions).any(axis=1).sum()
174
+ if nan_count > 0:
175
+ LOGGER.warning(" %d/%d predictions contain NaN", nan_count, len(X))
176
+
177
+ # Filter out NaN for error calculation
178
+ valid_mask = ~np.isnan(predictions).any(axis=1)
179
+ if valid_mask.sum() == 0:
180
+ LOGGER.error(" All predictions are NaN!")
181
+ return {
182
+ "hue": float("nan"),
183
+ "value": float("nan"),
184
+ "chroma": float("nan"),
185
+ "code": float("nan"),
186
+ }
187
+
188
+ errors = np.abs(predictions[valid_mask] - y[valid_mask])
189
+
190
+ component_names = ["Hue", "Value", "Chroma", "Code"]
191
+ results = {}
192
+
193
+ LOGGER.info("%s set MAE (%d/%d valid):", name, valid_mask.sum(), len(X))
194
+ for i, comp_name in enumerate(component_names):
195
+ mae = errors[:, i].mean()
196
+ results[comp_name.lower()] = mae
197
+ LOGGER.info(" %s: %.4f", comp_name, mae)
198
+
199
+ return results
200
+
201
+
202
+ def main() -> None:
203
+ """Build and evaluate Delaunay interpolator using reference Munsell data."""
204
+
205
+ log_file = setup_logging("delaunay_interpolator", "from_xyY")
206
+
207
+ LOGGER.info("=" * 80)
208
+ LOGGER.info("Delaunay Interpolation for xyY to Munsell Conversion")
209
+ LOGGER.info("Using MUNSELL_COLOURS_ALL reference data (4,995 colors)")
210
+ LOGGER.info("=" * 80)
211
+
212
+ # Load reference data from colour library
213
+ LOGGER.info("")
214
+ LOGGER.info("Loading reference Munsell data...")
215
+ X_all, y_all = load_munsell_reference_data()
216
+ LOGGER.info("Total reference colors: %d", len(X_all))
217
+
218
+ # Split into train/validation (80/20)
219
+ X_train, X_val, y_train, y_val = train_test_split(
220
+ X_all, y_all, test_size=0.2, random_state=42
221
+ )
222
+
223
+ LOGGER.info("Train samples: %d", len(X_train))
224
+ LOGGER.info("Validation samples: %d", len(X_val))
225
+
226
+ # Test with and without fallback
227
+ LOGGER.info("")
228
+ LOGGER.info("Testing Delaunay interpolation...")
229
+ LOGGER.info("-" * 60)
230
+
231
+ best_config = None
232
+ best_mae = float("inf")
233
+
234
+ for fallback in [True, False]:
235
+ LOGGER.info("")
236
+ LOGGER.info("Fallback to nearest: %s", fallback)
237
+
238
+ interpolator = MunsellDelaunayInterpolator(fallback_to_nearest=fallback)
239
+ interpolator.fit(X_train, y_train)
240
+
241
+ results = evaluate_delaunay(interpolator, X_val, y_val, "Validation")
242
+
243
+ # Skip if results contain NaN
244
+ if any(np.isnan(v) for v in results.values()):
245
+ LOGGER.info(" Skipping due to NaN results")
246
+ continue
247
+
248
+ total_mae = sum(results.values())
249
+
250
+ if total_mae < best_mae:
251
+ best_mae = total_mae
252
+ best_config = fallback
253
+
254
+ LOGGER.info("")
255
+ LOGGER.info("=" * 60)
256
+ LOGGER.info("Best configuration: fallback_to_nearest=%s", best_config)
257
+ LOGGER.info("=" * 60)
258
+
259
+ # Train final model on ALL data
260
+ LOGGER.info("")
261
+ LOGGER.info("Training final model on all %d reference colors...", len(X_all))
262
+
263
+ final_interpolator = MunsellDelaunayInterpolator(fallback_to_nearest=best_config)
264
+ final_interpolator.fit(X_all, y_all)
265
+
266
+ LOGGER.info("")
267
+ LOGGER.info("Final evaluation (training set = all data):")
268
+ evaluate_delaunay(final_interpolator, X_all, y_all, "All data")
269
+
270
+ # Save the model
271
+ model_dir = PROJECT_ROOT / "models" / "from_xyY"
272
+ model_dir.mkdir(parents=True, exist_ok=True)
273
+ model_path = model_dir / "delaunay_interpolator.pkl"
274
+ final_interpolator.save(model_path)
275
+
276
+ LOGGER.info("")
277
+ LOGGER.info("=" * 80)
278
+
279
+ log_file.close()
280
+
281
+
282
+ if __name__ == "__main__":
283
+ main()
learning_munsell/interpolation/from_xyY/kdtree_interpolator.py ADDED
@@ -0,0 +1,263 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ KD-Tree based interpolation for xyY to Munsell conversion.
3
+
4
+ This approach uses scipy's KDTree for fast nearest neighbor lookups,
5
+ with optional weighted interpolation using k nearest neighbors.
6
+
7
+ Uses the 4,995 reference colors from MUNSELL_COLOURS_ALL directly.
8
+
9
+ Advantages over RBF:
10
+ - O(n) memory, O(log n) query time
11
+ - Scales to millions of data points
12
+ - No matrix inversion required
13
+
14
+ Advantages over ML:
15
+ - Deterministic
16
+ - No training required
17
+ - Easy to understand
18
+ """
19
+
20
+ import logging
21
+ import pickle
22
+ from pathlib import Path
23
+
24
+ import numpy as np
25
+ from numpy.typing import NDArray
26
+ from scipy.spatial import KDTree
27
+ from sklearn.model_selection import train_test_split
28
+
29
+ from learning_munsell import PROJECT_ROOT, setup_logging
30
+ from learning_munsell.interpolation.from_xyY import load_munsell_reference_data
31
+
32
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
33
+ LOGGER = logging.getLogger(__name__)
34
+
35
+
36
+ class MunsellKDTreeInterpolator:
37
+ """
38
+ KD-Tree based interpolator for xyY to Munsell conversion.
39
+
40
+ Uses k-nearest neighbors with inverse distance weighting
41
+ for smooth interpolation.
42
+ """
43
+
44
+ def __init__(self, k: int = 5, power: float = 2.0) -> None:
45
+ """
46
+ Initialize the KD-Tree interpolator.
47
+
48
+ Parameters
49
+ ----------
50
+ k
51
+ Number of nearest neighbors to use for interpolation.
52
+ power
53
+ Power for inverse distance weighting. Higher = sharper.
54
+ """
55
+ self.k = k
56
+ self.power = power
57
+ self.tree: KDTree | None = None
58
+ self.y_data: NDArray | None = None
59
+ self.fitted = False
60
+
61
+ def fit(self, X: NDArray, y: NDArray) -> "MunsellKDTreeInterpolator":
62
+ """
63
+ Build the KD-Tree from training data.
64
+
65
+ Parameters
66
+ ----------
67
+ X
68
+ xyY input values of shape (n, 3)
69
+ y
70
+ Munsell output values [hue, value, chroma, code] of shape (n, 4)
71
+
72
+ Returns
73
+ -------
74
+ self
75
+ """
76
+ LOGGER.info("Building KD-Tree interpolator...")
77
+ LOGGER.info(" k neighbors: %d", self.k)
78
+ LOGGER.info(" IDW power: %.1f", self.power)
79
+ LOGGER.info(" Data points: %d", len(X))
80
+
81
+ self.tree = KDTree(X)
82
+ self.y_data = y.copy()
83
+ self.fitted = True
84
+
85
+ LOGGER.info("KD-Tree built successfully")
86
+ return self
87
+
88
+ def predict(self, X: NDArray) -> NDArray:
89
+ """
90
+ Predict Munsell values using k-NN with IDW.
91
+
92
+ Parameters
93
+ ----------
94
+ X
95
+ xyY input values of shape (n, 3)
96
+
97
+ Returns
98
+ -------
99
+ NDArray
100
+ Predicted Munsell values [hue, value, chroma, code] of shape (n, 4)
101
+ """
102
+ if not self.fitted:
103
+ msg = "Interpolator not fitted. Call fit() first."
104
+ raise RuntimeError(msg)
105
+
106
+ # Query k nearest neighbors
107
+ distances, indices = self.tree.query(X, k=self.k)
108
+
109
+ # Ensure 2D arrays for consistent handling
110
+ if self.k == 1:
111
+ distances = distances.reshape(-1, 1)
112
+ indices = indices.reshape(-1, 1)
113
+
114
+ # Inverse distance weighting
115
+ # Avoid division by zero
116
+ distances = np.maximum(distances, 1e-10)
117
+ weights = 1.0 / (distances**self.power)
118
+ weights /= weights.sum(axis=1, keepdims=True)
119
+
120
+ # Weighted average of neighbor values
121
+ results = np.zeros((len(X), 4))
122
+ for i in range(len(X)):
123
+ neighbor_values = self.y_data[indices[i]]
124
+ if self.k == 1:
125
+ results[i] = neighbor_values.flatten()
126
+ else:
127
+ results[i] = np.sum(weights[i, :, np.newaxis] * neighbor_values, axis=0)
128
+
129
+ return results
130
+
131
+ def save(self, path: Path) -> None:
132
+ """Save the interpolator to disk."""
133
+ with open(path, "wb") as f:
134
+ pickle.dump(
135
+ {
136
+ "k": self.k,
137
+ "power": self.power,
138
+ "tree": self.tree,
139
+ "y_data": self.y_data,
140
+ },
141
+ f,
142
+ )
143
+ LOGGER.info("Saved KD-Tree interpolator to %s", path)
144
+
145
+ @classmethod
146
+ def load(cls, path: Path) -> "MunsellKDTreeInterpolator":
147
+ """Load the interpolator from disk."""
148
+ with open(path, "rb") as f:
149
+ data = pickle.load(f) # noqa: S301
150
+
151
+ instance = cls(k=data["k"], power=data["power"])
152
+ instance.tree = data["tree"]
153
+ instance.y_data = data["y_data"]
154
+ instance.fitted = True
155
+
156
+ LOGGER.info("Loaded KD-Tree interpolator from %s", path)
157
+ return instance
158
+
159
+
160
+ def evaluate_kdtree(
161
+ interpolator: MunsellKDTreeInterpolator,
162
+ X: NDArray,
163
+ y: NDArray,
164
+ name: str = "Test",
165
+ ) -> dict:
166
+ """Evaluate KD-Tree interpolator performance."""
167
+ predictions = interpolator.predict(X)
168
+ errors = np.abs(predictions - y)
169
+
170
+ component_names = ["Hue", "Value", "Chroma", "Code"]
171
+ results = {}
172
+
173
+ LOGGER.info("%s set MAE:", name)
174
+ for i, comp_name in enumerate(component_names):
175
+ mae = errors[:, i].mean()
176
+ results[comp_name.lower()] = mae
177
+ LOGGER.info(" %s: %.4f", comp_name, mae)
178
+
179
+ return results
180
+
181
+
182
+ def main() -> None:
183
+ """Build and evaluate KD-Tree interpolator using reference Munsell data."""
184
+
185
+ log_file = setup_logging("kdtree_interpolator", "from_xyY")
186
+
187
+ LOGGER.info("=" * 80)
188
+ LOGGER.info("KD-Tree Interpolation for xyY to Munsell Conversion")
189
+ LOGGER.info("Using MUNSELL_COLOURS_ALL reference data (4,995 colors)")
190
+ LOGGER.info("=" * 80)
191
+
192
+ # Load reference data from colour library
193
+ LOGGER.info("")
194
+ LOGGER.info("Loading reference Munsell data...")
195
+ X_all, y_all = load_munsell_reference_data()
196
+ LOGGER.info("Total reference colors: %d", len(X_all))
197
+
198
+ # Split into train/validation (80/20)
199
+ X_train, X_val, y_train, y_val = train_test_split(
200
+ X_all, y_all, test_size=0.2, random_state=42
201
+ )
202
+
203
+ LOGGER.info("Train samples: %d", len(X_train))
204
+ LOGGER.info("Validation samples: %d", len(X_val))
205
+
206
+ # Test different k values
207
+ k_values = [1, 3, 5, 10, 20, 50]
208
+
209
+ best_k = None
210
+ best_mae = float("inf")
211
+
212
+ LOGGER.info("")
213
+ LOGGER.info("Testing different k values...")
214
+ LOGGER.info("-" * 60)
215
+
216
+ for k in k_values:
217
+ LOGGER.info("")
218
+ LOGGER.info("k = %d:", k)
219
+
220
+ interpolator = MunsellKDTreeInterpolator(k=k, power=2.0)
221
+ interpolator.fit(X_train, y_train)
222
+
223
+ results = evaluate_kdtree(interpolator, X_val, y_val, "Validation")
224
+ total_mae = sum(results.values())
225
+
226
+ if total_mae < best_mae:
227
+ best_mae = total_mae
228
+ best_k = k
229
+
230
+ LOGGER.info("")
231
+ LOGGER.info("=" * 60)
232
+ LOGGER.info("Best k: %d", best_k)
233
+ LOGGER.info("=" * 60)
234
+
235
+ # Train final model with best k on ALL data
236
+ LOGGER.info("")
237
+ LOGGER.info(
238
+ "Training final model on all %d reference colors with k=%d...",
239
+ len(X_all),
240
+ best_k,
241
+ )
242
+
243
+ final_interpolator = MunsellKDTreeInterpolator(k=best_k, power=2.0)
244
+ final_interpolator.fit(X_all, y_all)
245
+
246
+ LOGGER.info("")
247
+ LOGGER.info("Final evaluation (training set = all data):")
248
+ evaluate_kdtree(final_interpolator, X_all, y_all, "All data")
249
+
250
+ # Save the model
251
+ model_dir = PROJECT_ROOT / "models" / "from_xyY"
252
+ model_dir.mkdir(parents=True, exist_ok=True)
253
+ model_path = model_dir / "kdtree_interpolator.pkl"
254
+ final_interpolator.save(model_path)
255
+
256
+ LOGGER.info("")
257
+ LOGGER.info("=" * 80)
258
+
259
+ log_file.close()
260
+
261
+
262
+ if __name__ == "__main__":
263
+ main()
learning_munsell/interpolation/from_xyY/rbf_interpolator.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ RBF (Radial Basis Function) interpolation for xyY to Munsell conversion.
3
+
4
+ This approach uses scipy's RBFInterpolator to build a lookup table
5
+ with smooth interpolation between known color samples.
6
+
7
+ Uses the 4,995 reference colors from MUNSELL_COLOURS_ALL directly.
8
+
9
+ Advantages over ML:
10
+ - Deterministic, no training required
11
+ - Exact interpolation at known points
12
+ - Smooth interpolation between points
13
+ - Easy to understand and debug
14
+
15
+ Disadvantages:
16
+ - Memory scales with number of data points
17
+ - Query time scales with data points (O(n) naive, can optimize)
18
+ - May struggle with extrapolation
19
+ """
20
+
21
+ import logging
22
+ import pickle
23
+ from pathlib import Path
24
+
25
+ import numpy as np
26
+ from numpy.typing import NDArray
27
+ from scipy.interpolate import RBFInterpolator
28
+ from sklearn.model_selection import train_test_split
29
+
30
+ from learning_munsell import PROJECT_ROOT, setup_logging
31
+ from learning_munsell.interpolation.from_xyY import load_munsell_reference_data
32
+
33
+ logging.basicConfig(level=logging.INFO, format="%(message)s")
34
+ LOGGER = logging.getLogger(__name__)
35
+
36
+
37
+ class MunsellRBFInterpolator:
38
+ """
39
+ RBF-based interpolator for xyY to Munsell conversion.
40
+
41
+ Uses separate RBF interpolators for each Munsell component
42
+ (hue, value, chroma, code) to allow independent kernel tuning.
43
+ """
44
+
45
+ def __init__(
46
+ self,
47
+ kernel: str = "thin_plate_spline",
48
+ smoothing: float = 0.0,
49
+ epsilon: float | None = None,
50
+ ) -> None:
51
+ """
52
+ Initialize the RBF interpolator.
53
+
54
+ Parameters
55
+ ----------
56
+ kernel
57
+ RBF kernel type. Options: 'linear', 'thin_plate_spline',
58
+ 'cubic', 'quintic', 'multiquadric', 'inverse_multiquadric',
59
+ 'inverse_quadratic', 'gaussian'
60
+ smoothing
61
+ Smoothing parameter. 0 = exact interpolation.
62
+ epsilon
63
+ Shape parameter for kernels that use it.
64
+ """
65
+ self.kernel = kernel
66
+ self.smoothing = smoothing
67
+ self.epsilon = epsilon
68
+
69
+ self.interpolators: dict[str, RBFInterpolator] = {}
70
+ self.fitted = False
71
+
72
+ def fit(self, X: NDArray, y: NDArray) -> "MunsellRBFInterpolator":
73
+ """
74
+ Fit RBF interpolators to the training data.
75
+
76
+ Parameters
77
+ ----------
78
+ X
79
+ xyY input values of shape (n, 3)
80
+ y
81
+ Munsell output values [hue, value, chroma, code] of shape (n, 4)
82
+
83
+ Returns
84
+ -------
85
+ self
86
+ """
87
+ LOGGER.info("Fitting RBF interpolators...")
88
+ LOGGER.info(" Kernel: %s", self.kernel)
89
+ LOGGER.info(" Smoothing: %s", self.smoothing)
90
+ LOGGER.info(" Data points: %d", len(X))
91
+
92
+ component_names = ["hue", "value", "chroma", "code"]
93
+
94
+ for i, name in enumerate(component_names):
95
+ LOGGER.info(" Building %s interpolator...", name)
96
+
97
+ kwargs = {
98
+ "kernel": self.kernel,
99
+ "smoothing": self.smoothing,
100
+ }
101
+ if self.epsilon is not None:
102
+ kwargs["epsilon"] = self.epsilon
103
+
104
+ self.interpolators[name] = RBFInterpolator(X, y[:, i], **kwargs)
105
+
106
+ self.fitted = True
107
+ LOGGER.info("RBF interpolators fitted successfully")
108
+
109
+ return self
110
+
111
+ def predict(self, X: NDArray) -> NDArray:
112
+ """
113
+ Predict Munsell values for given xyY inputs.
114
+
115
+ Parameters
116
+ ----------
117
+ X
118
+ xyY input values of shape (n, 3)
119
+
120
+ Returns
121
+ -------
122
+ NDArray
123
+ Predicted Munsell values [hue, value, chroma, code] of shape (n, 4)
124
+ """
125
+ if not self.fitted:
126
+ msg = "Interpolator not fitted. Call fit() first."
127
+ raise RuntimeError(msg)
128
+
129
+ results = np.zeros((len(X), 4))
130
+
131
+ for i, name in enumerate(["hue", "value", "chroma", "code"]):
132
+ results[:, i] = self.interpolators[name](X)
133
+
134
+ return results
135
+
136
+ def save(self, path: Path) -> None:
137
+ """Save the interpolator to disk."""
138
+ with open(path, "wb") as f:
139
+ pickle.dump(
140
+ {
141
+ "kernel": self.kernel,
142
+ "smoothing": self.smoothing,
143
+ "epsilon": self.epsilon,
144
+ "interpolators": self.interpolators,
145
+ },
146
+ f,
147
+ )
148
+ LOGGER.info("Saved RBF interpolator to %s", path)
149
+
150
+ @classmethod
151
+ def load(cls, path: Path) -> "MunsellRBFInterpolator":
152
+ """Load the interpolator from disk."""
153
+ with open(path, "rb") as f:
154
+ data = pickle.load(f) # noqa: S301
155
+
156
+ instance = cls(
157
+ kernel=data["kernel"],
158
+ smoothing=data["smoothing"],
159
+ epsilon=data["epsilon"],
160
+ )
161
+ instance.interpolators = data["interpolators"]
162
+ instance.fitted = True
163
+
164
+ LOGGER.info("Loaded RBF interpolator from %s", path)
165
+ return instance
166
+
167
+
168
+ def evaluate_rbf(
169
+ interpolator: MunsellRBFInterpolator,
170
+ X: NDArray,
171
+ y: NDArray,
172
+ name: str = "Test",
173
+ ) -> dict[str, float]:
174
+ """
175
+ Evaluate RBF interpolator performance.
176
+
177
+ Parameters
178
+ ----------
179
+ interpolator
180
+ Fitted RBF interpolator
181
+ X
182
+ Input xyY values
183
+ y
184
+ Ground truth Munsell values
185
+ name
186
+ Name for logging
187
+
188
+ Returns
189
+ -------
190
+ dict
191
+ Dictionary of MAE values for each component
192
+ """
193
+ predictions = interpolator.predict(X)
194
+ errors = np.abs(predictions - y)
195
+
196
+ component_names = ["Hue", "Value", "Chroma", "Code"]
197
+ results = {}
198
+
199
+ LOGGER.info("%s set MAE:", name)
200
+ for i, comp_name in enumerate(component_names):
201
+ mae = errors[:, i].mean()
202
+ results[comp_name.lower()] = mae
203
+ LOGGER.info(" %s: %.4f", comp_name, mae)
204
+
205
+ return results
206
+
207
+
208
+ def main() -> None:
209
+ """Build and evaluate RBF interpolator using reference Munsell data."""
210
+
211
+ log_file = setup_logging("rbf_interpolator", "from_xyY")
212
+
213
+ LOGGER.info("=" * 80)
214
+ LOGGER.info("RBF Interpolation for xyY to Munsell Conversion")
215
+ LOGGER.info("Using MUNSELL_COLOURS_ALL reference data (4,995 colors)")
216
+ LOGGER.info("=" * 80)
217
+
218
+ # Load reference data from colour library
219
+ LOGGER.info("")
220
+ LOGGER.info("Loading reference Munsell data...")
221
+ X_all, y_all = load_munsell_reference_data()
222
+ LOGGER.info("Total reference colors: %d", len(X_all))
223
+
224
+ # Split into train/validation (80/20)
225
+ X_train, X_val, y_train, y_val = train_test_split(
226
+ X_all, y_all, test_size=0.2, random_state=42
227
+ )
228
+
229
+ LOGGER.info("Train samples: %d", len(X_train))
230
+ LOGGER.info("Validation samples: %d", len(X_val))
231
+
232
+ # Test different kernels
233
+ kernels_to_test = [
234
+ ("thin_plate_spline", 0.0),
235
+ ("thin_plate_spline", 0.001),
236
+ ("thin_plate_spline", 0.01),
237
+ ("cubic", 0.0),
238
+ ("linear", 0.0),
239
+ ("multiquadric", 0.0),
240
+ ]
241
+
242
+ best_kernel = None
243
+ best_smoothing = None
244
+ best_mae = float("inf")
245
+
246
+ LOGGER.info("")
247
+ LOGGER.info("Testing different RBF kernels...")
248
+ LOGGER.info("-" * 60)
249
+
250
+ for kernel, smoothing in kernels_to_test:
251
+ LOGGER.info("")
252
+ LOGGER.info("Kernel: %s, Smoothing: %s", kernel, smoothing)
253
+
254
+ try:
255
+ interpolator = MunsellRBFInterpolator(kernel=kernel, smoothing=smoothing)
256
+ interpolator.fit(X_train, y_train)
257
+
258
+ results = evaluate_rbf(interpolator, X_val, y_val, "Validation")
259
+ total_mae = sum(results.values())
260
+
261
+ if total_mae < best_mae:
262
+ best_mae = total_mae
263
+ best_kernel = kernel
264
+ best_smoothing = smoothing
265
+
266
+ except Exception:
267
+ LOGGER.exception(" Failed")
268
+
269
+ LOGGER.info("")
270
+ LOGGER.info("=" * 60)
271
+ LOGGER.info("Best configuration: %s with smoothing=%s", best_kernel, best_smoothing)
272
+ LOGGER.info("=" * 60)
273
+
274
+ # Train final model with best kernel on ALL data
275
+ LOGGER.info("")
276
+ LOGGER.info("Training final model on all %d reference colors...", len(X_all))
277
+
278
+ final_interpolator = MunsellRBFInterpolator(
279
+ kernel=best_kernel, smoothing=best_smoothing
280
+ )
281
+ final_interpolator.fit(X_all, y_all)
282
+
283
+ LOGGER.info("")
284
+ LOGGER.info("Final evaluation (training set = all data):")
285
+ evaluate_rbf(final_interpolator, X_all, y_all, "All data")
286
+
287
+ # Save the model
288
+ model_dir = PROJECT_ROOT / "models" / "from_xyY"
289
+ model_dir.mkdir(parents=True, exist_ok=True)
290
+ model_path = model_dir / "rbf_interpolator.pkl"
291
+ final_interpolator.save(model_path)
292
+
293
+ LOGGER.info("")
294
+ LOGGER.info("=" * 80)
295
+
296
+ log_file.close()
297
+
298
+
299
+ if __name__ == "__main__":
300
+ main()
learning_munsell/losses/__init__.py ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Loss functions for Munsell ML training."""
2
+
3
+ from learning_munsell.losses.jax_delta_e import (
4
+ XYZ_to_Lab,
5
+ delta_E_CIE2000,
6
+ delta_E_loss,
7
+ xyY_to_Lab,
8
+ xyY_to_XYZ,
9
+ )
10
+
11
+ __all__ = [
12
+ "delta_E_CIE2000",
13
+ "delta_E_loss",
14
+ "xyY_to_Lab",
15
+ "xyY_to_XYZ",
16
+ "XYZ_to_Lab",
17
+ ]
learning_munsell/losses/jax_delta_e.py ADDED
@@ -0,0 +1,299 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Differentiable Delta-E Loss Functions using JAX
3
+ ================================================
4
+
5
+ This module provides JAX implementations of color space conversions
6
+ and Delta-E (CIE2000) loss function for use in training.
7
+
8
+ The key insight is that we can compute Delta-E between:
9
+ - The input xyY (which we convert to Lab as the "target")
10
+ - The predicted Munsell converted back to Lab
11
+
12
+ For the Munsell -> xyY conversion, we either:
13
+ 1. Use a pre-trained neural network approximator
14
+ 2. Use differentiable interpolation on the Munsell Renotation data
15
+ """
16
+
17
+ from __future__ import annotations
18
+
19
+ import colour
20
+ import jax
21
+ import jax.numpy as jnp
22
+ import numpy as np
23
+ from jax import Array
24
+
25
+ # D65 illuminant XYZ reference values (standard for sRGB)
26
+ D65_XYZ = jnp.array([95.047, 100.0, 108.883])
27
+
28
+ # Illuminant C XYZ reference values (used by Munsell system)
29
+ ILLUMINANT_C_XYZ = jnp.array([98.074, 100.0, 118.232])
30
+
31
+
32
+ def xyY_to_XYZ(xyY: Array, scale_Y: bool = True) -> Array:
33
+ """
34
+ Convert CIE xyY to CIE XYZ.
35
+
36
+ Parameters
37
+ ----------
38
+ xyY : Array
39
+ CIE xyY values with shape (..., 3)
40
+ scale_Y : bool
41
+ If True, scale Y from 0-1 to 0-100 range (required for Lab conversion)
42
+
43
+ Returns
44
+ -------
45
+ Array
46
+ CIE XYZ values with shape (..., 3)
47
+ """
48
+ x = xyY[..., 0]
49
+ y = xyY[..., 1]
50
+ Y = xyY[..., 2]
51
+
52
+ # Scale Y to 0-100 range if needed (colour library uses 0-100)
53
+ if scale_Y:
54
+ Y = Y * 100.0
55
+
56
+ # Avoid division by zero
57
+ y_safe = jnp.where(y == 0, 1e-10, y)
58
+
59
+ X = (x * Y) / y_safe
60
+ Z = ((1 - x - y) * Y) / y_safe
61
+
62
+ # Handle y=0 case (set X=Z=0)
63
+ X = jnp.where(y == 0, 0.0, X)
64
+ Z = jnp.where(y == 0, 0.0, Z)
65
+
66
+ return jnp.stack([X, Y, Z], axis=-1)
67
+
68
+
69
+ def XYZ_to_Lab(XYZ: Array, illuminant: Array = ILLUMINANT_C_XYZ) -> Array:
70
+ """
71
+ Convert CIE XYZ to CIE Lab.
72
+
73
+ Parameters
74
+ ----------
75
+ XYZ : Array
76
+ CIE XYZ values with shape (..., 3)
77
+ illuminant : Array
78
+ Reference white XYZ values
79
+
80
+ Returns
81
+ -------
82
+ Array
83
+ CIE Lab values with shape (..., 3)
84
+ """
85
+ # Normalize by illuminant
86
+ XYZ_n = XYZ / illuminant
87
+
88
+ # CIE Lab transfer function
89
+ delta = 6.0 / 29.0
90
+ delta_cube = delta**3
91
+
92
+ # f(t) = t^(1/3) if t > delta^3, else t/(3*delta^2) + 4/29
93
+ def f(t: Array) -> Array:
94
+ return jnp.where(t > delta_cube, jnp.cbrt(t), t / (3 * delta**2) + 4.0 / 29.0)
95
+
96
+ f_X = f(XYZ_n[..., 0])
97
+ f_Y = f(XYZ_n[..., 1])
98
+ f_Z = f(XYZ_n[..., 2])
99
+
100
+ L = 116.0 * f_Y - 16.0
101
+ a = 500.0 * (f_X - f_Y)
102
+ b = 200.0 * (f_Y - f_Z)
103
+
104
+ return jnp.stack([L, a, b], axis=-1)
105
+
106
+
107
+ def xyY_to_Lab(xyY: Array, illuminant: Array = ILLUMINANT_C_XYZ) -> Array:
108
+ """Convert CIE xyY directly to CIE Lab."""
109
+ return XYZ_to_Lab(xyY_to_XYZ(xyY), illuminant)
110
+
111
+
112
+ def delta_E_CIE2000(Lab_1: Array, Lab_2: Array) -> Array:
113
+ """
114
+ Compute CIE 2000 Delta-E color difference.
115
+
116
+ This is a differentiable JAX implementation of the CIE 2000 Delta-E formula.
117
+
118
+ Parameters
119
+ ----------
120
+ Lab_1 : Array
121
+ First CIE Lab color(s) with shape (..., 3)
122
+ Lab_2 : Array
123
+ Second CIE Lab color(s) with shape (..., 3)
124
+
125
+ Returns
126
+ -------
127
+ Array
128
+ Delta-E values with shape (...)
129
+ """
130
+ L_1, a_1, b_1 = Lab_1[..., 0], Lab_1[..., 1], Lab_1[..., 2]
131
+ L_2, a_2, b_2 = Lab_2[..., 0], Lab_2[..., 1], Lab_2[..., 2]
132
+
133
+ # Chroma
134
+ C_1_ab = jnp.sqrt(a_1**2 + b_1**2)
135
+ C_2_ab = jnp.sqrt(a_2**2 + b_2**2)
136
+
137
+ C_bar_ab = (C_1_ab + C_2_ab) / 2
138
+ C_bar_ab_7 = C_bar_ab**7
139
+
140
+ # G factor for a' adjustment (25^7 = 6103515625.0)
141
+ G = 0.5 * (1 - jnp.sqrt(C_bar_ab_7 / (C_bar_ab_7 + 6103515625.0)))
142
+
143
+ # Adjusted a'
144
+ a_p_1 = (1 + G) * a_1
145
+ a_p_2 = (1 + G) * a_2
146
+
147
+ # Adjusted chroma C'
148
+ C_p_1 = jnp.sqrt(a_p_1**2 + b_1**2)
149
+ C_p_2 = jnp.sqrt(a_p_2**2 + b_2**2)
150
+
151
+ # Hue angle h' (in degrees)
152
+ h_p_1 = jnp.degrees(jnp.arctan2(b_1, a_p_1)) % 360
153
+ h_p_2 = jnp.degrees(jnp.arctan2(b_2, a_p_2)) % 360
154
+
155
+ # Handle achromatic case
156
+ h_p_1 = jnp.where((b_1 == 0) & (a_p_1 == 0), 0.0, h_p_1)
157
+ h_p_2 = jnp.where((b_2 == 0) & (a_p_2 == 0), 0.0, h_p_2)
158
+
159
+ # Delta L', C'
160
+ delta_L_p = L_2 - L_1
161
+ delta_C_p = C_p_2 - C_p_1
162
+
163
+ # Delta h'
164
+ h_p_diff = h_p_2 - h_p_1
165
+ C_p_product = C_p_1 * C_p_2
166
+
167
+ delta_h_p = jnp.where(
168
+ C_p_product == 0,
169
+ 0.0,
170
+ jnp.where(
171
+ jnp.abs(h_p_diff) <= 180,
172
+ h_p_diff,
173
+ jnp.where(h_p_diff > 180, h_p_diff - 360, h_p_diff + 360),
174
+ ),
175
+ )
176
+
177
+ # Delta H'
178
+ delta_H_p = 2 * jnp.sqrt(C_p_product) * jnp.sin(jnp.radians(delta_h_p / 2))
179
+
180
+ # Mean L', C'
181
+ L_bar_p = (L_1 + L_2) / 2
182
+ C_bar_p = (C_p_1 + C_p_2) / 2
183
+
184
+ # Mean h'
185
+ h_p_sum = h_p_1 + h_p_2
186
+ h_p_abs_diff = jnp.abs(h_p_1 - h_p_2)
187
+
188
+ h_bar_p = jnp.where(
189
+ C_p_product == 0,
190
+ h_p_sum,
191
+ jnp.where(
192
+ h_p_abs_diff <= 180,
193
+ h_p_sum / 2,
194
+ jnp.where(h_p_sum < 360, (h_p_sum + 360) / 2, (h_p_sum - 360) / 2),
195
+ ),
196
+ )
197
+
198
+ # T factor
199
+ T = (
200
+ 1
201
+ - 0.17 * jnp.cos(jnp.radians(h_bar_p - 30))
202
+ + 0.24 * jnp.cos(jnp.radians(2 * h_bar_p))
203
+ + 0.32 * jnp.cos(jnp.radians(3 * h_bar_p + 6))
204
+ - 0.20 * jnp.cos(jnp.radians(4 * h_bar_p - 63))
205
+ )
206
+
207
+ # Delta theta
208
+ delta_theta = 30 * jnp.exp(-(((h_bar_p - 275) / 25) ** 2))
209
+
210
+ # R_C (25^7 = 6103515625.0)
211
+ C_bar_p_7 = C_bar_p**7
212
+ R_C = 2 * jnp.sqrt(C_bar_p_7 / (C_bar_p_7 + 6103515625.0))
213
+
214
+ # S_L, S_C, S_H
215
+ L_bar_p_minus_50_sq = (L_bar_p - 50) ** 2
216
+ S_L = 1 + (0.015 * L_bar_p_minus_50_sq) / jnp.sqrt(20 + L_bar_p_minus_50_sq)
217
+ S_C = 1 + 0.045 * C_bar_p
218
+ S_H = 1 + 0.015 * C_bar_p * T
219
+
220
+ # R_T
221
+ R_T = -jnp.sin(jnp.radians(2 * delta_theta)) * R_C
222
+
223
+ # Final Delta E
224
+ k_L, k_C, k_H = 1.0, 1.0, 1.0
225
+
226
+ term_L = delta_L_p / (k_L * S_L)
227
+ term_C = delta_C_p / (k_C * S_C)
228
+ term_H = delta_H_p / (k_H * S_H)
229
+
230
+ return jnp.sqrt(term_L**2 + term_C**2 + term_H**2 + R_T * term_C * term_H)
231
+
232
+
233
+ def delta_E_loss(pred_xyY: Array, target_xyY: Array) -> Array:
234
+ """
235
+ Compute mean Delta-E loss between predicted and target xyY values.
236
+
237
+ This is the primary loss function for training with perceptual accuracy.
238
+
239
+ Parameters
240
+ ----------
241
+ pred_xyY : Array
242
+ Predicted xyY values with shape (batch, 3)
243
+ target_xyY : Array
244
+ Target xyY values with shape (batch, 3)
245
+
246
+ Returns
247
+ -------
248
+ Array
249
+ Scalar mean Delta-E loss
250
+ """
251
+ pred_Lab = xyY_to_Lab(pred_xyY)
252
+ target_Lab = xyY_to_Lab(target_xyY)
253
+ return jnp.mean(delta_E_CIE2000(pred_Lab, target_Lab))
254
+
255
+
256
+ # JIT-compiled versions for performance
257
+ xyY_to_XYZ_jit = jax.jit(xyY_to_XYZ)
258
+ XYZ_to_Lab_jit = jax.jit(XYZ_to_Lab)
259
+ xyY_to_Lab_jit = jax.jit(xyY_to_Lab)
260
+ delta_E_CIE2000_jit = jax.jit(delta_E_CIE2000)
261
+ delta_E_loss_jit = jax.jit(delta_E_loss)
262
+
263
+ # Gradient functions
264
+ grad_delta_E_loss = jax.grad(delta_E_loss)
265
+
266
+
267
+ def test_jax_delta_e() -> None:
268
+ """Test the JAX Delta-E implementation against colour library."""
269
+ # Test xyY values
270
+ xyY_1 = np.array([0.3127, 0.3290, 0.5]) # D65 white point, Y=0.5
271
+ xyY_2 = np.array([0.35, 0.35, 0.5]) # Slightly shifted
272
+
273
+ # Convert using JAX
274
+ Lab_1_jax = xyY_to_Lab(jnp.array(xyY_1))
275
+ Lab_2_jax = xyY_to_Lab(jnp.array(xyY_2))
276
+ delta_E_CIE2000(Lab_1_jax, Lab_2_jax)
277
+
278
+ # Convert using colour library
279
+ XYZ_1 = colour.xyY_to_XYZ(xyY_1)
280
+ XYZ_2 = colour.xyY_to_XYZ(xyY_2)
281
+ Lab_1_colour = colour.XYZ_to_Lab(
282
+ XYZ_1, colour.CCS_ILLUMINANTS["CIE 1931 2 Degree Standard Observer"]["C"]
283
+ )
284
+ Lab_2_colour = colour.XYZ_to_Lab(
285
+ XYZ_2, colour.CCS_ILLUMINANTS["CIE 1931 2 Degree Standard Observer"]["C"]
286
+ )
287
+ colour.delta_E(Lab_1_colour, Lab_2_colour, method="CIE 2000")
288
+
289
+ # Test gradient computation
290
+ pred_xyY = jnp.array([[0.35, 0.35, 0.5]])
291
+ target_xyY = jnp.array([[0.3127, 0.3290, 0.5]])
292
+
293
+ # Compute gradient
294
+ grad_fn = jax.grad(lambda x: delta_E_loss(x, target_xyY))
295
+ grad_fn(pred_xyY)
296
+
297
+
298
+ if __name__ == "__main__":
299
+ test_jax_delta_e()
learning_munsell/models/__init__.py ADDED
@@ -0,0 +1,47 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Neural network models for Munsell color conversions."""
2
+
3
+ from learning_munsell.models.networks import (
4
+ # Building blocks
5
+ ResidualBlock,
6
+ # Component networks
7
+ ComponentMLP,
8
+ ComponentErrorPredictor,
9
+ # Transformer building blocks
10
+ FeatureTokenizer,
11
+ TransformerBlock,
12
+ # Composite models: xyY → Munsell
13
+ MLPToMunsell,
14
+ MultiHeadMLPToMunsell,
15
+ MultiMLPToMunsell,
16
+ TransformerToMunsell,
17
+ # Error predictors: xyY → Munsell
18
+ MultiHeadErrorPredictorToMunsell,
19
+ MultiMLPErrorPredictorToMunsell,
20
+ # Composite models: Munsell → xyY
21
+ MultiMLPToxyY,
22
+ # Error predictors: Munsell → xyY
23
+ MultiMLPErrorPredictorToxyY,
24
+ )
25
+
26
+ __all__ = [
27
+ # Building blocks
28
+ "ResidualBlock",
29
+ # Component networks (single output)
30
+ "ComponentMLP",
31
+ "ComponentErrorPredictor",
32
+ # Transformer building blocks
33
+ "FeatureTokenizer",
34
+ "TransformerBlock",
35
+ # Composite models: xyY → Munsell
36
+ "MLPToMunsell",
37
+ "MultiHeadMLPToMunsell",
38
+ "MultiMLPToMunsell",
39
+ "TransformerToMunsell",
40
+ # Error predictors: xyY → Munsell
41
+ "MultiHeadErrorPredictorToMunsell",
42
+ "MultiMLPErrorPredictorToMunsell",
43
+ # Composite models: Munsell → xyY
44
+ "MultiMLPToxyY",
45
+ # Error predictors: Munsell → xyY
46
+ "MultiMLPErrorPredictorToxyY",
47
+ ]
learning_munsell/models/networks.py ADDED
@@ -0,0 +1,1294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Reusable neural network building blocks.
3
+
4
+ Provides shared network architectures for training scripts,
5
+ including MLP components and error predictors.
6
+ """
7
+
8
+ from __future__ import annotations
9
+
10
+ import torch
11
+ from torch import nn, Tensor
12
+
13
+ __all__ = [
14
+ # Building blocks
15
+ "ResidualBlock",
16
+ # Component networks (single output)
17
+ "ComponentMLP",
18
+ "ComponentResNet",
19
+ "ComponentErrorPredictor",
20
+ # Transformer building blocks
21
+ "FeatureTokenizer",
22
+ "TransformerBlock",
23
+ # Composite models: xyY → Munsell
24
+ "MLPToMunsell",
25
+ "MultiHeadMLPToMunsell",
26
+ "MultiMLPToMunsell",
27
+ "MultiResNetToMunsell",
28
+ "TransformerToMunsell",
29
+ # Error predictors: xyY → Munsell
30
+ "MultiHeadErrorPredictorToMunsell",
31
+ "MultiMLPErrorPredictorToMunsell",
32
+ "MultiResNetErrorPredictorToMunsell",
33
+ # Composite models: Munsell → xyY
34
+ "MultiMLPToxyY",
35
+ # Error predictors: Munsell → xyY
36
+ "MultiMLPErrorPredictorToxyY",
37
+ ]
38
+
39
+
40
+ # =============================================================================
41
+ # Building Blocks
42
+ # =============================================================================
43
+
44
+
45
+ class ResidualBlock(nn.Module):
46
+ """
47
+ Residual block with GELU activation and batch normalization.
48
+
49
+ Architecture: input → Linear → GELU → BatchNorm → Linear → BatchNorm → add input → GELU
50
+
51
+ Parameters
52
+ ----------
53
+ dim : int
54
+ Dimension of input and output features.
55
+
56
+ Attributes
57
+ ----------
58
+ block : nn.Sequential
59
+ Sequential block with linear layers, GELU, and BatchNorm.
60
+ activation : nn.GELU
61
+ Final activation after residual addition.
62
+ """
63
+
64
+ def __init__(self, dim: int) -> None:
65
+ """Initialize residual block."""
66
+ super().__init__()
67
+ self.block = nn.Sequential(
68
+ nn.Linear(dim, dim),
69
+ nn.GELU(),
70
+ nn.BatchNorm1d(dim),
71
+ nn.Linear(dim, dim),
72
+ nn.BatchNorm1d(dim),
73
+ )
74
+ self.activation = nn.GELU()
75
+
76
+ def forward(self, x: Tensor) -> Tensor:
77
+ """
78
+ Forward pass with residual connection.
79
+
80
+ Parameters
81
+ ----------
82
+ x : Tensor
83
+ Input tensor of shape (batch_size, dim).
84
+
85
+ Returns
86
+ -------
87
+ Tensor
88
+ Output tensor of shape (batch_size, dim).
89
+ """
90
+ return self.activation(x + self.block(x))
91
+
92
+
93
+ # =============================================================================
94
+ # Component Networks (Single Output)
95
+ # =============================================================================
96
+
97
+
98
+ class ComponentMLP(nn.Module):
99
+ """
100
+ Independent MLP for a single Munsell component.
101
+
102
+ Architecture: input_dim → 128 → 256 → 512 → 256 → 128 → 1
103
+
104
+ Parameters
105
+ ----------
106
+ input_dim : int, optional
107
+ Input feature dimension. Default is 3 (for xyY).
108
+ width_multiplier : float, optional
109
+ Multiplier for hidden layer dimensions. Default is 1.0.
110
+ dropout : float, optional
111
+ Dropout probability between layers. Default is 0.0.
112
+
113
+ Attributes
114
+ ----------
115
+ network : nn.Sequential
116
+ Feed-forward network with encoder-decoder structure.
117
+
118
+ Notes
119
+ -----
120
+ Uses ReLU activations and batch normalization. The encoder-decoder
121
+ architecture expands to 512-dim (or scaled by width_multiplier) and
122
+ then contracts back to a single output. Optional dropout can be
123
+ applied between layers for regularization.
124
+ """
125
+
126
+ def __init__(
127
+ self,
128
+ input_dim: int = 3,
129
+ width_multiplier: float = 1.0,
130
+ dropout: float = 0.0,
131
+ ) -> None:
132
+ """Initialize the component-specific MLP."""
133
+ super().__init__()
134
+
135
+ # Scale hidden dimensions
136
+ h1 = int(128 * width_multiplier)
137
+ h2 = int(256 * width_multiplier)
138
+ h3 = int(512 * width_multiplier)
139
+
140
+ layers: list[nn.Module] = [
141
+ # Encoder
142
+ nn.Linear(input_dim, h1),
143
+ nn.ReLU(),
144
+ nn.BatchNorm1d(h1),
145
+ ]
146
+
147
+ if dropout > 0:
148
+ layers.append(nn.Dropout(dropout))
149
+
150
+ layers.extend(
151
+ [
152
+ nn.Linear(h1, h2),
153
+ nn.ReLU(),
154
+ nn.BatchNorm1d(h2),
155
+ ]
156
+ )
157
+
158
+ if dropout > 0:
159
+ layers.append(nn.Dropout(dropout))
160
+
161
+ layers.extend(
162
+ [
163
+ nn.Linear(h2, h3),
164
+ nn.ReLU(),
165
+ nn.BatchNorm1d(h3),
166
+ ]
167
+ )
168
+
169
+ if dropout > 0:
170
+ layers.append(nn.Dropout(dropout))
171
+
172
+ layers.extend(
173
+ [
174
+ # Decoder
175
+ nn.Linear(h3, h2),
176
+ nn.ReLU(),
177
+ nn.BatchNorm1d(h2),
178
+ ]
179
+ )
180
+
181
+ if dropout > 0:
182
+ layers.append(nn.Dropout(dropout))
183
+
184
+ layers.extend(
185
+ [
186
+ nn.Linear(h2, h1),
187
+ nn.ReLU(),
188
+ nn.BatchNorm1d(h1),
189
+ # Output
190
+ nn.Linear(h1, 1),
191
+ ]
192
+ )
193
+
194
+ self.network = nn.Sequential(*layers)
195
+
196
+ def forward(self, x: Tensor) -> Tensor:
197
+ """
198
+ Forward pass through the component-specific network.
199
+
200
+ Parameters
201
+ ----------
202
+ x : Tensor
203
+ Input tensor of shape (batch_size, input_dim).
204
+
205
+ Returns
206
+ -------
207
+ Tensor
208
+ Output tensor of shape (batch_size, 1) containing the predicted
209
+ component value.
210
+ """
211
+ return self.network(x)
212
+
213
+
214
+ class ComponentResNet(nn.Module):
215
+ """
216
+ Independent ResNet for a single Munsell component with true skip connections.
217
+
218
+ Architecture: input → projection → ResidualBlock × num_blocks → output
219
+
220
+ Unlike ComponentMLP, this uses actual residual blocks where:
221
+ output = activation(x + f(x))
222
+
223
+ Parameters
224
+ ----------
225
+ input_dim : int, optional
226
+ Input feature dimension. Default is 3 (for xyY).
227
+ hidden_dim : int, optional
228
+ Hidden dimension for residual blocks. Default is 256.
229
+ num_blocks : int, optional
230
+ Number of residual blocks. Default is 4.
231
+
232
+ Attributes
233
+ ----------
234
+ input_proj : nn.Sequential
235
+ Projects input to hidden dimension with GELU activation.
236
+ res_blocks : nn.ModuleList
237
+ List of ResidualBlock modules with skip connections.
238
+ output_proj : nn.Linear
239
+ Projects hidden dimension to single output.
240
+ """
241
+
242
+ def __init__(
243
+ self,
244
+ input_dim: int = 3,
245
+ hidden_dim: int = 256,
246
+ num_blocks: int = 4,
247
+ ) -> None:
248
+ """Initialize the component-specific ResNet."""
249
+ super().__init__()
250
+
251
+ # Project input to hidden dimension
252
+ self.input_proj = nn.Sequential(
253
+ nn.Linear(input_dim, hidden_dim),
254
+ nn.GELU(),
255
+ )
256
+
257
+ # Stack of residual blocks with skip connections
258
+ self.res_blocks = nn.ModuleList(
259
+ [ResidualBlock(hidden_dim) for _ in range(num_blocks)]
260
+ )
261
+
262
+ # Project to output
263
+ self.output_proj = nn.Linear(hidden_dim, 1)
264
+
265
+ def forward(self, x: Tensor) -> Tensor:
266
+ """
267
+ Forward pass through the ResNet with skip connections.
268
+
269
+ Parameters
270
+ ----------
271
+ x : Tensor
272
+ Input tensor of shape (batch_size, input_dim).
273
+
274
+ Returns
275
+ -------
276
+ Tensor
277
+ Output tensor of shape (batch_size, 1).
278
+ """
279
+ x = self.input_proj(x)
280
+ for block in self.res_blocks:
281
+ x = block(x) # Each block applies: activation(x + f(x))
282
+ return self.output_proj(x)
283
+
284
+
285
+ class ComponentErrorPredictor(nn.Module):
286
+ """
287
+ Independent error predictor for a single Munsell component.
288
+
289
+ A deep MLP that learns to predict residual errors for one Munsell
290
+ component (hue, value, chroma, or code).
291
+
292
+ Parameters
293
+ ----------
294
+ input_dim : int, optional
295
+ Input feature dimension. Default is 7 (xyY_norm + base_pred_norm).
296
+ width_multiplier : float, optional
297
+ Multiplier for hidden layer widths. Default is 1.0.
298
+ Use 1.5 for chroma which requires more capacity.
299
+
300
+ Attributes
301
+ ----------
302
+ network : nn.Sequential
303
+ Feed-forward network: input → 128 → 256 → 512 → 256 → 128 → 1
304
+ with GELU activations and BatchNorm after each hidden layer.
305
+
306
+ Notes
307
+ -----
308
+ Default input is [xyY_norm (3) + base_pred_norm (4)] = 7 features.
309
+ Output is a single scalar error correction for the component.
310
+ """
311
+
312
+ def __init__(
313
+ self,
314
+ input_dim: int = 7,
315
+ width_multiplier: float = 1.0,
316
+ ) -> None:
317
+ """Initialize the error predictor."""
318
+ super().__init__()
319
+
320
+ # Scale hidden dimensions
321
+ h1 = int(128 * width_multiplier)
322
+ h2 = int(256 * width_multiplier)
323
+ h3 = int(512 * width_multiplier)
324
+
325
+ self.network = nn.Sequential(
326
+ # Encoder
327
+ nn.Linear(input_dim, h1),
328
+ nn.GELU(),
329
+ nn.BatchNorm1d(h1),
330
+ nn.Linear(h1, h2),
331
+ nn.GELU(),
332
+ nn.BatchNorm1d(h2),
333
+ nn.Linear(h2, h3),
334
+ nn.GELU(),
335
+ nn.BatchNorm1d(h3),
336
+ # Decoder
337
+ nn.Linear(h3, h2),
338
+ nn.GELU(),
339
+ nn.BatchNorm1d(h2),
340
+ nn.Linear(h2, h1),
341
+ nn.GELU(),
342
+ nn.BatchNorm1d(h1),
343
+ # Output
344
+ nn.Linear(h1, 1),
345
+ )
346
+
347
+ def forward(self, x: Tensor) -> Tensor:
348
+ """
349
+ Forward pass through the error predictor.
350
+
351
+ Parameters
352
+ ----------
353
+ x : Tensor
354
+ Combined input of shape (batch_size, input_dim).
355
+
356
+ Returns
357
+ -------
358
+ Tensor
359
+ Predicted error correction of shape (batch_size, 1).
360
+ """
361
+ return self.network(x)
362
+
363
+
364
+ # =============================================================================
365
+ # Transformer Building Blocks
366
+ # =============================================================================
367
+
368
+
369
+ class FeatureTokenizer(nn.Module):
370
+ """
371
+ Tokenize each input feature into high-dimensional embedding.
372
+
373
+ Converts each scalar input feature into a learned embedding vector,
374
+ similar to word embeddings in NLP. Also prepends a learnable CLS token
375
+ used for regression output.
376
+
377
+ Parameters
378
+ ----------
379
+ num_features : int
380
+ Number of input features to tokenize.
381
+ embedding_dim : int
382
+ Dimensionality of each token embedding.
383
+
384
+ Attributes
385
+ ----------
386
+ feature_embeddings : nn.ModuleList
387
+ List of linear layers, one per input feature.
388
+ cls_token : nn.Parameter
389
+ Learnable classification token prepended to feature tokens.
390
+ """
391
+
392
+ def __init__(self, num_features: int, embedding_dim: int) -> None:
393
+ """Initialize the feature tokenizer."""
394
+ super().__init__()
395
+ # Each feature gets its own embedding
396
+ self.feature_embeddings = nn.ModuleList(
397
+ [nn.Linear(1, embedding_dim) for _ in range(num_features)]
398
+ )
399
+ # CLS token for regression
400
+ self.cls_token = nn.Parameter(torch.randn(1, 1, embedding_dim))
401
+
402
+ def forward(self, x: Tensor) -> Tensor:
403
+ """
404
+ Transform input features into token embeddings.
405
+
406
+ Parameters
407
+ ----------
408
+ x : Tensor
409
+ Input tensor of shape (batch_size, num_features).
410
+
411
+ Returns
412
+ -------
413
+ Tensor
414
+ Token embeddings of shape (batch_size, 1+num_features, embedding_dim).
415
+ First token is CLS, followed by feature tokens.
416
+ """
417
+ batch_size = x.size(0)
418
+
419
+ # Tokenize each feature
420
+ tokens = []
421
+ for i, embedding in enumerate(self.feature_embeddings):
422
+ feature_val = x[:, i : i + 1] # (batch_size, 1)
423
+ token = embedding(feature_val) # (batch_size, embedding_dim)
424
+ tokens.append(token.unsqueeze(1)) # (batch_size, 1, embedding_dim)
425
+
426
+ # Concatenate feature tokens
427
+ feature_tokens = torch.cat(
428
+ tokens, dim=1
429
+ ) # (batch_size, num_features, embedding_dim)
430
+
431
+ # Prepend CLS token
432
+ cls_tokens = self.cls_token.expand(
433
+ batch_size, -1, -1
434
+ ) # (batch_size, 1, embedding_dim)
435
+ return torch.cat(
436
+ [cls_tokens, feature_tokens], dim=1
437
+ ) # (batch_size, 1+num_features, embedding_dim)
438
+
439
+
440
+ class TransformerBlock(nn.Module):
441
+ """
442
+ Standard transformer block with multi-head attention and feedforward network.
443
+
444
+ Implements the classic transformer architecture with self-attention,
445
+ feedforward layers, layer normalization, and residual connections.
446
+
447
+ Parameters
448
+ ----------
449
+ embedding_dim : int
450
+ Dimension of token embeddings.
451
+ num_heads : int
452
+ Number of attention heads.
453
+ ff_dim : int
454
+ Hidden dimension of feedforward network.
455
+ dropout : float, optional
456
+ Dropout probability, default is 0.1.
457
+
458
+ Attributes
459
+ ----------
460
+ attention : nn.MultiheadAttention
461
+ Multi-head self-attention mechanism.
462
+ norm1 : nn.LayerNorm
463
+ Layer normalization after attention.
464
+ feedforward : nn.Sequential
465
+ Feedforward network with GELU activation.
466
+ norm2 : nn.LayerNorm
467
+ Layer normalization after feedforward.
468
+ """
469
+
470
+ def __init__(
471
+ self, embedding_dim: int, num_heads: int, ff_dim: int, dropout: float = 0.1
472
+ ) -> None:
473
+ """Initialize the transformer block."""
474
+ super().__init__()
475
+
476
+ self.attention = nn.MultiheadAttention(
477
+ embedding_dim, num_heads, dropout=dropout, batch_first=True
478
+ )
479
+ self.norm1 = nn.LayerNorm(embedding_dim)
480
+
481
+ self.feedforward = nn.Sequential(
482
+ nn.Linear(embedding_dim, ff_dim),
483
+ nn.GELU(),
484
+ nn.Dropout(dropout),
485
+ nn.Linear(ff_dim, embedding_dim),
486
+ nn.Dropout(dropout),
487
+ )
488
+ self.norm2 = nn.LayerNorm(embedding_dim)
489
+
490
+ def forward(self, x: Tensor) -> Tensor:
491
+ """
492
+ Apply transformer block to input tokens.
493
+
494
+ Parameters
495
+ ----------
496
+ x : Tensor
497
+ Input tokens of shape (batch_size, num_tokens, embedding_dim).
498
+
499
+ Returns
500
+ -------
501
+ Tensor
502
+ Transformed tokens of shape (batch_size, num_tokens, embedding_dim).
503
+ """
504
+ # Self-attention with residual
505
+ attn_output, _ = self.attention(x, x, x)
506
+ x = self.norm1(x + attn_output)
507
+
508
+ # Feedforward with residual
509
+ ff_output = self.feedforward(x)
510
+ return self.norm2(x + ff_output)
511
+
512
+
513
+ # =============================================================================
514
+ # Composite Models: xyY → Munsell
515
+ # =============================================================================
516
+
517
+
518
+ class MLPToMunsell(nn.Module):
519
+ """
520
+ Large MLP for xyY to Munsell conversion.
521
+
522
+ Architecture: 3 → 128 → 256 → 512 → 512 → 256 → 128 → 4
523
+
524
+ Attributes
525
+ ----------
526
+ network : nn.Sequential
527
+ Feed-forward network with ReLU activations and BatchNorm.
528
+ """
529
+
530
+ def __init__(self) -> None:
531
+ """Initialize the MunsellMLP network."""
532
+ super().__init__()
533
+
534
+ self.network = nn.Sequential(
535
+ nn.Linear(3, 128),
536
+ nn.ReLU(),
537
+ nn.BatchNorm1d(128),
538
+ nn.Linear(128, 256),
539
+ nn.ReLU(),
540
+ nn.BatchNorm1d(256),
541
+ nn.Linear(256, 512),
542
+ nn.ReLU(),
543
+ nn.BatchNorm1d(512),
544
+ nn.Linear(512, 512),
545
+ nn.ReLU(),
546
+ nn.BatchNorm1d(512),
547
+ nn.Linear(512, 256),
548
+ nn.ReLU(),
549
+ nn.BatchNorm1d(256),
550
+ nn.Linear(256, 128),
551
+ nn.ReLU(),
552
+ nn.BatchNorm1d(128),
553
+ nn.Linear(128, 4),
554
+ )
555
+
556
+ def forward(self, x: Tensor) -> Tensor:
557
+ """
558
+ Forward pass through the network.
559
+
560
+ Parameters
561
+ ----------
562
+ x : Tensor
563
+ Input tensor of shape (batch_size, 3) containing normalized xyY values.
564
+
565
+ Returns
566
+ -------
567
+ Tensor
568
+ Output tensor of shape (batch_size, 4) containing normalized Munsell
569
+ specifications [hue, value, chroma, code].
570
+ """
571
+ return self.network(x)
572
+
573
+
574
+ class MultiHeadMLPToMunsell(nn.Module):
575
+ """
576
+ Multi-head MLP for xyY to Munsell conversion.
577
+
578
+ Each component (hue, value, chroma, code) has a specialized decoder head
579
+ after a shared encoder. The chroma head is wider to handle the more complex
580
+ non-linear relationship between xyY and chroma.
581
+
582
+ Attributes
583
+ ----------
584
+ encoder : nn.Sequential
585
+ Shared encoder: 3 → 128 → 256 → 512 with ReLU and BatchNorm.
586
+ hue_head : nn.Sequential
587
+ Hue decoder: 512 → 256 → 128 → 1 (circular component).
588
+ value_head : nn.Sequential
589
+ Value decoder: 512 → 256 → 128 → 1 (linear component).
590
+ chroma_head : nn.Sequential
591
+ Chroma decoder: 512 → 384 → 256 → 128 → 1 (wider for complexity).
592
+ code_head : nn.Sequential
593
+ Code decoder: 512 → 256 → 128 → 1 (discrete component).
594
+
595
+ Notes
596
+ -----
597
+ The chroma head has increased capacity (384 units in first layer) to handle
598
+ the more complex non-linear relationship between xyY and chroma.
599
+ """
600
+
601
+ def __init__(self) -> None:
602
+ """Initialize the multi-head MLP model."""
603
+ super().__init__()
604
+
605
+ # Shared encoder - learns general color space features
606
+ self.encoder = nn.Sequential(
607
+ nn.Linear(3, 128),
608
+ nn.ReLU(),
609
+ nn.BatchNorm1d(128),
610
+ nn.Linear(128, 256),
611
+ nn.ReLU(),
612
+ nn.BatchNorm1d(256),
613
+ nn.Linear(256, 512),
614
+ nn.ReLU(),
615
+ nn.BatchNorm1d(512),
616
+ )
617
+
618
+ # Hue head - circular/angular component
619
+ self.hue_head = nn.Sequential(
620
+ nn.Linear(512, 256),
621
+ nn.ReLU(),
622
+ nn.BatchNorm1d(256),
623
+ nn.Linear(256, 128),
624
+ nn.ReLU(),
625
+ nn.BatchNorm1d(128),
626
+ nn.Linear(128, 1),
627
+ )
628
+
629
+ # Value head - linear lightness
630
+ self.value_head = nn.Sequential(
631
+ nn.Linear(512, 256),
632
+ nn.ReLU(),
633
+ nn.BatchNorm1d(256),
634
+ nn.Linear(256, 128),
635
+ nn.ReLU(),
636
+ nn.BatchNorm1d(128),
637
+ nn.Linear(128, 1),
638
+ )
639
+
640
+ # Chroma head - non-linear saturation (WIDER for harder task)
641
+ self.chroma_head = nn.Sequential(
642
+ nn.Linear(512, 384), # Wider than other heads
643
+ nn.ReLU(),
644
+ nn.BatchNorm1d(384),
645
+ nn.Linear(384, 256),
646
+ nn.ReLU(),
647
+ nn.BatchNorm1d(256),
648
+ nn.Linear(256, 128),
649
+ nn.ReLU(),
650
+ nn.BatchNorm1d(128),
651
+ nn.Linear(128, 1),
652
+ )
653
+
654
+ # Code head - discrete categorical
655
+ self.code_head = nn.Sequential(
656
+ nn.Linear(512, 256),
657
+ nn.ReLU(),
658
+ nn.BatchNorm1d(256),
659
+ nn.Linear(256, 128),
660
+ nn.ReLU(),
661
+ nn.BatchNorm1d(128),
662
+ nn.Linear(128, 1),
663
+ )
664
+
665
+ def forward(self, x: Tensor) -> Tensor:
666
+ """
667
+ Forward pass through the multi-head network.
668
+
669
+ Parameters
670
+ ----------
671
+ x : Tensor
672
+ Input xyY values of shape (batch_size, 3).
673
+
674
+ Returns
675
+ -------
676
+ Tensor
677
+ Concatenated Munsell predictions [hue, value, chroma, code]
678
+ of shape (batch_size, 4).
679
+ """
680
+ # Shared feature extraction
681
+ features = self.encoder(x)
682
+
683
+ # Component-specific predictions
684
+ hue = self.hue_head(features)
685
+ value = self.value_head(features)
686
+ chroma = self.chroma_head(features)
687
+ code = self.code_head(features)
688
+
689
+ # Concatenate: [Hue, Value, Chroma, Code]
690
+ return torch.cat([hue, value, chroma, code], dim=1)
691
+
692
+
693
+ class MultiMLPToMunsell(nn.Module):
694
+ """
695
+ Multi-MLP for xyY to Munsell conversion.
696
+
697
+ Uses 4 independent ComponentMLP branches, one for each Munsell component.
698
+ The chroma branch can be wider to handle the more complex relationship.
699
+
700
+ Parameters
701
+ ----------
702
+ chroma_width_multiplier : float, optional
703
+ Width multiplier for the chroma branch. Default is 2.0.
704
+ dropout : float, optional
705
+ Dropout probability for all branches. Default is 0.1.
706
+
707
+ Attributes
708
+ ----------
709
+ hue_branch : ComponentMLP
710
+ MLP for hue component (1.0x width).
711
+ value_branch : ComponentMLP
712
+ MLP for value component (1.0x width).
713
+ chroma_branch : ComponentMLP
714
+ MLP for chroma component (configurable width).
715
+ code_branch : ComponentMLP
716
+ MLP for hue code component (1.0x width).
717
+ """
718
+
719
+ def __init__(
720
+ self, chroma_width_multiplier: float = 2.0, dropout: float = 0.1
721
+ ) -> None:
722
+ """Initialize the multi-branch MLP model."""
723
+ super().__init__()
724
+
725
+ self.hue_branch = ComponentMLP(
726
+ input_dim=3, width_multiplier=1.0, dropout=dropout
727
+ )
728
+ self.value_branch = ComponentMLP(
729
+ input_dim=3, width_multiplier=1.0, dropout=dropout
730
+ )
731
+ self.chroma_branch = ComponentMLP(
732
+ input_dim=3, width_multiplier=chroma_width_multiplier, dropout=dropout
733
+ )
734
+ self.code_branch = ComponentMLP(
735
+ input_dim=3, width_multiplier=1.0, dropout=dropout
736
+ )
737
+
738
+ def forward(self, x: Tensor) -> Tensor:
739
+ """
740
+ Forward pass through all 4 independent branches.
741
+
742
+ Parameters
743
+ ----------
744
+ x : Tensor
745
+ Input tensor of shape (batch_size, 3) containing normalized xyY values.
746
+
747
+ Returns
748
+ -------
749
+ Tensor
750
+ Concatenated predictions [hue, value, chroma, code]
751
+ of shape (batch_size, 4).
752
+ """
753
+ hue = self.hue_branch(x)
754
+ value = self.value_branch(x)
755
+ chroma = self.chroma_branch(x)
756
+ code = self.code_branch(x)
757
+ return torch.cat([hue, value, chroma, code], dim=1)
758
+
759
+
760
+ class MultiResNetToMunsell(nn.Module):
761
+ """
762
+ Multi-ResNet for xyY to Munsell conversion with true skip connections.
763
+
764
+ Uses 4 independent ComponentResNet branches, one for each Munsell component.
765
+ Each branch contains actual residual blocks with skip connections.
766
+
767
+ Parameters
768
+ ----------
769
+ hidden_dim : int, optional
770
+ Hidden dimension for residual blocks. Default is 256.
771
+ num_blocks : int, optional
772
+ Number of residual blocks per branch. Default is 4.
773
+ chroma_hidden_dim : int, optional
774
+ Hidden dimension for chroma branch (typically larger). Default is 512.
775
+
776
+ Attributes
777
+ ----------
778
+ hue_branch : ComponentResNet
779
+ ResNet for hue component.
780
+ value_branch : ComponentResNet
781
+ ResNet for value component.
782
+ chroma_branch : ComponentResNet
783
+ ResNet for chroma component (larger hidden dim).
784
+ code_branch : ComponentResNet
785
+ ResNet for hue code component.
786
+ """
787
+
788
+ def __init__(
789
+ self,
790
+ hidden_dim: int = 256,
791
+ num_blocks: int = 4,
792
+ chroma_hidden_dim: int = 512,
793
+ ) -> None:
794
+ """Initialize the multi-branch ResNet model."""
795
+ super().__init__()
796
+
797
+ self.hue_branch = ComponentResNet(
798
+ input_dim=3, hidden_dim=hidden_dim, num_blocks=num_blocks
799
+ )
800
+ self.value_branch = ComponentResNet(
801
+ input_dim=3, hidden_dim=hidden_dim, num_blocks=num_blocks
802
+ )
803
+ self.chroma_branch = ComponentResNet(
804
+ input_dim=3, hidden_dim=chroma_hidden_dim, num_blocks=num_blocks
805
+ )
806
+ self.code_branch = ComponentResNet(
807
+ input_dim=3, hidden_dim=hidden_dim, num_blocks=num_blocks
808
+ )
809
+
810
+ def forward(self, x: Tensor) -> Tensor:
811
+ """
812
+ Forward pass through all 4 independent ResNet branches.
813
+
814
+ Parameters
815
+ ----------
816
+ x : Tensor
817
+ Input tensor of shape (batch_size, 3) containing normalized xyY values.
818
+
819
+ Returns
820
+ -------
821
+ Tensor
822
+ Concatenated predictions [hue, value, chroma, code]
823
+ of shape (batch_size, 4).
824
+ """
825
+ hue = self.hue_branch(x)
826
+ value = self.value_branch(x)
827
+ chroma = self.chroma_branch(x)
828
+ code = self.code_branch(x)
829
+ return torch.cat([hue, value, chroma, code], dim=1)
830
+
831
+
832
+ class TransformerToMunsell(nn.Module):
833
+ """
834
+ Transformer for xyY to Munsell conversion.
835
+
836
+ Uses a feature tokenizer to convert input features to embeddings,
837
+ followed by transformer blocks with self-attention, and separate
838
+ output heads for each Munsell component.
839
+
840
+ Parameters
841
+ ----------
842
+ num_features : int, optional
843
+ Number of input features (default is 3 for xyY).
844
+ embedding_dim : int, optional
845
+ Dimension of token embeddings (default is 256).
846
+ num_blocks : int, optional
847
+ Number of transformer blocks (default is 6).
848
+ num_heads : int, optional
849
+ Number of attention heads (default is 8).
850
+ ff_dim : int, optional
851
+ Feedforward network hidden dimension (default is 1024).
852
+ dropout : float, optional
853
+ Dropout probability (default is 0.1).
854
+
855
+ Attributes
856
+ ----------
857
+ tokenizer : FeatureTokenizer
858
+ Converts input features to token embeddings with CLS token.
859
+ transformer_blocks : nn.ModuleList
860
+ Stack of transformer blocks with self-attention.
861
+ final_norm : nn.LayerNorm
862
+ Final layer normalization before output heads.
863
+ hue_head : nn.Sequential
864
+ Output head for hue prediction.
865
+ value_head : nn.Sequential
866
+ Output head for value prediction.
867
+ chroma_head : nn.Sequential
868
+ Deeper output head for chroma prediction.
869
+ code_head : nn.Sequential
870
+ Output head for hue code prediction.
871
+
872
+ Notes
873
+ -----
874
+ Architecture: 3 xyY features → 3 tokens + 1 CLS token → transformer blocks
875
+ with self-attention → multi-head output with specialized component heads.
876
+ The chroma head has additional depth due to prediction difficulty.
877
+ """
878
+
879
+ def __init__(
880
+ self,
881
+ num_features: int = 3,
882
+ embedding_dim: int = 256,
883
+ num_blocks: int = 6,
884
+ num_heads: int = 8,
885
+ ff_dim: int = 1024,
886
+ dropout: float = 0.1,
887
+ ) -> None:
888
+ """Initialize the transformer model."""
889
+ super().__init__()
890
+
891
+ self.tokenizer = FeatureTokenizer(num_features, embedding_dim)
892
+
893
+ self.transformer_blocks = nn.ModuleList(
894
+ [
895
+ TransformerBlock(embedding_dim, num_heads, ff_dim, dropout)
896
+ for _ in range(num_blocks)
897
+ ]
898
+ )
899
+
900
+ self.final_norm = nn.LayerNorm(embedding_dim)
901
+
902
+ # Multi-head output - separate heads for each Munsell component
903
+ self.hue_head = nn.Sequential(
904
+ nn.Linear(embedding_dim, 128),
905
+ nn.GELU(),
906
+ nn.Dropout(dropout),
907
+ nn.Linear(128, 1),
908
+ )
909
+ self.value_head = nn.Sequential(
910
+ nn.Linear(embedding_dim, 128),
911
+ nn.GELU(),
912
+ nn.Dropout(dropout),
913
+ nn.Linear(128, 1),
914
+ )
915
+ self.chroma_head = nn.Sequential(
916
+ nn.Linear(embedding_dim, 256),
917
+ nn.GELU(),
918
+ nn.Dropout(dropout),
919
+ nn.Linear(256, 128),
920
+ nn.GELU(),
921
+ nn.Linear(128, 1),
922
+ )
923
+ self.code_head = nn.Sequential(
924
+ nn.Linear(embedding_dim, 128),
925
+ nn.GELU(),
926
+ nn.Dropout(dropout),
927
+ nn.Linear(128, 1),
928
+ )
929
+
930
+ def forward(self, x: Tensor) -> Tensor:
931
+ """
932
+ Forward pass through the transformer.
933
+
934
+ Parameters
935
+ ----------
936
+ x : Tensor
937
+ Input xyY values of shape (batch_size, 3).
938
+
939
+ Returns
940
+ -------
941
+ Tensor
942
+ Predicted Munsell specification [hue, value, chroma, code]
943
+ of shape (batch_size, 4).
944
+
945
+ Notes
946
+ -----
947
+ The CLS token representation is used for the final prediction through
948
+ separate task-specific heads for each Munsell component.
949
+ """
950
+ tokens = self.tokenizer(x)
951
+
952
+ for block in self.transformer_blocks:
953
+ tokens = block(tokens)
954
+
955
+ tokens = self.final_norm(tokens)
956
+ cls_token = tokens[:, 0, :]
957
+
958
+ hue = self.hue_head(cls_token)
959
+ value = self.value_head(cls_token)
960
+ chroma = self.chroma_head(cls_token)
961
+ code = self.code_head(cls_token)
962
+
963
+ return torch.cat([hue, value, chroma, code], dim=1)
964
+
965
+
966
+ # =============================================================================
967
+ # Error Predictors: xyY → Munsell
968
+ # =============================================================================
969
+
970
+
971
+ class MultiHeadErrorPredictorToMunsell(nn.Module):
972
+ """
973
+ Multi-Head error predictor for xyY to Munsell conversion.
974
+
975
+ Each branch is a ComponentErrorPredictor specialized for one
976
+ Munsell component. The chroma branch is wider (1.5x) to handle
977
+ the more complex error patterns in chroma prediction.
978
+
979
+ Parameters
980
+ ----------
981
+ input_dim : int, optional
982
+ Input feature dimension. Default is 7.
983
+ chroma_width : float, optional
984
+ Width multiplier for chroma branch. Default is 1.5.
985
+
986
+ Attributes
987
+ ----------
988
+ hue_branch : ComponentErrorPredictor
989
+ Error predictor for hue component (1.0x width).
990
+ value_branch : ComponentErrorPredictor
991
+ Error predictor for value component (1.0x width).
992
+ chroma_branch : ComponentErrorPredictor
993
+ Error predictor for chroma component (1.5x width by default).
994
+ code_branch : ComponentErrorPredictor
995
+ Error predictor for hue code component (1.0x width).
996
+ """
997
+
998
+ def __init__(
999
+ self,
1000
+ input_dim: int = 7,
1001
+ chroma_width: float = 1.5,
1002
+ ) -> None:
1003
+ """Initialize the multi-head error predictor."""
1004
+ super().__init__()
1005
+
1006
+ # Independent error predictor for each component
1007
+ self.hue_branch = ComponentErrorPredictor(
1008
+ input_dim=input_dim, width_multiplier=1.0
1009
+ )
1010
+ self.value_branch = ComponentErrorPredictor(
1011
+ input_dim=input_dim, width_multiplier=1.0
1012
+ )
1013
+ self.chroma_branch = ComponentErrorPredictor(
1014
+ input_dim=input_dim, width_multiplier=chroma_width
1015
+ )
1016
+ self.code_branch = ComponentErrorPredictor(
1017
+ input_dim=input_dim, width_multiplier=1.0
1018
+ )
1019
+
1020
+ def forward(self, x: Tensor) -> Tensor:
1021
+ """
1022
+ Forward pass through all error predictor branches.
1023
+
1024
+ Parameters
1025
+ ----------
1026
+ x : Tensor
1027
+ Combined input of shape (batch_size, input_dim).
1028
+
1029
+ Returns
1030
+ -------
1031
+ Tensor
1032
+ Concatenated error corrections [hue, value, chroma, code]
1033
+ of shape (batch_size, 4).
1034
+ """
1035
+ # Each branch processes the same combined input independently
1036
+ hue_error = self.hue_branch(x)
1037
+ value_error = self.value_branch(x)
1038
+ chroma_error = self.chroma_branch(x)
1039
+ code_error = self.code_branch(x)
1040
+
1041
+ # Concatenate: [Hue_error, Value_error, Chroma_error, Code_error]
1042
+ return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1)
1043
+
1044
+
1045
+ class MultiMLPErrorPredictorToMunsell(nn.Module):
1046
+ """
1047
+ Multi-MLP error predictor for xyY to Munsell conversion.
1048
+
1049
+ Uses 4 independent ComponentErrorPredictor branches, one for each
1050
+ Munsell component error.
1051
+
1052
+ Parameters
1053
+ ----------
1054
+ chroma_width : float, optional
1055
+ Width multiplier for chroma branch. Default is 1.5.
1056
+
1057
+ Attributes
1058
+ ----------
1059
+ hue_branch : ComponentErrorPredictor
1060
+ Error predictor for hue component (1.0x width).
1061
+ value_branch : ComponentErrorPredictor
1062
+ Error predictor for value component (1.0x width).
1063
+ chroma_branch : ComponentErrorPredictor
1064
+ Error predictor for chroma component (configurable width).
1065
+ code_branch : ComponentErrorPredictor
1066
+ Error predictor for hue code component (1.0x width).
1067
+ """
1068
+
1069
+ def __init__(self, chroma_width: float = 1.5) -> None:
1070
+ """Initialize the multi-head error predictor."""
1071
+ super().__init__()
1072
+
1073
+ self.hue_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0)
1074
+ self.value_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0)
1075
+ self.chroma_branch = ComponentErrorPredictor(
1076
+ input_dim=7, width_multiplier=chroma_width
1077
+ )
1078
+ self.code_branch = ComponentErrorPredictor(input_dim=7, width_multiplier=1.0)
1079
+
1080
+ def forward(self, x: Tensor) -> Tensor:
1081
+ """
1082
+ Forward pass through all error predictor branches.
1083
+
1084
+ Parameters
1085
+ ----------
1086
+ x : Tensor
1087
+ Combined input [xyY_norm, base_pred_norm] of shape (batch_size, 7).
1088
+
1089
+ Returns
1090
+ -------
1091
+ Tensor
1092
+ Concatenated error corrections [hue, value, chroma, code]
1093
+ of shape (batch_size, 4).
1094
+ """
1095
+ hue_error = self.hue_branch(x)
1096
+ value_error = self.value_branch(x)
1097
+ chroma_error = self.chroma_branch(x)
1098
+ code_error = self.code_branch(x)
1099
+ return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1)
1100
+
1101
+
1102
+ class MultiResNetErrorPredictorToMunsell(nn.Module):
1103
+ """
1104
+ Multi-ResNet error predictor for xyY to Munsell conversion.
1105
+
1106
+ Uses 4 independent ComponentResNet branches with true skip connections,
1107
+ one for each Munsell component error.
1108
+
1109
+ Parameters
1110
+ ----------
1111
+ hidden_dim : int, optional
1112
+ Hidden dimension for residual blocks. Default is 256.
1113
+ num_blocks : int, optional
1114
+ Number of residual blocks per branch. Default is 4.
1115
+ chroma_hidden_dim : int, optional
1116
+ Hidden dimension for chroma branch. Default is 384.
1117
+
1118
+ Attributes
1119
+ ----------
1120
+ hue_branch : ComponentResNet
1121
+ ResNet error predictor for hue component.
1122
+ value_branch : ComponentResNet
1123
+ ResNet error predictor for value component.
1124
+ chroma_branch : ComponentResNet
1125
+ ResNet error predictor for chroma component.
1126
+ code_branch : ComponentResNet
1127
+ ResNet error predictor for code component.
1128
+ """
1129
+
1130
+ def __init__(
1131
+ self,
1132
+ hidden_dim: int = 256,
1133
+ num_blocks: int = 4,
1134
+ chroma_hidden_dim: int = 384,
1135
+ ) -> None:
1136
+ """Initialize the multi-ResNet error predictor."""
1137
+ super().__init__()
1138
+
1139
+ # Input: xyY (3) + base prediction (4) = 7
1140
+ self.hue_branch = ComponentResNet(
1141
+ input_dim=7, hidden_dim=hidden_dim, num_blocks=num_blocks
1142
+ )
1143
+ self.value_branch = ComponentResNet(
1144
+ input_dim=7, hidden_dim=hidden_dim, num_blocks=num_blocks
1145
+ )
1146
+ self.chroma_branch = ComponentResNet(
1147
+ input_dim=7, hidden_dim=chroma_hidden_dim, num_blocks=num_blocks
1148
+ )
1149
+ self.code_branch = ComponentResNet(
1150
+ input_dim=7, hidden_dim=hidden_dim, num_blocks=num_blocks
1151
+ )
1152
+
1153
+ def forward(self, x: Tensor) -> Tensor:
1154
+ """
1155
+ Forward pass through all error predictor branches.
1156
+
1157
+ Parameters
1158
+ ----------
1159
+ x : Tensor
1160
+ Combined input [xyY_norm, base_pred_norm] of shape (batch_size, 7).
1161
+
1162
+ Returns
1163
+ -------
1164
+ Tensor
1165
+ Concatenated error corrections [hue, value, chroma, code]
1166
+ of shape (batch_size, 4).
1167
+ """
1168
+ hue_error = self.hue_branch(x)
1169
+ value_error = self.value_branch(x)
1170
+ chroma_error = self.chroma_branch(x)
1171
+ code_error = self.code_branch(x)
1172
+ return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1)
1173
+
1174
+
1175
+ # =============================================================================
1176
+ # Composite Models: Munsell → xyY
1177
+ # =============================================================================
1178
+
1179
+
1180
+ class MultiMLPToxyY(nn.Module):
1181
+ """
1182
+ Multi-MLP for Munsell to xyY conversion.
1183
+
1184
+ Uses 3 independent ComponentMLP branches, one for each xyY component.
1185
+
1186
+ Parameters
1187
+ ----------
1188
+ width_multiplier : float, optional
1189
+ Width multiplier for x and y branches. Default is 1.0.
1190
+ y_width_multiplier : float, optional
1191
+ Width multiplier for Y (luminance) branch. Default is 1.25.
1192
+
1193
+ Attributes
1194
+ ----------
1195
+ x_branch : ComponentMLP
1196
+ MLP for x chromaticity component.
1197
+ y_branch : ComponentMLP
1198
+ MLP for y chromaticity component.
1199
+ Y_branch : ComponentMLP
1200
+ MLP for Y luminance component.
1201
+ """
1202
+
1203
+ def __init__(
1204
+ self, width_multiplier: float = 1.0, y_width_multiplier: float = 1.25
1205
+ ) -> None:
1206
+ """Initialize the multi-MLP model."""
1207
+ super().__init__()
1208
+
1209
+ self.x_branch = ComponentMLP(input_dim=4, width_multiplier=width_multiplier)
1210
+ self.y_branch = ComponentMLP(input_dim=4, width_multiplier=width_multiplier)
1211
+ self.Y_branch = ComponentMLP(
1212
+ input_dim=4, width_multiplier=y_width_multiplier
1213
+ )
1214
+
1215
+ def forward(self, munsell: Tensor) -> Tensor:
1216
+ """
1217
+ Forward pass through all branches.
1218
+
1219
+ Parameters
1220
+ ----------
1221
+ munsell : Tensor
1222
+ Normalized Munsell specification [hue, value, chroma, code]
1223
+ of shape (batch_size, 4).
1224
+
1225
+ Returns
1226
+ -------
1227
+ Tensor
1228
+ Predicted xyY values [x, y, Y] of shape (batch_size, 3).
1229
+ """
1230
+ x = self.x_branch(munsell)
1231
+ y = self.y_branch(munsell)
1232
+ Y = self.Y_branch(munsell)
1233
+ return torch.cat([x, y, Y], dim=1)
1234
+
1235
+
1236
+ # =============================================================================
1237
+ # Error Predictors: Munsell → xyY
1238
+ # =============================================================================
1239
+
1240
+
1241
+ class MultiMLPErrorPredictorToxyY(nn.Module):
1242
+ """
1243
+ Multi-MLP error predictor for Munsell to xyY conversion.
1244
+
1245
+ Uses 3 independent ComponentErrorPredictor branches, one for each
1246
+ xyY component error.
1247
+
1248
+ Parameters
1249
+ ----------
1250
+ width_multiplier : float, optional
1251
+ Width multiplier for all branches. Default is 1.0.
1252
+
1253
+ Attributes
1254
+ ----------
1255
+ x_branch : ComponentErrorPredictor
1256
+ Error predictor for x chromaticity component.
1257
+ y_branch : ComponentErrorPredictor
1258
+ Error predictor for y chromaticity component.
1259
+ Y_branch : ComponentErrorPredictor
1260
+ Error predictor for Y luminance component.
1261
+ """
1262
+
1263
+ def __init__(self, width_multiplier: float = 1.0) -> None:
1264
+ """Initialize the multi-head error predictor."""
1265
+ super().__init__()
1266
+
1267
+ self.x_branch = ComponentErrorPredictor(
1268
+ input_dim=7, width_multiplier=width_multiplier
1269
+ )
1270
+ self.y_branch = ComponentErrorPredictor(
1271
+ input_dim=7, width_multiplier=width_multiplier
1272
+ )
1273
+ self.Y_branch = ComponentErrorPredictor(
1274
+ input_dim=7, width_multiplier=width_multiplier
1275
+ )
1276
+
1277
+ def forward(self, combined_input: Tensor) -> Tensor:
1278
+ """
1279
+ Forward pass through all error predictor branches.
1280
+
1281
+ Parameters
1282
+ ----------
1283
+ combined_input : Tensor
1284
+ Combined input [munsell_norm, base_pred] of shape (batch_size, 7).
1285
+
1286
+ Returns
1287
+ -------
1288
+ Tensor
1289
+ Concatenated error corrections [x, y, Y] of shape (batch_size, 3).
1290
+ """
1291
+ x_error = self.x_branch(combined_input)
1292
+ y_error = self.y_branch(combined_input)
1293
+ Y_error = self.Y_branch(combined_input)
1294
+ return torch.cat([x_error, y_error, Y_error], dim=1)
learning_munsell/training/from_xyY/__init__.py ADDED
@@ -0,0 +1 @@
 
 
1
+ """Training scripts for xyY to Munsell conversion."""
learning_munsell/training/from_xyY/hyperparameter_search_error_predictor.py ADDED
@@ -0,0 +1,503 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hyperparameter search for Multi-Error Predictor using Optuna.
3
+
4
+ Optimizes:
5
+ - Learning rate
6
+ - Batch size
7
+ - Chroma width multiplier
8
+ - Loss function weights (MSE, MAE, log penalty, Huber)
9
+ - Huber delta
10
+ - Dropout
11
+
12
+ Objective: Minimize validation loss
13
+ """
14
+
15
+ import logging
16
+ from datetime import datetime
17
+ from pathlib import Path
18
+
19
+ import mlflow
20
+ import numpy as np
21
+ import onnxruntime as ort
22
+ import optuna
23
+ import torch
24
+ from numpy.typing import NDArray
25
+ from optuna.trial import Trial
26
+ from torch import nn, optim
27
+ from torch.utils.data import DataLoader, TensorDataset
28
+
29
+ from learning_munsell import PROJECT_ROOT
30
+ from learning_munsell.models.networks import (
31
+ ComponentErrorPredictor,
32
+ MultiMLPErrorPredictorToMunsell,
33
+ )
34
+ from learning_munsell.utilities.common import setup_mlflow_experiment
35
+ from learning_munsell.utilities.data import normalize_xyY, normalize_munsell
36
+
37
+ LOGGER = logging.getLogger(__name__)
38
+
39
+
40
+ def precision_focused_loss(
41
+ pred: torch.Tensor,
42
+ target: torch.Tensor,
43
+ mse_weight: float = 1.0,
44
+ mae_weight: float = 0.5,
45
+ log_weight: float = 0.3,
46
+ huber_weight: float = 0.5,
47
+ huber_delta: float = 0.01,
48
+ ) -> torch.Tensor:
49
+ """
50
+ Precision-focused loss function with configurable weights.
51
+
52
+ Combines multiple loss components to encourage accurate error prediction:
53
+ - MSE: Standard mean squared error
54
+ - MAE: Mean absolute error for robustness
55
+ - Log penalty: Penalizes small errors more heavily
56
+ - Huber loss: Robust to outliers with adjustable delta
57
+
58
+ Parameters
59
+ ----------
60
+ pred : torch.Tensor
61
+ Predicted values, shape (batch_size, n_components).
62
+ target : torch.Tensor
63
+ Target values, shape (batch_size, n_components).
64
+ mse_weight : float, optional
65
+ Weight for MSE component. Default is 1.0.
66
+ mae_weight : float, optional
67
+ Weight for MAE component. Default is 0.5.
68
+ log_weight : float, optional
69
+ Weight for logarithmic penalty component. Default is 0.3.
70
+ huber_weight : float, optional
71
+ Weight for Huber loss component. Default is 0.5.
72
+ huber_delta : float, optional
73
+ Delta parameter for Huber loss transition point. Default is 0.01.
74
+
75
+ Returns
76
+ -------
77
+ torch.Tensor
78
+ Weighted combination of loss components, scalar tensor.
79
+ """
80
+
81
+ mse = torch.mean((pred - target) ** 2)
82
+ mae = torch.mean(torch.abs(pred - target))
83
+ log_penalty = torch.mean(torch.log1p(torch.abs(pred - target) * 1000.0))
84
+
85
+ abs_error = torch.abs(pred - target)
86
+ huber = torch.where(
87
+ abs_error <= huber_delta,
88
+ 0.5 * abs_error**2,
89
+ huber_delta * (abs_error - 0.5 * huber_delta),
90
+ )
91
+ huber_loss = torch.mean(huber)
92
+
93
+ return (
94
+ mse_weight * mse
95
+ + mae_weight * mae
96
+ + log_weight * log_penalty
97
+ + huber_weight * huber_loss
98
+ )
99
+
100
+
101
+ def load_base_model(
102
+ model_path: Path, params_path: Path
103
+ ) -> tuple[ort.InferenceSession, dict, dict]:
104
+ """
105
+ Load the base ONNX model and its normalization parameters.
106
+
107
+ Parameters
108
+ ----------
109
+ model_path : Path
110
+ Path to the base model ONNX file.
111
+ params_path : Path
112
+ Path to the normalization parameters NPZ file.
113
+
114
+ Returns
115
+ -------
116
+ ort.InferenceSession
117
+ ONNX Runtime inference session for the base model.
118
+ dict
119
+ Input normalization parameters (x_range, y_range, Y_range).
120
+ dict
121
+ Output normalization parameters (hue_range, value_range, chroma_range, code_range).
122
+ """
123
+ session = ort.InferenceSession(str(model_path))
124
+ params = np.load(params_path, allow_pickle=True)
125
+ return session, params["input_params"].item(), params["output_params"].item()
126
+
127
+
128
+ def train_epoch(
129
+ model: nn.Module,
130
+ dataloader: DataLoader,
131
+ optimizer: optim.Optimizer,
132
+ device: torch.device,
133
+ loss_params: dict[str, float],
134
+ ) -> float:
135
+ """
136
+ Train the model for one epoch.
137
+
138
+ Parameters
139
+ ----------
140
+ model : nn.Module
141
+ Error predictor model to train.
142
+ dataloader : DataLoader
143
+ DataLoader providing training batches.
144
+ optimizer : optim.Optimizer
145
+ Optimizer for updating model parameters.
146
+ device : torch.device
147
+ Device to run training on (CPU, CUDA, or MPS).
148
+ loss_params : dict of str to float
149
+ Parameters for precision_focused_loss function.
150
+
151
+ Returns
152
+ -------
153
+ float
154
+ Average training loss over the epoch.
155
+ """
156
+ model.train()
157
+ total_loss = 0.0
158
+
159
+ for X_batch, y_batch in dataloader:
160
+ X_batch = X_batch.to(device)
161
+ y_batch = y_batch.to(device)
162
+ outputs = model(X_batch)
163
+ loss = precision_focused_loss(outputs, y_batch, **loss_params)
164
+
165
+ optimizer.zero_grad()
166
+ loss.backward()
167
+ optimizer.step()
168
+
169
+ total_loss += loss.item()
170
+
171
+ return total_loss / len(dataloader)
172
+
173
+
174
+ def validate(
175
+ model: nn.Module,
176
+ dataloader: DataLoader,
177
+ device: torch.device,
178
+ loss_params: dict[str, float],
179
+ ) -> float:
180
+ """
181
+ Validate the model on the validation set.
182
+
183
+ Parameters
184
+ ----------
185
+ model : nn.Module
186
+ Error predictor model to validate.
187
+ dataloader : DataLoader
188
+ DataLoader providing validation batches.
189
+ device : torch.device
190
+ Device to run validation on (CPU, CUDA, or MPS).
191
+ loss_params : dict of str to float
192
+ Parameters for precision_focused_loss function.
193
+
194
+ Returns
195
+ -------
196
+ float
197
+ Average validation loss.
198
+ """
199
+ model.eval()
200
+ total_loss = 0.0
201
+
202
+ with torch.no_grad():
203
+ for X_batch, y_batch in dataloader:
204
+ X_batch = X_batch.to(device)
205
+ y_batch = y_batch.to(device)
206
+ outputs = model(X_batch)
207
+ loss = precision_focused_loss(outputs, y_batch, **loss_params)
208
+
209
+ total_loss += loss.item()
210
+
211
+ return total_loss / len(dataloader)
212
+
213
+
214
+ def objective(trial: Trial) -> float:
215
+ """
216
+ Optuna objective function to minimize validation loss.
217
+
218
+ This function defines the hyperparameter search space and training
219
+ procedure for each trial. It optimizes:
220
+ - Learning rate (5e-4 to 1e-3, log scale)
221
+ - Batch size (512 or 1024)
222
+ - Chroma branch width multiplier (1.0 to 1.5)
223
+ - Dropout rate (0.1 to 0.2)
224
+ - Loss function weights (MSE, Huber)
225
+ - Huber delta parameter (0.01 to 0.05)
226
+
227
+ Parameters
228
+ ----------
229
+ trial : Trial
230
+ Optuna trial object for suggesting hyperparameters.
231
+
232
+ Returns
233
+ -------
234
+ float
235
+ Best validation loss achieved during training.
236
+
237
+ Raises
238
+ ------
239
+ FileNotFoundError
240
+ If base model or training data files are not found.
241
+ optuna.TrialPruned
242
+ If trial is pruned based on intermediate results.
243
+ """
244
+
245
+ # Hyperparameters to optimize - constrained based on Trial 0 insights
246
+ lr = trial.suggest_float("lr", 5e-4, 1e-3, log=True) # Higher LR worked well
247
+ batch_size = trial.suggest_categorical(
248
+ "batch_size", [512, 1024]
249
+ ) # Smaller batches better
250
+ chroma_width = trial.suggest_float(
251
+ "chroma_width", 1.0, 1.5, step=0.25
252
+ ) # Smaller worked
253
+ dropout = trial.suggest_float("dropout", 0.1, 0.2, step=0.05)
254
+
255
+ # Simplified loss - just MSE + optional small Huber (no log penalty!)
256
+ mse_weight = trial.suggest_float("mse_weight", 1.0, 2.0, step=0.25)
257
+ huber_weight = trial.suggest_float("huber_weight", 0.0, 0.5, step=0.25)
258
+ huber_delta = trial.suggest_float("huber_delta", 0.01, 0.05, step=0.01)
259
+
260
+ loss_params = {
261
+ "mse_weight": mse_weight,
262
+ "mae_weight": 0.0, # Fixed at 0
263
+ "log_weight": 0.0, # Fixed at 0 (was causing scale issues)
264
+ "huber_weight": huber_weight,
265
+ "huber_delta": huber_delta,
266
+ }
267
+
268
+ LOGGER.info("")
269
+ LOGGER.info("=" * 80)
270
+ LOGGER.info("Trial %d", trial.number)
271
+ LOGGER.info("=" * 80)
272
+ LOGGER.info(" lr: %.6f", lr)
273
+ LOGGER.info(" batch_size: %d", batch_size)
274
+ LOGGER.info(" chroma_width: %.2f", chroma_width)
275
+ LOGGER.info(" dropout: %.2f", dropout)
276
+ LOGGER.info(" mse_weight: %.2f", mse_weight)
277
+ LOGGER.info(" huber_weight: %.2f", huber_weight)
278
+ LOGGER.info(" huber_delta: %.3f", huber_delta)
279
+
280
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
281
+
282
+ # Load base model and data
283
+ model_dir = PROJECT_ROOT / "models" / "from_xyY"
284
+ data_dir = PROJECT_ROOT / "data"
285
+
286
+ base_model_path = model_dir / "multi_mlp.onnx"
287
+ params_path = model_dir / "multi_mlp_normalization_params.npz"
288
+ cache_file = data_dir / "training_data.npz"
289
+
290
+ if not base_model_path.exists():
291
+ msg = f"Base model not found: {base_model_path}"
292
+ raise FileNotFoundError(msg)
293
+
294
+ base_session, input_params, output_params = load_base_model(
295
+ base_model_path, params_path
296
+ )
297
+
298
+ # Load data
299
+ data = np.load(cache_file)
300
+ X_train = data["X_train"]
301
+ y_train = data["y_train"]
302
+ X_val = data["X_val"]
303
+ y_val = data["y_val"]
304
+
305
+ # Normalize and generate base predictions
306
+ X_train_norm = normalize_xyY(X_train, input_params)
307
+ y_train_norm = normalize_munsell(y_train, output_params)
308
+ base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
309
+
310
+ X_val_norm = normalize_xyY(X_val, input_params)
311
+ y_val_norm = normalize_munsell(y_val, output_params)
312
+ base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
313
+
314
+ # Compute errors
315
+ error_train = y_train_norm - base_pred_train_norm
316
+ error_val = y_val_norm - base_pred_val_norm
317
+
318
+ # Combined input
319
+ X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
320
+ X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
321
+
322
+ # PyTorch tensors
323
+ X_train_t = torch.FloatTensor(X_train_combined)
324
+ error_train_t = torch.FloatTensor(error_train)
325
+ X_val_t = torch.FloatTensor(X_val_combined)
326
+ error_val_t = torch.FloatTensor(error_val)
327
+
328
+ # Data loaders
329
+ train_dataset = TensorDataset(X_train_t, error_train_t)
330
+ val_dataset = TensorDataset(X_val_t, error_val_t)
331
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
332
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
333
+
334
+ # Initialize model
335
+ model = MultiMLPErrorPredictorToMunsell(chroma_width=chroma_width, dropout=dropout).to(
336
+ device
337
+ )
338
+
339
+ total_params = sum(p.numel() for p in model.parameters())
340
+ LOGGER.info(" Total parameters: %s", f"{total_params:,}")
341
+
342
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
343
+
344
+ # MLflow setup
345
+ run_name = setup_mlflow_experiment(
346
+ "from_xyY", f"hparam_error_predictor_trial_{trial.number}"
347
+ )
348
+
349
+ # Training loop
350
+ num_epochs = 100
351
+ patience = 15
352
+ best_val_loss = float("inf")
353
+ patience_counter = 0
354
+
355
+ with mlflow.start_run(run_name=run_name):
356
+ mlflow.log_params(
357
+ {
358
+ "trial": trial.number,
359
+ "lr": lr,
360
+ "batch_size": batch_size,
361
+ "chroma_width": chroma_width,
362
+ "dropout": dropout,
363
+ "mse_weight": mse_weight,
364
+ "huber_weight": huber_weight,
365
+ "huber_delta": huber_delta,
366
+ "total_params": total_params,
367
+ }
368
+ )
369
+
370
+ for epoch in range(num_epochs):
371
+ train_loss = train_epoch(
372
+ model, train_loader, optimizer, device, loss_params
373
+ )
374
+ val_loss = validate(model, val_loader, device, loss_params)
375
+
376
+ mlflow.log_metrics(
377
+ {
378
+ "train_loss": train_loss,
379
+ "val_loss": val_loss,
380
+ },
381
+ step=epoch,
382
+ )
383
+
384
+ if (epoch + 1) % 10 == 0:
385
+ LOGGER.info(
386
+ " Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
387
+ epoch + 1,
388
+ num_epochs,
389
+ train_loss,
390
+ val_loss,
391
+ )
392
+
393
+ if val_loss < best_val_loss:
394
+ best_val_loss = val_loss
395
+ patience_counter = 0
396
+ else:
397
+ patience_counter += 1
398
+ if patience_counter >= patience:
399
+ LOGGER.info(" Early stopping at epoch %d", epoch + 1)
400
+ break
401
+
402
+ trial.report(val_loss, epoch)
403
+
404
+ if trial.should_prune():
405
+ LOGGER.info(" Trial pruned at epoch %d", epoch + 1)
406
+ mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch})
407
+ raise optuna.TrialPruned
408
+
409
+ # Log final results
410
+ mlflow.log_metrics(
411
+ {
412
+ "best_val_loss": best_val_loss,
413
+ "final_train_loss": train_loss,
414
+ "final_epoch": epoch + 1,
415
+ }
416
+ )
417
+
418
+ LOGGER.info(" Final validation loss: %.6f", best_val_loss)
419
+
420
+ return best_val_loss
421
+
422
+
423
+ def main() -> None:
424
+ """
425
+ Run hyperparameter search for Multi-MLP Error Predictor.
426
+
427
+ Performs systematic hyperparameter optimization using Optuna with:
428
+ - MedianPruner for early stopping of unpromising trials
429
+ - 15 total trials
430
+ - MLflow logging for each trial
431
+ - Result visualization and saving
432
+
433
+ The search aims to find optimal hyperparameters for predicting errors
434
+ in a base Munsell prediction model, which can then be used to improve
435
+ predictions by correcting systematic biases.
436
+ """
437
+
438
+ LOGGER.info("=" * 80)
439
+ LOGGER.info("Multi-Error Predictor Hyperparameter Search with Optuna")
440
+ LOGGER.info("=" * 80)
441
+
442
+ study = optuna.create_study(
443
+ direction="minimize",
444
+ study_name="multi_mlp_error_predictor_hparam_search",
445
+ pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10),
446
+ )
447
+
448
+ n_trials = 15
449
+
450
+ LOGGER.info("")
451
+ LOGGER.info("Starting hyperparameter search with %d trials...", n_trials)
452
+ LOGGER.info("")
453
+
454
+ study.optimize(objective, n_trials=n_trials, timeout=None)
455
+
456
+ # Print results
457
+ LOGGER.info("")
458
+ LOGGER.info("=" * 80)
459
+ LOGGER.info("Hyperparameter Search Results")
460
+ LOGGER.info("=" * 80)
461
+ LOGGER.info("")
462
+ LOGGER.info("Best trial:")
463
+ LOGGER.info(" Value (val_loss): %.6f", study.best_value)
464
+ LOGGER.info("")
465
+ LOGGER.info("Best hyperparameters:")
466
+ for key, value in study.best_params.items():
467
+ LOGGER.info(" %s: %s", key, value)
468
+
469
+ # Save results
470
+ results_dir = PROJECT_ROOT / "results" / "from_xyY"
471
+ results_dir.mkdir(exist_ok=True)
472
+
473
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
474
+ results_file = results_dir / f"error_predictor_hparam_search_{timestamp}.txt"
475
+
476
+ with open(results_file, "w") as f:
477
+ f.write("=" * 80 + "\n")
478
+ f.write("Multi-Error Predictor Hyperparameter Search Results\n")
479
+ f.write("=" * 80 + "\n\n")
480
+ f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
481
+ f.write(f"Number of trials: {len(study.trials)}\n")
482
+ f.write(f"Best validation loss: {study.best_value:.6f}\n\n")
483
+ f.write("Best hyperparameters:\n")
484
+ for key, value in study.best_params.items():
485
+ f.write(f" {key}: {value}\n")
486
+ f.write("\n\nAll trials:\n")
487
+ f.write("-" * 80 + "\n")
488
+
489
+ for trial in study.trials:
490
+ f.write(f"\nTrial {trial.number}:\n")
491
+ f.write(f" Value: {trial.value:.6f if trial.value else 'Pruned'}\n")
492
+ f.write(" Params:\n")
493
+ for key, value in trial.params.items():
494
+ f.write(f" {key}: {value}\n")
495
+
496
+ LOGGER.info("")
497
+ LOGGER.info("Results saved to: %s", results_file)
498
+
499
+
500
+ if __name__ == "__main__":
501
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
502
+
503
+ main()
learning_munsell/training/from_xyY/hyperparameter_search_multi_head.py ADDED
@@ -0,0 +1,541 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hyperparameter search for Multi-Head model (xyY to Munsell) using Optuna.
3
+
4
+ Optimizes:
5
+ - Learning rate
6
+ - Batch size
7
+ - Encoder width multiplier (shared encoder capacity)
8
+ - Head width multiplier (component-specific head capacity)
9
+ - Chroma head width (specialized for chroma prediction)
10
+ - Dropout
11
+ - Weight decay
12
+
13
+ Objective: Minimize validation loss
14
+ """
15
+
16
+ from __future__ import annotations
17
+
18
+ import logging
19
+ from datetime import datetime
20
+
21
+ import matplotlib.pyplot as plt
22
+ import mlflow
23
+ import numpy as np
24
+ import optuna
25
+ import torch
26
+ from optuna.trial import Trial
27
+ from torch import nn, optim
28
+ from torch.utils.data import DataLoader, TensorDataset
29
+
30
+ from learning_munsell import PROJECT_ROOT
31
+ from learning_munsell.utilities.common import setup_mlflow_experiment
32
+ from learning_munsell.utilities.data import MUNSELL_NORMALIZATION_PARAMS, normalize_munsell
33
+ from learning_munsell.utilities.losses import weighted_mse_loss
34
+ from learning_munsell.utilities.training import train_epoch, validate
35
+
36
+ LOGGER = logging.getLogger(__name__)
37
+
38
+
39
+ class MultiHeadParametric(nn.Module):
40
+ """
41
+ Parametric Multi-Head model for hyperparameter search (xyY to Munsell).
42
+
43
+ This model uses a shared encoder to extract general color space features
44
+ from xyY inputs, followed by component-specific heads for predicting
45
+ each Munsell component independently.
46
+
47
+ Architecture:
48
+ - Shared encoder: 3 → h1 → h2 → h3 (scaled by encoder_width)
49
+ - hue, value, code heads: h3 → h2' → h1' → 1 (scaled by head_width)
50
+ - chroma head: h3 → h2'' → h1'' → 1 (scaled by chroma_head_width)
51
+
52
+ Parameters
53
+ ----------
54
+ encoder_width : float, optional
55
+ Width multiplier for shared encoder layers. Default is 1.0.
56
+ Base dimensions: h1=128, h2=256, h3=512.
57
+ head_width : float, optional
58
+ Width multiplier for hue, value, and code heads. Default is 1.0.
59
+ Base dimensions: h1=128, h2=256.
60
+ chroma_head_width : float, optional
61
+ Width multiplier for chroma head (typically wider). Default is 1.0.
62
+ Base dimensions: h1=128, h2=256, h3=384.
63
+ dropout : float, optional
64
+ Dropout rate applied after hidden layers. Default is 0.0.
65
+ """
66
+
67
+ def __init__(
68
+ self,
69
+ encoder_width: float = 1.0,
70
+ head_width: float = 1.0,
71
+ chroma_head_width: float = 1.0,
72
+ dropout: float = 0.0,
73
+ ) -> None:
74
+ super().__init__()
75
+
76
+ # Encoder dimensions (shared)
77
+ e_h1 = int(128 * encoder_width)
78
+ e_h2 = int(256 * encoder_width)
79
+ e_h3 = int(512 * encoder_width)
80
+
81
+ # Head dimensions (component-specific)
82
+ h_h1 = int(128 * head_width)
83
+ h_h2 = int(256 * head_width)
84
+
85
+ # Chroma head dimensions (specialized)
86
+ c_h1 = int(128 * chroma_head_width)
87
+ c_h2 = int(256 * chroma_head_width)
88
+ c_h3 = int(384 * chroma_head_width)
89
+
90
+ # Shared encoder - learns general color space features
91
+ encoder_layers = [
92
+ nn.Linear(3, e_h1),
93
+ nn.ReLU(),
94
+ nn.BatchNorm1d(e_h1),
95
+ ]
96
+
97
+ if dropout > 0:
98
+ encoder_layers.append(nn.Dropout(dropout))
99
+
100
+ encoder_layers.extend(
101
+ [
102
+ nn.Linear(e_h1, e_h2),
103
+ nn.ReLU(),
104
+ nn.BatchNorm1d(e_h2),
105
+ ]
106
+ )
107
+
108
+ if dropout > 0:
109
+ encoder_layers.append(nn.Dropout(dropout))
110
+
111
+ encoder_layers.extend(
112
+ [
113
+ nn.Linear(e_h2, e_h3),
114
+ nn.ReLU(),
115
+ nn.BatchNorm1d(e_h3),
116
+ ]
117
+ )
118
+
119
+ if dropout > 0:
120
+ encoder_layers.append(nn.Dropout(dropout))
121
+
122
+ self.encoder = nn.Sequential(*encoder_layers)
123
+
124
+ # Component-specific heads (hue, value, code)
125
+ def create_head() -> nn.Sequential:
126
+ head_layers = [
127
+ nn.Linear(e_h3, h_h2),
128
+ nn.ReLU(),
129
+ nn.BatchNorm1d(h_h2),
130
+ ]
131
+
132
+ if dropout > 0:
133
+ head_layers.append(nn.Dropout(dropout))
134
+
135
+ head_layers.extend(
136
+ [
137
+ nn.Linear(h_h2, h_h1),
138
+ nn.ReLU(),
139
+ nn.BatchNorm1d(h_h1),
140
+ ]
141
+ )
142
+
143
+ if dropout > 0:
144
+ head_layers.append(nn.Dropout(dropout))
145
+
146
+ head_layers.append(nn.Linear(h_h1, 1))
147
+
148
+ return nn.Sequential(*head_layers)
149
+
150
+ self.hue_head = create_head()
151
+ self.value_head = create_head()
152
+ self.code_head = create_head()
153
+
154
+ # Chroma head - wider for harder task
155
+ chroma_layers = [
156
+ nn.Linear(e_h3, c_h3),
157
+ nn.ReLU(),
158
+ nn.BatchNorm1d(c_h3),
159
+ ]
160
+
161
+ if dropout > 0:
162
+ chroma_layers.append(nn.Dropout(dropout))
163
+
164
+ chroma_layers.extend(
165
+ [
166
+ nn.Linear(c_h3, c_h2),
167
+ nn.ReLU(),
168
+ nn.BatchNorm1d(c_h2),
169
+ ]
170
+ )
171
+
172
+ if dropout > 0:
173
+ chroma_layers.append(nn.Dropout(dropout))
174
+
175
+ chroma_layers.extend(
176
+ [
177
+ nn.Linear(c_h2, c_h1),
178
+ nn.ReLU(),
179
+ nn.BatchNorm1d(c_h1),
180
+ ]
181
+ )
182
+
183
+ if dropout > 0:
184
+ chroma_layers.append(nn.Dropout(dropout))
185
+
186
+ chroma_layers.append(nn.Linear(c_h1, 1))
187
+
188
+ self.chroma_head = nn.Sequential(*chroma_layers)
189
+
190
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
191
+ """
192
+ Forward pass through shared encoder and component-specific heads.
193
+
194
+ Parameters
195
+ ----------
196
+ x : torch.Tensor
197
+ Input tensor of shape (batch_size, 3) containing normalized
198
+ xyY values.
199
+
200
+ Returns
201
+ -------
202
+ torch.Tensor
203
+ Predicted Munsell components, shape (batch_size, 4).
204
+ Output order: [hue, value, chroma, code].
205
+ """
206
+ # Shared feature extraction
207
+ features = self.encoder(x)
208
+
209
+ # Component-specific predictions
210
+ hue = self.hue_head(features)
211
+ value = self.value_head(features)
212
+ chroma = self.chroma_head(features)
213
+ code = self.code_head(features)
214
+
215
+ # Concatenate: [hue, value, chroma, code]
216
+ return torch.cat([hue, value, chroma, code], dim=1)
217
+
218
+
219
+ def objective(trial: Trial) -> float:
220
+ """
221
+ Optuna objective function to minimize validation loss.
222
+
223
+ This function defines the hyperparameter search space and training
224
+ procedure for each trial. It optimizes:
225
+ - Learning rate (1e-4 to 1e-3, log scale)
226
+ - Batch size (256, 512, or 1024)
227
+ - Encoder width multiplier (0.75 to 1.5)
228
+ - Head width multiplier (0.75 to 1.5)
229
+ - Chroma head width multiplier (1.0 to 1.75)
230
+ - Dropout rate (0.0 to 0.2)
231
+ - Weight decay (1e-5 to 1e-3, log scale)
232
+
233
+ Parameters
234
+ ----------
235
+ trial : Trial
236
+ Optuna trial object for suggesting hyperparameters.
237
+
238
+ Returns
239
+ -------
240
+ float
241
+ Best validation loss achieved during training.
242
+
243
+ Raises
244
+ ------
245
+ optuna.TrialPruned
246
+ If trial is pruned based on intermediate results.
247
+ """
248
+
249
+ # Suggest hyperparameters
250
+ lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True)
251
+ batch_size = trial.suggest_categorical("batch_size", [256, 512, 1024])
252
+ encoder_width = trial.suggest_float("encoder_width", 0.75, 1.5, step=0.25)
253
+ head_width = trial.suggest_float("head_width", 0.75, 1.5, step=0.25)
254
+ chroma_head_width = trial.suggest_float("chroma_head_width", 1.0, 1.75, step=0.25)
255
+ dropout = trial.suggest_float("dropout", 0.0, 0.2, step=0.05)
256
+ weight_decay = trial.suggest_float("weight_decay", 1e-5, 1e-3, log=True)
257
+
258
+ LOGGER.info("")
259
+ LOGGER.info("=" * 80)
260
+ LOGGER.info("Trial %d", trial.number)
261
+ LOGGER.info("=" * 80)
262
+ LOGGER.info(" lr: %.6f", lr)
263
+ LOGGER.info(" batch_size: %d", batch_size)
264
+ LOGGER.info(" encoder_width: %.2f", encoder_width)
265
+ LOGGER.info(" head_width: %.2f", head_width)
266
+ LOGGER.info(" chroma_head_width: %.2f", chroma_head_width)
267
+ LOGGER.info(" dropout: %.2f", dropout)
268
+ LOGGER.info(" weight_decay: %.6f", weight_decay)
269
+
270
+ # Set device
271
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
272
+ LOGGER.info(" device: %s", device)
273
+
274
+ # Load data
275
+ data_dir = PROJECT_ROOT / "data"
276
+ cache_file = data_dir / "training_data.npz"
277
+ data = np.load(cache_file)
278
+
279
+ X_train = data["X_train"]
280
+ y_train = data["y_train"]
281
+ X_val = data["X_val"]
282
+ y_val = data["y_val"]
283
+
284
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
285
+ output_params = MUNSELL_NORMALIZATION_PARAMS
286
+ y_train_norm = normalize_munsell(y_train, output_params)
287
+ y_val_norm = normalize_munsell(y_val, output_params)
288
+
289
+ # Convert to tensors
290
+ X_train_t = torch.from_numpy(X_train).float()
291
+ y_train_t = torch.from_numpy(y_train_norm).float()
292
+ X_val_t = torch.from_numpy(X_val).float()
293
+ y_val_t = torch.from_numpy(y_val_norm).float()
294
+
295
+ train_loader = DataLoader(
296
+ TensorDataset(X_train_t, y_train_t), batch_size=batch_size, shuffle=True
297
+ )
298
+ val_loader = DataLoader(
299
+ TensorDataset(X_val_t, y_val_t), batch_size=batch_size, shuffle=False
300
+ )
301
+
302
+ LOGGER.info(
303
+ " Training samples: %d, Validation samples: %d", len(X_train_t), len(X_val_t)
304
+ )
305
+
306
+ # Initialize model
307
+ model = MultiHeadParametric(
308
+ encoder_width=encoder_width,
309
+ head_width=head_width,
310
+ chroma_head_width=chroma_head_width,
311
+ dropout=dropout,
312
+ ).to(device)
313
+
314
+ # Count parameters
315
+ total_params = sum(p.numel() for p in model.parameters())
316
+ LOGGER.info(" Total parameters: %s", f"{total_params:,}")
317
+
318
+ # Training setup
319
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
320
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=100)
321
+
322
+ # MLflow setup
323
+ run_name = setup_mlflow_experiment(
324
+ "from_xyY", f"hparam_multi_head_trial_{trial.number}"
325
+ )
326
+
327
+ # Training loop with early stopping
328
+ num_epochs = 100 # Reduced for hyperparameter search
329
+ patience = 15
330
+ best_val_loss = float("inf")
331
+ patience_counter = 0
332
+
333
+ with mlflow.start_run(run_name=run_name):
334
+ mlflow.log_params(
335
+ {
336
+ "trial": trial.number,
337
+ "lr": lr,
338
+ "batch_size": batch_size,
339
+ "encoder_width": encoder_width,
340
+ "head_width": head_width,
341
+ "chroma_head_width": chroma_head_width,
342
+ "dropout": dropout,
343
+ "weight_decay": weight_decay,
344
+ "total_params": total_params,
345
+ }
346
+ )
347
+
348
+ for epoch in range(num_epochs):
349
+ train_loss = train_epoch(
350
+ model, train_loader, optimizer, weighted_mse_loss, device
351
+ )
352
+ val_loss = validate(model, val_loader, weighted_mse_loss, device)
353
+ scheduler.step()
354
+
355
+ # Per-component MAE
356
+ with torch.no_grad():
357
+ pred_val = model(X_val_t.to(device))
358
+ mae = torch.mean(torch.abs(pred_val - y_val_t.to(device)), dim=0).cpu()
359
+
360
+ # Log to MLflow
361
+ mlflow.log_metrics(
362
+ {
363
+ "train_loss": train_loss,
364
+ "val_loss": val_loss,
365
+ "mae_hue": mae[0].item(),
366
+ "mae_value": mae[1].item(),
367
+ "mae_chroma": mae[2].item(),
368
+ "mae_code": mae[3].item(),
369
+ "learning_rate": optimizer.param_groups[0]["lr"],
370
+ },
371
+ step=epoch,
372
+ )
373
+
374
+ if (epoch + 1) % 10 == 0:
375
+ LOGGER.info(
376
+ " Epoch %03d/%d - Train: %.6f, Val: %.6f - "
377
+ "MAE: hue=%.6f, value=%.6f, chroma=%.6f, code=%.6f",
378
+ epoch + 1,
379
+ num_epochs,
380
+ train_loss,
381
+ val_loss,
382
+ mae[0],
383
+ mae[1],
384
+ mae[2],
385
+ mae[3],
386
+ )
387
+
388
+ # Early stopping
389
+ if val_loss < best_val_loss:
390
+ best_val_loss = val_loss
391
+ patience_counter = 0
392
+ else:
393
+ patience_counter += 1
394
+ if patience_counter >= patience:
395
+ LOGGER.info(" Early stopping at epoch %d", epoch + 1)
396
+ break
397
+
398
+ # Report intermediate value for pruning
399
+ trial.report(val_loss, epoch)
400
+
401
+ # Handle pruning
402
+ if trial.should_prune():
403
+ LOGGER.info(" Trial pruned at epoch %d", epoch + 1)
404
+ mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch})
405
+ raise optuna.TrialPruned
406
+
407
+ # Log final results
408
+ mlflow.log_metrics(
409
+ {
410
+ "best_val_loss": best_val_loss,
411
+ "final_train_loss": train_loss,
412
+ "final_mae_hue": mae[0].item(),
413
+ "final_mae_value": mae[1].item(),
414
+ "final_mae_chroma": mae[2].item(),
415
+ "final_mae_code": mae[3].item(),
416
+ "final_epoch": epoch + 1,
417
+ }
418
+ )
419
+
420
+ LOGGER.info(" Final validation loss: %.6f", best_val_loss)
421
+
422
+ return best_val_loss
423
+
424
+
425
+ def main() -> None:
426
+ """
427
+ Run hyperparameter search for Multi-Head model (xyY to Munsell).
428
+
429
+ Performs systematic hyperparameter optimization using Optuna with:
430
+ - MedianPruner for early stopping of unpromising trials
431
+ - 20 total trials
432
+ - MLflow logging for each trial
433
+ - Result visualization using matplotlib (optimization history,
434
+ parameter importances, parallel coordinate plot)
435
+
436
+ The search aims to find optimal hyperparameters for converting xyY
437
+ color coordinates to Munsell color specifications using a multi-head
438
+ architecture with shared encoder and component-specific heads.
439
+ """
440
+
441
+ LOGGER.info("=" * 80)
442
+ LOGGER.info("Multi-Head (from_xyY) Hyperparameter Search with Optuna")
443
+ LOGGER.info("=" * 80)
444
+
445
+ # Create study
446
+ study = optuna.create_study(
447
+ direction="minimize",
448
+ study_name="multi_head_from_xyY_hparam_search",
449
+ pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10),
450
+ )
451
+
452
+ # Run optimization
453
+ n_trials = 20 # Number of trials to run
454
+
455
+ LOGGER.info("")
456
+ LOGGER.info("Starting hyperparameter search with %d trials...", n_trials)
457
+ LOGGER.info("")
458
+
459
+ study.optimize(objective, n_trials=n_trials, timeout=None)
460
+
461
+ # Print results
462
+ LOGGER.info("")
463
+ LOGGER.info("=" * 80)
464
+ LOGGER.info("Hyperparameter Search Results")
465
+ LOGGER.info("=" * 80)
466
+ LOGGER.info("")
467
+ LOGGER.info("Best trial:")
468
+ LOGGER.info(" Value (val_loss): %.6f", study.best_value)
469
+ LOGGER.info("")
470
+ LOGGER.info("Best hyperparameters:")
471
+ for key, value in study.best_params.items():
472
+ LOGGER.info(" %s: %s", key, value)
473
+
474
+ # Save results
475
+ results_dir = PROJECT_ROOT / "results" / "from_xyY"
476
+ results_dir.mkdir(exist_ok=True, parents=True)
477
+
478
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
479
+ results_file = results_dir / f"hparam_search_multi_head_{timestamp}.txt"
480
+
481
+ with open(results_file, "w") as f:
482
+ f.write("=" * 80 + "\n")
483
+ f.write("Multi-Head (from_xyY) Hyperparameter Search Results\n")
484
+ f.write("=" * 80 + "\n\n")
485
+ f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
486
+ f.write(f"Number of trials: {len(study.trials)}\n")
487
+ f.write(f"Best validation loss: {study.best_value:.6f}\n\n")
488
+ f.write("Best hyperparameters:\n")
489
+ for key, value in study.best_params.items():
490
+ f.write(f" {key}: {value}\n")
491
+ f.write("\n\nAll trials:\n")
492
+ f.write("-" * 80 + "\n")
493
+
494
+ for t in study.trials:
495
+ f.write(f"\nTrial {t.number}:\n")
496
+ if t.value is not None:
497
+ f.write(f" Value: {t.value:.6f}\n")
498
+ else:
499
+ f.write(" Value: Pruned\n")
500
+ f.write(" Params:\n")
501
+ for key, value in t.params.items():
502
+ f.write(f" {key}: {value}\n")
503
+
504
+ LOGGER.info("")
505
+ LOGGER.info("Results saved to: %s", results_file)
506
+
507
+ # Generate visualizations using matplotlib
508
+ from optuna.visualization.matplotlib import (
509
+ plot_optimization_history,
510
+ plot_param_importances,
511
+ plot_parallel_coordinate,
512
+ )
513
+
514
+ # Optimization history
515
+ ax = plot_optimization_history(study)
516
+ ax.figure.savefig(
517
+ results_dir / f"optimization_history_multi_head_{timestamp}.png", dpi=150
518
+ )
519
+ plt.close(ax.figure)
520
+
521
+ # Parameter importances
522
+ ax = plot_param_importances(study)
523
+ ax.figure.savefig(
524
+ results_dir / f"param_importances_multi_head_{timestamp}.png", dpi=150
525
+ )
526
+ plt.close(ax.figure)
527
+
528
+ # Parallel coordinate plot
529
+ ax = plot_parallel_coordinate(study)
530
+ ax.figure.savefig(
531
+ results_dir / f"parallel_coordinate_multi_head_{timestamp}.png", dpi=150
532
+ )
533
+ plt.close(ax.figure)
534
+
535
+ LOGGER.info("Visualizations saved to: %s", results_dir)
536
+
537
+
538
+ if __name__ == "__main__":
539
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
540
+
541
+ main()
learning_munsell/training/from_xyY/hyperparameter_search_multi_head_error_predictor.py ADDED
@@ -0,0 +1,552 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hyperparameter search for Multi-Head Error Predictor using Optuna.
3
+
4
+ Optimizes:
5
+ - Learning rate
6
+ - Batch size
7
+ - Width multipliers for each component branch (hue, value, chroma, code)
8
+ - Loss function component weights
9
+
10
+ Objective: Minimize validation loss (combined base + error predictor)
11
+ """
12
+
13
+ from __future__ import annotations
14
+
15
+ import logging
16
+ from datetime import datetime
17
+ from pathlib import Path
18
+
19
+ import matplotlib.pyplot as plt
20
+ import mlflow
21
+ import numpy as np
22
+ import onnxruntime as ort
23
+ import optuna
24
+ import torch
25
+ from numpy.typing import NDArray
26
+ from optuna.trial import Trial
27
+ from torch import nn, optim
28
+ from torch.utils.data import DataLoader, TensorDataset
29
+
30
+ from learning_munsell import PROJECT_ROOT
31
+ from learning_munsell.models.networks import ComponentErrorPredictor
32
+ from learning_munsell.utilities.common import setup_mlflow_experiment
33
+ from learning_munsell.utilities.data import normalize_xyY, normalize_munsell
34
+ from learning_munsell.utilities.training import train_epoch, validate
35
+
36
+ LOGGER = logging.getLogger(__name__)
37
+
38
+
39
+ class MultiHeadErrorPredictorParametric(nn.Module):
40
+ """
41
+ Parametric Multi-Head error predictor with 4 independent branches.
42
+
43
+ This model consists of four independent ComponentErrorPredictor
44
+ networks, one for each Munsell component (hue, value, chroma, code).
45
+ Each branch can have different widths for hyperparameter optimization.
46
+
47
+ Parameters
48
+ ----------
49
+ hue_width : float, optional
50
+ Width multiplier for the hue branch. Default is 1.0.
51
+ value_width : float, optional
52
+ Width multiplier for the value branch. Default is 1.0.
53
+ chroma_width : float, optional
54
+ Width multiplier for the chroma branch. Default is 1.5.
55
+ code_width : float, optional
56
+ Width multiplier for the code branch. Default is 1.0.
57
+ """
58
+
59
+ def __init__(
60
+ self,
61
+ hue_width: float = 1.0,
62
+ value_width: float = 1.0,
63
+ chroma_width: float = 1.5,
64
+ code_width: float = 1.0,
65
+ ) -> None:
66
+ super().__init__()
67
+
68
+ # Independent error predictor for each component
69
+ self.hue_branch = ComponentErrorPredictor(width_multiplier=hue_width)
70
+ self.value_branch = ComponentErrorPredictor(
71
+ width_multiplier=value_width
72
+ )
73
+ self.chroma_branch = ComponentErrorPredictor(
74
+ width_multiplier=chroma_width
75
+ )
76
+ self.code_branch = ComponentErrorPredictor(
77
+ width_multiplier=code_width
78
+ )
79
+
80
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
81
+ """
82
+ Forward pass through all four error predictor branches.
83
+
84
+ Parameters
85
+ ----------
86
+ x : torch.Tensor
87
+ Input tensor of shape (batch_size, 7) containing normalized
88
+ xyY values and base model predictions.
89
+
90
+ Returns
91
+ -------
92
+ torch.Tensor
93
+ Predicted errors for all components, shape (batch_size, 4).
94
+ Output order: [hue_error, value_error, chroma_error, code_error].
95
+ """
96
+ # Each branch processes the same combined input independently
97
+ hue_error = self.hue_branch(x)
98
+ value_error = self.value_branch(x)
99
+ chroma_error = self.chroma_branch(x)
100
+ code_error = self.code_branch(x)
101
+
102
+ # Concatenate: [Hue_error, Value_error, Chroma_error, Code_error]
103
+ return torch.cat([hue_error, value_error, chroma_error, code_error], dim=1)
104
+
105
+
106
+ def load_base_model(
107
+ model_path: Path, params_path: Path
108
+ ) -> tuple[ort.InferenceSession, dict, dict]:
109
+ """
110
+ Load the base Multi-Head ONNX model and its normalization parameters.
111
+
112
+ Parameters
113
+ ----------
114
+ model_path : Path
115
+ Path to the base Multi-Head model ONNX file.
116
+ params_path : Path
117
+ Path to the normalization parameters NPZ file.
118
+
119
+ Returns
120
+ -------
121
+ ort.InferenceSession
122
+ ONNX Runtime inference session for the base model.
123
+ dict
124
+ Input normalization parameters (x_range, y_range, Y_range).
125
+ dict
126
+ Output normalization parameters (hue_range, value_range, chroma_range, code_range).
127
+ """
128
+ session = ort.InferenceSession(str(model_path))
129
+ params = np.load(params_path, allow_pickle=True)
130
+ return session, params["input_params"].item(), params["output_params"].item()
131
+
132
+
133
+ def create_weighted_loss(
134
+ mse_weight: float,
135
+ mae_weight: float,
136
+ log_weight: float,
137
+ huber_weight: float,
138
+ huber_delta: float,
139
+ ):
140
+ """
141
+ Create a weighted loss function combining multiple loss components.
142
+
143
+ Parameters
144
+ ----------
145
+ mse_weight : float
146
+ Weight for MSE component.
147
+ mae_weight : float
148
+ Weight for MAE component.
149
+ log_weight : float
150
+ Weight for logarithmic penalty component.
151
+ huber_weight : float
152
+ Weight for Huber loss component.
153
+ huber_delta : float
154
+ Delta parameter for Huber loss transition point.
155
+
156
+ Returns
157
+ -------
158
+ callable
159
+ Loss function that accepts (pred, target) and returns a scalar loss.
160
+ """
161
+
162
+ def weighted_loss(pred: torch.Tensor, target: torch.Tensor) -> torch.Tensor:
163
+ """
164
+ Compute weighted combination of loss components.
165
+
166
+ Parameters
167
+ ----------
168
+ pred : torch.Tensor
169
+ Predicted values, shape (batch_size, n_components).
170
+ target : torch.Tensor
171
+ Target values, shape (batch_size, n_components).
172
+
173
+ Returns
174
+ -------
175
+ torch.Tensor
176
+ Weighted combination of loss components, scalar tensor.
177
+ """
178
+ # Standard MSE
179
+ mse = torch.mean((pred - target) ** 2)
180
+
181
+ # Mean absolute error
182
+ mae = torch.mean(torch.abs(pred - target))
183
+
184
+ # Logarithmic penalty
185
+ log_penalty = torch.mean(torch.log1p(torch.abs(pred - target) * 1000.0))
186
+
187
+ # Huber loss
188
+ abs_error = torch.abs(pred - target)
189
+ huber = torch.where(
190
+ abs_error <= huber_delta,
191
+ 0.5 * abs_error**2,
192
+ huber_delta * (abs_error - 0.5 * huber_delta),
193
+ )
194
+ huber_loss = torch.mean(huber)
195
+
196
+ # Combine with weights
197
+ return (
198
+ mse_weight * mse
199
+ + mae_weight * mae
200
+ + log_weight * log_penalty
201
+ + huber_weight * huber_loss
202
+ )
203
+
204
+ return weighted_loss
205
+
206
+
207
+ def objective(trial: Trial) -> float:
208
+ """
209
+ Optuna objective function to minimize validation loss.
210
+
211
+ This function defines the hyperparameter search space and training
212
+ procedure for each trial. It optimizes:
213
+ - Learning rate (1e-4 to 1e-3, log scale)
214
+ - Batch size (512, 1024, or 2048)
215
+ - Width multipliers for each component branch
216
+ - Loss function weights (MSE, MAE, log penalty, Huber)
217
+ - Huber delta parameter (0.005 to 0.02)
218
+
219
+ Parameters
220
+ ----------
221
+ trial : Trial
222
+ Optuna trial object for suggesting hyperparameters.
223
+
224
+ Returns
225
+ -------
226
+ float
227
+ Best validation loss achieved during training.
228
+
229
+ Raises
230
+ ------
231
+ optuna.TrialPruned
232
+ If trial is pruned based on intermediate results.
233
+ """
234
+
235
+ # Suggest hyperparameters
236
+ lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True)
237
+ batch_size = trial.suggest_categorical("batch_size", [512, 1024, 2048])
238
+ hue_width = trial.suggest_float("hue_width", 0.75, 1.5, step=0.25)
239
+ value_width = trial.suggest_float("value_width", 0.75, 1.5, step=0.25)
240
+ chroma_width = trial.suggest_float("chroma_width", 1.0, 2.0, step=0.25)
241
+ code_width = trial.suggest_float("code_width", 0.75, 1.5, step=0.25)
242
+
243
+ # Loss function weights
244
+ mse_weight = trial.suggest_float("mse_weight", 0.5, 2.0, step=0.5)
245
+ mae_weight = trial.suggest_float("mae_weight", 0.0, 1.0, step=0.25)
246
+ log_weight = trial.suggest_float("log_weight", 0.0, 0.5, step=0.1)
247
+ huber_weight = trial.suggest_float("huber_weight", 0.0, 1.0, step=0.25)
248
+ huber_delta = trial.suggest_float("huber_delta", 0.005, 0.02, step=0.005)
249
+
250
+ LOGGER.info("")
251
+ LOGGER.info("=" * 80)
252
+ LOGGER.info("Trial %d", trial.number)
253
+ LOGGER.info("=" * 80)
254
+ LOGGER.info(" lr: %.6f", lr)
255
+ LOGGER.info(" batch_size: %d", batch_size)
256
+ LOGGER.info(" hue_width: %.2f", hue_width)
257
+ LOGGER.info(" value_width: %.2f", value_width)
258
+ LOGGER.info(" chroma_width: %.2f", chroma_width)
259
+ LOGGER.info(" code_width: %.2f", code_width)
260
+ LOGGER.info(" mse_weight: %.2f", mse_weight)
261
+ LOGGER.info(" mae_weight: %.2f", mae_weight)
262
+ LOGGER.info(" log_weight: %.2f", log_weight)
263
+ LOGGER.info(" huber_weight: %.2f", huber_weight)
264
+ LOGGER.info(" huber_delta: %.3f", huber_delta)
265
+
266
+ # Set device
267
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
268
+ LOGGER.info(" device: %s", device)
269
+
270
+ # Paths
271
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
272
+ data_dir = PROJECT_ROOT / "data"
273
+
274
+ base_model_path = model_directory / "multi_head.onnx"
275
+ params_path = model_directory / "multi_head_normalization_params.npz"
276
+ cache_file = data_dir / "training_data.npz"
277
+
278
+ # Load base model
279
+ base_session, input_params, output_params = load_base_model(
280
+ base_model_path, params_path
281
+ )
282
+
283
+ # Load training data
284
+ data = np.load(cache_file)
285
+ X_train = data["X_train"]
286
+ y_train = data["y_train"]
287
+ X_val = data["X_val"]
288
+ y_val = data["y_val"]
289
+
290
+ # Normalize
291
+ X_train_norm = normalize_xyY(X_train, input_params)
292
+ y_train_norm = normalize_munsell(y_train, output_params)
293
+ X_val_norm = normalize_xyY(X_val, input_params)
294
+ y_val_norm = normalize_munsell(y_val, output_params)
295
+
296
+ # Generate base model predictions
297
+ base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
298
+ base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
299
+
300
+ # Compute errors
301
+ error_train = y_train_norm - base_pred_train_norm
302
+ error_val = y_val_norm - base_pred_val_norm
303
+
304
+ # Create combined input: [xyY_norm, base_prediction_norm]
305
+ X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
306
+ X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
307
+
308
+ # Convert to PyTorch tensors
309
+ X_train_t = torch.FloatTensor(X_train_combined)
310
+ error_train_t = torch.FloatTensor(error_train)
311
+ X_val_t = torch.FloatTensor(X_val_combined)
312
+ error_val_t = torch.FloatTensor(error_val)
313
+
314
+ # Create data loaders
315
+ train_loader = DataLoader(
316
+ TensorDataset(X_train_t, error_train_t), batch_size=batch_size, shuffle=True
317
+ )
318
+ val_loader = DataLoader(
319
+ TensorDataset(X_val_t, error_val_t), batch_size=batch_size, shuffle=False
320
+ )
321
+
322
+ LOGGER.info(
323
+ " Training samples: %d, Validation samples: %d", len(X_train_t), len(X_val_t)
324
+ )
325
+
326
+ # Initialize error predictor model
327
+ model = MultiHeadErrorPredictorParametric(
328
+ hue_width=hue_width,
329
+ value_width=value_width,
330
+ chroma_width=chroma_width,
331
+ code_width=code_width,
332
+ ).to(device)
333
+
334
+ # Count parameters
335
+ total_params = sum(p.numel() for p in model.parameters())
336
+ LOGGER.info(" Total parameters: %s", f"{total_params:,}")
337
+
338
+ # Training setup
339
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
340
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
341
+ optimizer, mode="min", factor=0.5, patience=5
342
+ )
343
+
344
+ # Create loss function
345
+ criterion = create_weighted_loss(
346
+ mse_weight, mae_weight, log_weight, huber_weight, huber_delta
347
+ )
348
+
349
+ # MLflow setup
350
+ run_name = setup_mlflow_experiment(
351
+ "from_xyY", f"hparam_multi_head_error_trial_{trial.number}"
352
+ )
353
+
354
+ # Training loop with early stopping
355
+ num_epochs = 50 # Reduced for hyperparameter search
356
+ patience = 10
357
+ best_val_loss = float("inf")
358
+ patience_counter = 0
359
+
360
+ with mlflow.start_run(run_name=run_name):
361
+ mlflow.log_params(
362
+ {
363
+ "lr": lr,
364
+ "batch_size": batch_size,
365
+ "hue_width": hue_width,
366
+ "value_width": value_width,
367
+ "chroma_width": chroma_width,
368
+ "code_width": code_width,
369
+ "mse_weight": mse_weight,
370
+ "mae_weight": mae_weight,
371
+ "log_weight": log_weight,
372
+ "huber_weight": huber_weight,
373
+ "huber_delta": huber_delta,
374
+ "total_params": total_params,
375
+ "trial_number": trial.number,
376
+ }
377
+ )
378
+
379
+ for epoch in range(num_epochs):
380
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
381
+ val_loss = validate(model, val_loader, criterion, device)
382
+ scheduler.step(val_loss)
383
+
384
+ # Log to MLflow
385
+ mlflow.log_metrics(
386
+ {
387
+ "train_loss": train_loss,
388
+ "val_loss": val_loss,
389
+ "learning_rate": optimizer.param_groups[0]["lr"],
390
+ },
391
+ step=epoch,
392
+ )
393
+
394
+ if (epoch + 1) % 10 == 0:
395
+ LOGGER.info(
396
+ " Epoch %03d/%d - Train: %.6f, Val: %.6f, LR: %.6f",
397
+ epoch + 1,
398
+ num_epochs,
399
+ train_loss,
400
+ val_loss,
401
+ optimizer.param_groups[0]["lr"],
402
+ )
403
+
404
+ # Early stopping
405
+ if val_loss < best_val_loss:
406
+ best_val_loss = val_loss
407
+ patience_counter = 0
408
+ else:
409
+ patience_counter += 1
410
+ if patience_counter >= patience:
411
+ LOGGER.info(" Early stopping at epoch %d", epoch + 1)
412
+ break
413
+
414
+ # Report intermediate value for pruning
415
+ trial.report(val_loss, epoch)
416
+
417
+ # Handle pruning
418
+ if trial.should_prune():
419
+ LOGGER.info(" Trial pruned at epoch %d", epoch + 1)
420
+ mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch})
421
+ raise optuna.TrialPruned
422
+
423
+ # Log final results
424
+ mlflow.log_metrics(
425
+ {
426
+ "best_val_loss": best_val_loss,
427
+ "final_train_loss": train_loss,
428
+ }
429
+ )
430
+
431
+ LOGGER.info(" Final validation loss: %.6f", best_val_loss)
432
+
433
+ return best_val_loss
434
+
435
+
436
+ def main() -> None:
437
+ """
438
+ Run hyperparameter search for Multi-Head Error Predictor.
439
+
440
+ Performs systematic hyperparameter optimization using Optuna with:
441
+ - MedianPruner for early stopping of unpromising trials
442
+ - 30 total trials
443
+ - MLflow logging for each trial
444
+ - Result visualization using matplotlib (optimization history,
445
+ parameter importances, parallel coordinate plot)
446
+
447
+ The search aims to find optimal hyperparameters for predicting errors
448
+ in a base Multi-Head model, allowing for error correction and improved
449
+ Munsell predictions.
450
+ """
451
+
452
+ LOGGER.info("=" * 80)
453
+ LOGGER.info("Multi-Head Error Predictor Hyperparameter Search with Optuna")
454
+ LOGGER.info("=" * 80)
455
+
456
+ # Create study
457
+ study = optuna.create_study(
458
+ direction="minimize",
459
+ study_name="multi_head_error_predictor_hparam_search",
460
+ pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=5),
461
+ )
462
+
463
+ # Run optimization
464
+ n_trials = 30 # Number of trials to run
465
+
466
+ LOGGER.info("")
467
+ LOGGER.info("Starting hyperparameter search with %d trials...", n_trials)
468
+ LOGGER.info("")
469
+
470
+ study.optimize(objective, n_trials=n_trials, timeout=None)
471
+
472
+ # Print results
473
+ LOGGER.info("")
474
+ LOGGER.info("=" * 80)
475
+ LOGGER.info("Hyperparameter Search Results")
476
+ LOGGER.info("=" * 80)
477
+ LOGGER.info("")
478
+ LOGGER.info("Best trial:")
479
+ LOGGER.info(" Value (val_loss): %.6f", study.best_value)
480
+ LOGGER.info("")
481
+ LOGGER.info("Best hyperparameters:")
482
+ for key, value in study.best_params.items():
483
+ LOGGER.info(" %s: %s", key, value)
484
+
485
+ # Save results
486
+ results_dir = PROJECT_ROOT / "results" / "from_xyY"
487
+ results_dir.mkdir(exist_ok=True, parents=True)
488
+
489
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
490
+ results_file = results_dir / f"hparam_search_multi_head_error_{timestamp}.txt"
491
+
492
+ with open(results_file, "w") as f:
493
+ f.write("=" * 80 + "\n")
494
+ f.write("Multi-Head Error Predictor Hyperparameter Search Results\n")
495
+ f.write("=" * 80 + "\n\n")
496
+ f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
497
+ f.write(f"Number of trials: {len(study.trials)}\n")
498
+ f.write(f"Best validation loss: {study.best_value:.6f}\n\n")
499
+ f.write("Best hyperparameters:\n")
500
+ for key, value in study.best_params.items():
501
+ f.write(f" {key}: {value}\n")
502
+ f.write("\n\nAll trials:\n")
503
+ f.write("-" * 80 + "\n")
504
+
505
+ for t in study.trials:
506
+ f.write(f"\nTrial {t.number}:\n")
507
+ if t.value is not None:
508
+ f.write(f" Value: {t.value:.6f}\n")
509
+ else:
510
+ f.write(" Value: Pruned\n")
511
+ f.write(" Params:\n")
512
+ for key, value in t.params.items():
513
+ f.write(f" {key}: {value}\n")
514
+
515
+ LOGGER.info("")
516
+ LOGGER.info("Results saved to: %s", results_file)
517
+
518
+ # Generate visualizations using matplotlib
519
+ from optuna.visualization.matplotlib import (
520
+ plot_optimization_history,
521
+ plot_param_importances,
522
+ plot_parallel_coordinate,
523
+ )
524
+
525
+ # Optimization history
526
+ ax = plot_optimization_history(study)
527
+ ax.figure.savefig(
528
+ results_dir / f"optimization_history_multi_head_error_{timestamp}.png", dpi=150
529
+ )
530
+ plt.close(ax.figure)
531
+
532
+ # Parameter importances
533
+ ax = plot_param_importances(study)
534
+ ax.figure.savefig(
535
+ results_dir / f"param_importances_multi_head_error_{timestamp}.png", dpi=150
536
+ )
537
+ plt.close(ax.figure)
538
+
539
+ # Parallel coordinate plot
540
+ ax = plot_parallel_coordinate(study)
541
+ ax.figure.savefig(
542
+ results_dir / f"parallel_coordinate_multi_head_error_{timestamp}.png", dpi=150
543
+ )
544
+ plt.close(ax.figure)
545
+
546
+ LOGGER.info("Visualizations saved to: %s", results_dir)
547
+
548
+
549
+ if __name__ == "__main__":
550
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
551
+
552
+ main()
learning_munsell/training/from_xyY/hyperparameter_search_multi_mlp.py ADDED
@@ -0,0 +1,471 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Hyperparameter search for Multi-MLP model using Optuna.
3
+
4
+ Optimizes:
5
+ - Learning rate
6
+ - Batch size
7
+ - Chroma width multiplier
8
+ - Chroma loss weight
9
+ - Code loss weight
10
+ - Dropout (optional)
11
+
12
+ Objective: Minimize validation loss
13
+ """
14
+
15
+ import logging
16
+ from datetime import datetime
17
+
18
+ import matplotlib.pyplot as plt
19
+ import mlflow
20
+ import numpy as np
21
+ import optuna
22
+ import torch
23
+ from numpy.typing import NDArray
24
+ from optuna.trial import Trial
25
+ from torch import nn, optim
26
+ from torch.utils.data import DataLoader, TensorDataset
27
+
28
+ from learning_munsell import PROJECT_ROOT
29
+ from learning_munsell.models.networks import MultiMLPToMunsell
30
+ from learning_munsell.utilities.common import setup_mlflow_experiment
31
+ from learning_munsell.utilities.data import MUNSELL_NORMALIZATION_PARAMS, normalize_munsell
32
+
33
+ LOGGER = logging.getLogger(__name__)
34
+
35
+
36
+ def weighted_mse_loss(
37
+ pred: torch.Tensor,
38
+ target: torch.Tensor,
39
+ hue_weight: float = 1.0,
40
+ value_weight: float = 1.0,
41
+ chroma_weight: float = 4.0,
42
+ code_weight: float = 0.5,
43
+ ) -> torch.Tensor:
44
+ """
45
+ Component-wise weighted MSE loss with configurable weights.
46
+
47
+ Applies different weights to each Munsell component to account for
48
+ varying prediction difficulty and importance.
49
+
50
+ Parameters
51
+ ----------
52
+ pred : torch.Tensor
53
+ Predicted values, shape (batch_size, 4).
54
+ target : torch.Tensor
55
+ Target values, shape (batch_size, 4).
56
+ hue_weight : float, optional
57
+ Weight for hue component. Default is 1.0.
58
+ value_weight : float, optional
59
+ Weight for value component. Default is 1.0.
60
+ chroma_weight : float, optional
61
+ Weight for chroma component (typically higher). Default is 4.0.
62
+ code_weight : float, optional
63
+ Weight for code component (typically lower). Default is 0.5.
64
+
65
+ Returns
66
+ -------
67
+ torch.Tensor
68
+ Weighted MSE loss, scalar tensor.
69
+ """
70
+ weights = torch.tensor(
71
+ [hue_weight, value_weight, chroma_weight, code_weight], device=pred.device
72
+ )
73
+
74
+ mse = (pred - target) ** 2
75
+ weighted_mse = mse * weights
76
+ return weighted_mse.mean()
77
+
78
+
79
+ def train_epoch(
80
+ model: nn.Module,
81
+ dataloader: DataLoader,
82
+ optimizer: optim.Optimizer,
83
+ device: torch.device,
84
+ chroma_weight: float,
85
+ code_weight: float,
86
+ ) -> float:
87
+ """
88
+ Train the model for one epoch.
89
+
90
+ Parameters
91
+ ----------
92
+ model : nn.Module
93
+ Multi-MLP model to train.
94
+ dataloader : DataLoader
95
+ DataLoader providing training batches.
96
+ optimizer : optim.Optimizer
97
+ Optimizer for updating model parameters.
98
+ device : torch.device
99
+ Device to run training on (CPU, CUDA, or MPS).
100
+ chroma_weight : float
101
+ Weight for chroma component in loss function.
102
+ code_weight : float
103
+ Weight for code component in loss function.
104
+
105
+ Returns
106
+ -------
107
+ float
108
+ Average training loss over the epoch.
109
+ """
110
+ model.train()
111
+ total_loss = 0.0
112
+
113
+ for X_batch, y_batch in dataloader:
114
+ X_batch = X_batch.to(device)
115
+ y_batch = y_batch.to(device)
116
+ # Forward pass
117
+ outputs = model(X_batch)
118
+ loss = weighted_mse_loss(
119
+ outputs, y_batch, chroma_weight=chroma_weight, code_weight=code_weight
120
+ )
121
+
122
+ # Backward pass
123
+ optimizer.zero_grad()
124
+ loss.backward()
125
+ optimizer.step()
126
+
127
+ total_loss += loss.item()
128
+
129
+ return total_loss / len(dataloader)
130
+
131
+
132
+ def validate(
133
+ model: nn.Module,
134
+ dataloader: DataLoader,
135
+ device: torch.device,
136
+ chroma_weight: float,
137
+ code_weight: float,
138
+ ) -> float:
139
+ """
140
+ Validate the model on the validation set.
141
+
142
+ Parameters
143
+ ----------
144
+ model : nn.Module
145
+ Multi-MLP model to validate.
146
+ dataloader : DataLoader
147
+ DataLoader providing validation batches.
148
+ device : torch.device
149
+ Device to run validation on (CPU, CUDA, or MPS).
150
+ chroma_weight : float
151
+ Weight for chroma component in loss function.
152
+ code_weight : float
153
+ Weight for code component in loss function.
154
+
155
+ Returns
156
+ -------
157
+ float
158
+ Average validation loss.
159
+ """
160
+ model.eval()
161
+ total_loss = 0.0
162
+
163
+ with torch.no_grad():
164
+ for X_batch, y_batch in dataloader:
165
+ X_batch = X_batch.to(device)
166
+ y_batch = y_batch.to(device)
167
+ outputs = model(X_batch)
168
+ loss = weighted_mse_loss(
169
+ outputs, y_batch, chroma_weight=chroma_weight, code_weight=code_weight
170
+ )
171
+
172
+ total_loss += loss.item()
173
+
174
+ return total_loss / len(dataloader)
175
+
176
+
177
+ def objective(trial: Trial) -> float:
178
+ """
179
+ Optuna objective function to minimize validation loss.
180
+
181
+ This function defines the hyperparameter search space and training
182
+ procedure for each trial. It optimizes:
183
+ - Learning rate (1e-4 to 1e-3, log scale)
184
+ - Batch size (512, 1024, or 2048)
185
+ - Chroma branch width multiplier (1.5 to 2.5)
186
+ - Chroma loss weight (3.0 to 6.0)
187
+ - Code loss weight (0.3 to 1.0)
188
+ - Dropout rate (0.0 to 0.2)
189
+
190
+ Parameters
191
+ ----------
192
+ trial : Trial
193
+ Optuna trial object for suggesting hyperparameters.
194
+
195
+ Returns
196
+ -------
197
+ float
198
+ Best validation loss achieved during training.
199
+
200
+ Raises
201
+ ------
202
+ FileNotFoundError
203
+ If training data file is not found.
204
+ optuna.TrialPruned
205
+ If trial is pruned based on intermediate results.
206
+ """
207
+
208
+ # Suggest hyperparameters
209
+ lr = trial.suggest_float("lr", 1e-4, 1e-3, log=True)
210
+ batch_size = trial.suggest_categorical("batch_size", [512, 1024, 2048])
211
+ chroma_width = trial.suggest_float("chroma_width", 1.5, 2.5, step=0.25)
212
+ chroma_weight = trial.suggest_float("chroma_weight", 3.0, 6.0, step=0.5)
213
+ code_weight = trial.suggest_float("code_weight", 0.3, 1.0, step=0.1)
214
+ dropout = trial.suggest_float("dropout", 0.0, 0.2, step=0.05)
215
+
216
+ LOGGER.info("")
217
+ LOGGER.info("=" * 80)
218
+ LOGGER.info("Trial %d", trial.number)
219
+ LOGGER.info("=" * 80)
220
+ LOGGER.info(" lr: %.6f", lr)
221
+ LOGGER.info(" batch_size: %d", batch_size)
222
+ LOGGER.info(" chroma_width: %.2f", chroma_width)
223
+ LOGGER.info(" chroma_weight: %.1f", chroma_weight)
224
+ LOGGER.info(" code_weight: %.1f", code_weight)
225
+ LOGGER.info(" dropout: %.2f", dropout)
226
+
227
+ # Set device
228
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
229
+
230
+ # Load training data
231
+ data_file = PROJECT_ROOT / "data" / "training_data.npz"
232
+
233
+ if not data_file.exists():
234
+ LOGGER.error("Training data not found at %s", data_file)
235
+ LOGGER.error("Run generate_training_data.py first")
236
+ msg = f"Training data not found: {data_file}"
237
+ raise FileNotFoundError(msg)
238
+
239
+ data = np.load(data_file)
240
+
241
+ # Use pre-split data
242
+ X_train = data["X_train"]
243
+ y_train = data["y_train"]
244
+ X_val = data["X_val"]
245
+ y_val = data["y_val"]
246
+
247
+ LOGGER.info(
248
+ "Loaded %d training samples, %d validation samples", len(X_train), len(X_val)
249
+ )
250
+
251
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
252
+ output_params = MUNSELL_NORMALIZATION_PARAMS
253
+ y_train = normalize_munsell(y_train, output_params)
254
+ y_val = normalize_munsell(y_val, output_params)
255
+
256
+ # Convert to PyTorch tensors
257
+ X_train_t = torch.FloatTensor(X_train)
258
+ y_train_t = torch.FloatTensor(y_train)
259
+ X_val_t = torch.FloatTensor(X_val)
260
+ y_val_t = torch.FloatTensor(y_val)
261
+
262
+ # Create data loaders
263
+ train_dataset = TensorDataset(X_train_t, y_train_t)
264
+ val_dataset = TensorDataset(X_val_t, y_val_t)
265
+
266
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
267
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
268
+
269
+ # Initialize model
270
+ model = MultiMLPToMunsell(
271
+ chroma_width_multiplier=chroma_width, dropout=dropout
272
+ ).to(device)
273
+
274
+ # Count parameters
275
+ total_params = sum(p.numel() for p in model.parameters())
276
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
277
+
278
+ # Training setup
279
+ optimizer = optim.Adam(model.parameters(), lr=lr)
280
+
281
+ # MLflow setup
282
+ run_name = setup_mlflow_experiment(
283
+ "from_xyY", f"hparam_multi_mlp_trial_{trial.number}"
284
+ )
285
+
286
+ # Training loop with early stopping
287
+ num_epochs = 100 # Reduced for hyperparameter search
288
+ patience = 15
289
+ best_val_loss = float("inf")
290
+ patience_counter = 0
291
+
292
+ with mlflow.start_run(run_name=run_name):
293
+ mlflow.log_params(
294
+ {
295
+ "trial": trial.number,
296
+ "lr": lr,
297
+ "batch_size": batch_size,
298
+ "chroma_width": chroma_width,
299
+ "chroma_weight": chroma_weight,
300
+ "code_weight": code_weight,
301
+ "dropout": dropout,
302
+ "total_params": total_params,
303
+ }
304
+ )
305
+
306
+ for epoch in range(num_epochs):
307
+ train_loss = train_epoch(
308
+ model, train_loader, optimizer, device, chroma_weight, code_weight
309
+ )
310
+ val_loss = validate(model, val_loader, device, chroma_weight, code_weight)
311
+
312
+ # Log to MLflow
313
+ mlflow.log_metrics(
314
+ {
315
+ "train_loss": train_loss,
316
+ "val_loss": val_loss,
317
+ "learning_rate": lr,
318
+ },
319
+ step=epoch,
320
+ )
321
+
322
+ if (epoch + 1) % 10 == 0:
323
+ LOGGER.info(
324
+ " Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
325
+ epoch + 1,
326
+ num_epochs,
327
+ train_loss,
328
+ val_loss,
329
+ )
330
+
331
+ # Early stopping
332
+ if val_loss < best_val_loss:
333
+ best_val_loss = val_loss
334
+ patience_counter = 0
335
+ else:
336
+ patience_counter += 1
337
+ if patience_counter >= patience:
338
+ LOGGER.info(" Early stopping at epoch %d", epoch + 1)
339
+ break
340
+
341
+ # Report intermediate value for pruning
342
+ trial.report(val_loss, epoch)
343
+
344
+ # Handle pruning
345
+ if trial.should_prune():
346
+ LOGGER.info(" Trial pruned at epoch %d", epoch + 1)
347
+ mlflow.log_metrics({"pruned": 1, "pruned_epoch": epoch})
348
+ raise optuna.TrialPruned
349
+
350
+ # Log final results
351
+ mlflow.log_metrics(
352
+ {
353
+ "best_val_loss": best_val_loss,
354
+ "final_train_loss": train_loss,
355
+ "final_epoch": epoch + 1,
356
+ }
357
+ )
358
+
359
+ LOGGER.info(" Final validation loss: %.6f", best_val_loss)
360
+
361
+ return best_val_loss
362
+
363
+
364
+ def main() -> None:
365
+ """
366
+ Run hyperparameter search for Multi-MLP model.
367
+
368
+ Performs systematic hyperparameter optimization using Optuna with:
369
+ - MedianPruner for early stopping of unpromising trials
370
+ - 15 total trials
371
+ - MLflow logging for each trial
372
+ - Result visualization using matplotlib (optimization history,
373
+ parameter importances, parallel coordinate plot)
374
+
375
+ The search aims to find optimal hyperparameters for converting xyY
376
+ color coordinates to Munsell color specifications using a multi-MLP
377
+ architecture with independent branches for each component.
378
+ """
379
+
380
+ LOGGER.info("=" * 80)
381
+ LOGGER.info("Multi-MLP Hyperparameter Search with Optuna")
382
+ LOGGER.info("=" * 80)
383
+
384
+ # Create study
385
+ study = optuna.create_study(
386
+ direction="minimize",
387
+ study_name="multi_mlp_hparam_search",
388
+ pruner=optuna.pruners.MedianPruner(n_startup_trials=3, n_warmup_steps=10),
389
+ )
390
+
391
+ # Run optimization
392
+ n_trials = 15 # Number of trials to run
393
+
394
+ LOGGER.info("")
395
+ LOGGER.info("Starting hyperparameter search with %d trials...", n_trials)
396
+ LOGGER.info("")
397
+
398
+ study.optimize(objective, n_trials=n_trials, timeout=None)
399
+
400
+ # Print results
401
+ LOGGER.info("")
402
+ LOGGER.info("=" * 80)
403
+ LOGGER.info("Hyperparameter Search Results")
404
+ LOGGER.info("=" * 80)
405
+ LOGGER.info("")
406
+ LOGGER.info("Best trial:")
407
+ LOGGER.info(" Value (val_loss): %.6f", study.best_value)
408
+ LOGGER.info("")
409
+ LOGGER.info("Best hyperparameters:")
410
+ for key, value in study.best_params.items():
411
+ LOGGER.info(" %s: %s", key, value)
412
+
413
+ # Save results
414
+ results_dir = PROJECT_ROOT / "results" / "from_xyY"
415
+ results_dir.mkdir(exist_ok=True)
416
+
417
+ timestamp = datetime.now().strftime("%Y%m%d_%H%M%S")
418
+ results_file = results_dir / f"hparam_search_{timestamp}.txt"
419
+
420
+ with open(results_file, "w") as f:
421
+ f.write("=" * 80 + "\n")
422
+ f.write("Multi-MLP Hyperparameter Search Results\n")
423
+ f.write("=" * 80 + "\n\n")
424
+ f.write(f"Date: {datetime.now().strftime('%Y-%m-%d %H:%M:%S')}\n")
425
+ f.write(f"Number of trials: {len(study.trials)}\n")
426
+ f.write(f"Best validation loss: {study.best_value:.6f}\n\n")
427
+ f.write("Best hyperparameters:\n")
428
+ for key, value in study.best_params.items():
429
+ f.write(f" {key}: {value}\n")
430
+ f.write("\n\nAll trials:\n")
431
+ f.write("-" * 80 + "\n")
432
+
433
+ for trial in study.trials:
434
+ f.write(f"\nTrial {trial.number}:\n")
435
+ f.write(f" Value: {trial.value:.6f if trial.value else 'Pruned'}\n")
436
+ f.write(" Params:\n")
437
+ for key, value in trial.params.items():
438
+ f.write(f" {key}: {value}\n")
439
+
440
+ LOGGER.info("")
441
+ LOGGER.info("Results saved to: %s", results_file)
442
+
443
+ # Generate visualizations using matplotlib
444
+ from optuna.visualization.matplotlib import (
445
+ plot_optimization_history,
446
+ plot_param_importances,
447
+ plot_parallel_coordinate,
448
+ )
449
+
450
+ # Optimization history
451
+ ax = plot_optimization_history(study)
452
+ ax.figure.savefig(results_dir / f"optimization_history_{timestamp}.png", dpi=150)
453
+ plt.close(ax.figure)
454
+
455
+ # Parameter importances
456
+ ax = plot_param_importances(study)
457
+ ax.figure.savefig(results_dir / f"param_importances_{timestamp}.png", dpi=150)
458
+ plt.close(ax.figure)
459
+
460
+ # Parallel coordinate plot
461
+ ax = plot_parallel_coordinate(study)
462
+ ax.figure.savefig(results_dir / f"parallel_coordinate_{timestamp}.png", dpi=150)
463
+ plt.close(ax.figure)
464
+
465
+ LOGGER.info("Visualizations saved to: %s", results_dir)
466
+
467
+
468
+ if __name__ == "__main__":
469
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
470
+
471
+ main()
learning_munsell/training/from_xyY/refine_multi_head_real.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Refine Multi-Head model on REAL Munsell colors only.
3
+
4
+ This script fine-tunes the best Multi-Head model using only the 2734 real
5
+ (measured) Munsell colors, which should improve accuracy on the evaluation set.
6
+ """
7
+
8
+ import logging
9
+ from typing import Any
10
+
11
+ import click
12
+ import mlflow
13
+ import mlflow.pytorch
14
+ import numpy as np
15
+ import torch
16
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
17
+ from colour.notation.munsell import (
18
+ munsell_colour_to_munsell_specification,
19
+ munsell_specification_to_xyY,
20
+ )
21
+ from numpy.typing import NDArray
22
+ from sklearn.model_selection import train_test_split
23
+ from torch import nn, optim
24
+ from torch.utils.data import DataLoader, TensorDataset
25
+
26
+ from learning_munsell import PROJECT_ROOT
27
+ from learning_munsell.models.networks import MultiHeadMLPToMunsell
28
+ from learning_munsell.utilities.common import (
29
+ log_training_epoch,
30
+ setup_mlflow_experiment,
31
+ )
32
+ from learning_munsell.utilities.data import (
33
+ MUNSELL_NORMALIZATION_PARAMS,
34
+ XYY_NORMALIZATION_PARAMS,
35
+ normalize_munsell,
36
+ )
37
+ from learning_munsell.utilities.training import train_epoch, validate
38
+
39
+ LOGGER = logging.getLogger(__name__)
40
+
41
+
42
+ def generate_real_samples(
43
+ n_samples_per_color: int = 100,
44
+ perturbation_pct: float = 0.05,
45
+ ) -> tuple[NDArray, NDArray]:
46
+ """
47
+ Generate training samples from REAL (measured) Munsell colors only.
48
+
49
+ Creates augmented samples by applying small perturbations to the 2734 real
50
+ Munsell color specifications to increase training data while staying close
51
+ to measured values.
52
+
53
+ Parameters
54
+ ----------
55
+ n_samples_per_color : int, optional
56
+ Number of perturbed samples to generate per real color (default is 100).
57
+ perturbation_pct : float, optional
58
+ Percentage of range to use for perturbations (default is 0.05 = 5%).
59
+
60
+ Returns
61
+ -------
62
+ xyY_samples : NDArray
63
+ Array of shape (n_samples, 3) containing xyY coordinates.
64
+ munsell_samples : NDArray
65
+ Array of shape (n_samples, 4) containing Munsell specifications
66
+ [hue, value, chroma, code].
67
+
68
+ Notes
69
+ -----
70
+ Perturbations are applied uniformly within ±perturbation_pct of the
71
+ component ranges:
72
+ - Hue range: 9.5 (0.5 to 10.0)
73
+ - Value range: 9.0 (1.0 to 10.0)
74
+ - Chroma range: 50.0 (0.0 to 50.0)
75
+
76
+ Invalid samples (that cannot be converted to xyY) are skipped.
77
+ """
78
+ LOGGER.info(
79
+ "Generating samples from %d REAL Munsell colors...", len(MUNSELL_COLOURS_REAL)
80
+ )
81
+
82
+ np.random.seed(42)
83
+
84
+ hue_range = 9.5
85
+ value_range = 9.0
86
+ chroma_range = 50.0
87
+
88
+ xyY_samples = []
89
+ munsell_samples = []
90
+
91
+ for munsell_spec_tuple, _ in MUNSELL_COLOURS_REAL:
92
+ hue_code_str, value, chroma = munsell_spec_tuple
93
+ munsell_str = f"{hue_code_str} {value}/{chroma}"
94
+ base_spec = munsell_colour_to_munsell_specification(munsell_str)
95
+
96
+ for _ in range(n_samples_per_color):
97
+ hue_delta = np.random.uniform(
98
+ -perturbation_pct * hue_range, perturbation_pct * hue_range
99
+ )
100
+ value_delta = np.random.uniform(
101
+ -perturbation_pct * value_range, perturbation_pct * value_range
102
+ )
103
+ chroma_delta = np.random.uniform(
104
+ -perturbation_pct * chroma_range, perturbation_pct * chroma_range
105
+ )
106
+
107
+ perturbed_spec = base_spec.copy()
108
+ perturbed_spec[0] = np.clip(base_spec[0] + hue_delta, 0.5, 10.0)
109
+ perturbed_spec[1] = np.clip(base_spec[1] + value_delta, 1.0, 10.0)
110
+ perturbed_spec[2] = np.clip(base_spec[2] + chroma_delta, 0.0, 50.0)
111
+
112
+ try:
113
+ xyY = munsell_specification_to_xyY(perturbed_spec)
114
+ xyY_samples.append(xyY)
115
+ munsell_samples.append(perturbed_spec)
116
+ except Exception:
117
+ continue
118
+
119
+ LOGGER.info("Generated %d samples", len(xyY_samples))
120
+ return np.array(xyY_samples), np.array(munsell_samples)
121
+
122
+
123
+ @click.command()
124
+ @click.option("--epochs", default=200, help="Number of training epochs")
125
+ @click.option("--batch-size", default=512, help="Batch size for training")
126
+ @click.option("--lr", default=1e-5, help="Learning rate")
127
+ @click.option("--patience", default=30, help="Early stopping patience")
128
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
129
+ """
130
+ Refine Multi-Head model on REAL Munsell colors only.
131
+
132
+ Fine-tunes a pretrained Multi-Head MLP model using only the 2734 real
133
+ (measured) Munsell colors with small perturbations. This refinement step
134
+ aims to improve accuracy on actual measured colors by focusing the model
135
+ on the real color gamut.
136
+
137
+ Notes
138
+ -----
139
+ Training configuration:
140
+ - Dataset: 2734 real Munsell colors with 200 samples per color
141
+ - Perturbation: 3% of component ranges (smaller than initial training)
142
+ - Learning rate: 1e-5 (lower for fine-tuning)
143
+ - Batch size: 512
144
+ - Early stopping: patience of 30 epochs
145
+ - Optimizer: AdamW with weight decay 0.01
146
+ - Scheduler: ReduceLROnPlateau with factor 0.5, patience 15
147
+
148
+ Workflow:
149
+ 1. Generate augmented samples from real Munsell colors
150
+ 2. Load pretrained model (multi_head_large_best.pth)
151
+ 3. Fine-tune with lower learning rate
152
+ 4. Save best model based on validation loss
153
+ 5. Export to ONNX format
154
+ 6. Log metrics to MLflow
155
+
156
+ Files generated:
157
+ - multi_head_refined_real_best.pth: Best checkpoint
158
+ - multi_head_refined_real.onnx: ONNX model
159
+ - multi_head_refined_real_normalization_params.npz: Normalization params
160
+ """
161
+ LOGGER.info("=" * 80)
162
+ LOGGER.info("Multi-Head Refinement on REAL Munsell Colors")
163
+ LOGGER.info("=" * 80)
164
+
165
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
166
+ if torch.backends.mps.is_available():
167
+ device = torch.device("mps")
168
+ LOGGER.info("Using device: %s", device)
169
+
170
+ # Generate REAL-only samples
171
+ LOGGER.info("")
172
+ xyY_all, munsell_all = generate_real_samples(
173
+ n_samples_per_color=200, # 200 samples per real color
174
+ perturbation_pct=0.03, # Smaller perturbations for refinement
175
+ )
176
+
177
+ # Split data
178
+ X_train, X_val, y_train, y_val = train_test_split(
179
+ xyY_all, munsell_all, test_size=0.15, random_state=42
180
+ )
181
+
182
+ LOGGER.info("Train samples: %d", len(X_train))
183
+ LOGGER.info("Validation samples: %d", len(X_val))
184
+
185
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
186
+ # Use hardcoded ranges covering the full Munsell space for generalization
187
+ output_params = MUNSELL_NORMALIZATION_PARAMS
188
+ y_train_norm = normalize_munsell(y_train, output_params)
189
+ y_val_norm = normalize_munsell(y_val, output_params)
190
+
191
+ # Convert to tensors
192
+ X_train_t = torch.FloatTensor(X_train)
193
+ y_train_t = torch.FloatTensor(y_train_norm)
194
+ X_val_t = torch.FloatTensor(X_val)
195
+ y_val_t = torch.FloatTensor(y_val_norm)
196
+
197
+ # Data loaders
198
+ train_dataset = TensorDataset(X_train_t, y_train_t)
199
+ val_dataset = TensorDataset(X_val_t, y_val_t)
200
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
201
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
202
+
203
+ # Load pretrained model
204
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
205
+ pretrained_path = model_directory / "multi_head_large_best.pth"
206
+
207
+ model = MultiHeadMLPToMunsell().to(device)
208
+
209
+ if pretrained_path.exists():
210
+ LOGGER.info("")
211
+ LOGGER.info("Loading pretrained model from %s...", pretrained_path)
212
+ checkpoint = torch.load(
213
+ pretrained_path, weights_only=False, map_location=device
214
+ )
215
+ model.load_state_dict(checkpoint["model_state_dict"])
216
+ LOGGER.info("Pretrained model loaded successfully")
217
+ else:
218
+ LOGGER.info("")
219
+ LOGGER.info("No pretrained model found, training from scratch")
220
+
221
+ total_params = sum(p.numel() for p in model.parameters())
222
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
223
+
224
+ # Fine-tuning with lower learning rate
225
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=0.01)
226
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
227
+ optimizer, mode="min", factor=0.5, patience=15
228
+ )
229
+ criterion = nn.MSELoss()
230
+
231
+ # MLflow setup
232
+ run_name = setup_mlflow_experiment("from_xyY", "multi_head_refined_real")
233
+
234
+ LOGGER.info("")
235
+ LOGGER.info("MLflow run: %s", run_name)
236
+ LOGGER.info("Learning rate: %e (fine-tuning)", lr)
237
+
238
+ # Training loop
239
+ best_val_loss = float("inf")
240
+ patience_counter = 0
241
+
242
+ LOGGER.info("")
243
+ LOGGER.info("Starting refinement training...")
244
+
245
+ with mlflow.start_run(run_name=run_name):
246
+ mlflow.log_params(
247
+ {
248
+ "model": "multi_head_refined_real",
249
+ "learning_rate": lr,
250
+ "batch_size": batch_size,
251
+ "num_epochs": epochs,
252
+ "patience": patience,
253
+ "total_params": total_params,
254
+ "train_samples": len(X_train),
255
+ "val_samples": len(X_val),
256
+ "dataset": "REAL_only",
257
+ "perturbation_pct": 0.03,
258
+ "samples_per_color": 200,
259
+ }
260
+ )
261
+
262
+ for epoch in range(epochs):
263
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
264
+ val_loss = validate(model, val_loader, criterion, device)
265
+
266
+ scheduler.step(val_loss)
267
+
268
+ log_training_epoch(
269
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
270
+ )
271
+
272
+ LOGGER.info(
273
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.2e",
274
+ epoch + 1,
275
+ epochs,
276
+ train_loss,
277
+ val_loss,
278
+ optimizer.param_groups[0]["lr"],
279
+ )
280
+
281
+ if val_loss < best_val_loss:
282
+ best_val_loss = val_loss
283
+ patience_counter = 0
284
+
285
+ checkpoint_file = model_directory / "multi_head_refined_real_best.pth"
286
+
287
+ torch.save(
288
+ {
289
+ "model_state_dict": model.state_dict(),
290
+ "output_params": output_params,
291
+ "epoch": epoch,
292
+ "val_loss": val_loss,
293
+ },
294
+ checkpoint_file,
295
+ )
296
+
297
+ LOGGER.info(" -> Saved best model (val_loss: %.6f)", val_loss)
298
+ else:
299
+ patience_counter += 1
300
+ if patience_counter >= patience:
301
+ LOGGER.info("")
302
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
303
+ break
304
+
305
+ mlflow.log_metrics(
306
+ {
307
+ "best_val_loss": best_val_loss,
308
+ "final_epoch": epoch + 1,
309
+ }
310
+ )
311
+
312
+ # Export to ONNX
313
+ LOGGER.info("")
314
+ LOGGER.info("Exporting refined model to ONNX...")
315
+ model.eval()
316
+
317
+ checkpoint = torch.load(checkpoint_file, weights_only=False)
318
+ model.load_state_dict(checkpoint["model_state_dict"])
319
+
320
+ model_cpu = model.cpu()
321
+ dummy_input = torch.randn(1, 3)
322
+
323
+ onnx_file = model_directory / "multi_head_refined_real.onnx"
324
+ torch.onnx.export(
325
+ model_cpu,
326
+ dummy_input,
327
+ onnx_file,
328
+ export_params=True,
329
+ opset_version=14,
330
+ input_names=["xyY"],
331
+ output_names=["munsell_spec"],
332
+ dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
333
+ )
334
+
335
+ params_file = (
336
+ model_directory / "multi_head_refined_real_normalization_params.npz"
337
+ )
338
+ input_params = XYY_NORMALIZATION_PARAMS
339
+ np.savez(
340
+ params_file,
341
+ input_params=input_params,
342
+ output_params=output_params,
343
+ )
344
+
345
+ mlflow.log_artifact(str(checkpoint_file))
346
+ mlflow.log_artifact(str(onnx_file))
347
+ mlflow.log_artifact(str(params_file))
348
+
349
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
350
+ LOGGER.info("Normalization params saved to: %s", params_file)
351
+
352
+ LOGGER.info("=" * 80)
353
+
354
+
355
+ if __name__ == "__main__":
356
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
357
+
358
+ main()
learning_munsell/training/from_xyY/train_deep_wide.py ADDED
@@ -0,0 +1,371 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Deep + Wide model for xyY to Munsell conversion.
3
+
4
+ Option 5: Hybrid Deep + Wide architecture
5
+ - Input: 3 features (xyY)
6
+ - Deep path: 3 → 512 → 1024 (ResBlocks) → 512
7
+ - Wide path: 3 → 128 (direct linear)
8
+ - Combine: [512, 128] → 256 → 4
9
+ - Output: 4 features (hue, value, chroma, code)
10
+ """
11
+
12
+ import logging
13
+ from typing import Any
14
+
15
+ import click
16
+ import mlflow
17
+ import mlflow.pytorch
18
+ import numpy as np
19
+ import torch
20
+ from numpy.typing import NDArray
21
+ from torch import nn, optim
22
+ from torch.utils.data import DataLoader, TensorDataset
23
+
24
+ from learning_munsell import PROJECT_ROOT
25
+ from learning_munsell.models.networks import ResidualBlock
26
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
27
+ from learning_munsell.utilities.data import (
28
+ MUNSELL_NORMALIZATION_PARAMS,
29
+ XYY_NORMALIZATION_PARAMS,
30
+ normalize_munsell,
31
+ )
32
+ from learning_munsell.utilities.losses import precision_focused_loss
33
+ from learning_munsell.utilities.training import train_epoch, validate
34
+
35
+ LOGGER = logging.getLogger(__name__)
36
+
37
+
38
+ class DeepWideNet(nn.Module):
39
+ """
40
+ Deep + Wide Network for xyY to Munsell conversion.
41
+
42
+ Architecture:
43
+ - Deep path: Complex non-linear transformation
44
+ - Wide path: Direct linear connections
45
+ - Combines both for final prediction
46
+
47
+ Parameters
48
+ ----------
49
+ num_residual_blocks : int, optional
50
+ Number of residual blocks in deep path. Default is 4.
51
+
52
+ Attributes
53
+ ----------
54
+ deep_encoder : nn.Sequential
55
+ Deep path encoder: 3 → 512 → 1024.
56
+ deep_residual_blocks : nn.ModuleList
57
+ Stack of residual blocks in deep path.
58
+ deep_decoder : nn.Sequential
59
+ Deep path decoder: 1024 → 512.
60
+ wide_path : nn.Sequential
61
+ Wide path: 3 → 128.
62
+ output_head : nn.Sequential
63
+ Combined output: [512, 128] → 256 → 4.
64
+
65
+ Notes
66
+ -----
67
+ Hybrid architecture inspired by Google's Wide & Deep Learning:
68
+ - Deep path: 3 → 512 → 1024 → (ResBlocks) → 512
69
+ - Wide path: 3 → 128 (direct linear transformation)
70
+ - Combined: Concatenate [512, 128] → 256 → 4
71
+
72
+ The deep path learns complex non-linear transformations while the
73
+ wide path provides direct linear connections to preserve simple
74
+ relationships. Both paths are concatenated before the final output.
75
+ """
76
+
77
+ def __init__(self, num_residual_blocks: int = 4) -> None:
78
+ """Initialize the deep and wide network."""
79
+ super().__init__()
80
+
81
+ # Deep path: Complex transformation
82
+ self.deep_encoder = nn.Sequential(
83
+ nn.Linear(3, 512),
84
+ nn.GELU(),
85
+ nn.BatchNorm1d(512),
86
+ nn.Linear(512, 1024),
87
+ nn.GELU(),
88
+ nn.BatchNorm1d(1024),
89
+ )
90
+
91
+ self.deep_residual_blocks = nn.ModuleList(
92
+ [ResidualBlock(1024) for _ in range(num_residual_blocks)]
93
+ )
94
+
95
+ self.deep_decoder = nn.Sequential(
96
+ nn.Linear(1024, 512),
97
+ nn.GELU(),
98
+ nn.BatchNorm1d(512),
99
+ )
100
+
101
+ # Wide path: Direct linear transformation
102
+ self.wide_path = nn.Sequential(
103
+ nn.Linear(3, 128),
104
+ nn.GELU(),
105
+ nn.BatchNorm1d(128),
106
+ )
107
+
108
+ # Combined output: Concatenate deep (512) + wide (128) = 640
109
+ self.output_head = nn.Sequential(
110
+ nn.Linear(640, 256),
111
+ nn.GELU(),
112
+ nn.BatchNorm1d(256),
113
+ nn.Linear(256, 4),
114
+ )
115
+
116
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
117
+ """
118
+ Forward pass through deep and wide paths.
119
+
120
+ Parameters
121
+ ----------
122
+ x : Tensor
123
+ Input tensor of shape (batch_size, 3) containing normalized xyY values.
124
+
125
+ Returns
126
+ -------
127
+ Tensor
128
+ Output tensor of shape (batch_size, 4) containing normalized Munsell
129
+ specifications [hue, value, chroma, code].
130
+
131
+ Notes
132
+ -----
133
+ The forward pass processes input through two parallel paths:
134
+ 1. Deep path: Complex transformation through encoder, residual blocks,
135
+ and decoder (3 → 512 → 1024 → 512)
136
+ 2. Wide path: Direct linear transformation (3 → 128)
137
+ 3. Concatenation: Combine deep (512) + wide (128) = 640 features
138
+ 4. Output head: Final transformation to 4 components (640 → 256 → 4)
139
+ """
140
+ # Deep path
141
+ deep = self.deep_encoder(x)
142
+ for block in self.deep_residual_blocks:
143
+ deep = block(deep)
144
+ deep = self.deep_decoder(deep)
145
+
146
+ # Wide path
147
+ wide = self.wide_path(x)
148
+
149
+ # Concatenate and output
150
+ combined = torch.cat([deep, wide], dim=1)
151
+ return self.output_head(combined)
152
+
153
+
154
+ @click.command()
155
+ @click.option("--epochs", default=200, help="Number of training epochs")
156
+ @click.option("--batch-size", default=1024, help="Batch size for training")
157
+ @click.option("--lr", default=3e-4, help="Learning rate")
158
+ @click.option("--patience", default=20, help="Early stopping patience")
159
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
160
+ """
161
+ Train the DeepWideNet model for xyY to Munsell conversion.
162
+
163
+ Notes
164
+ -----
165
+ The training pipeline:
166
+ 1. Loads normalization parameters from existing config
167
+ 2. Loads training data from cache
168
+ 3. Normalizes inputs and outputs to [0, 1] range
169
+ 4. Creates PyTorch DataLoaders
170
+ 5. Initializes DeepWideNet with deep and wide paths
171
+ 6. Trains with AdamW optimizer and precision-focused loss
172
+ 7. Uses learning rate scheduler (ReduceLROnPlateau)
173
+ 8. Implements early stopping based on validation loss
174
+ 9. Exports best model to ONNX format
175
+ 10. Logs all metrics and artifacts to MLflow
176
+ """
177
+
178
+
179
+ LOGGER.info("=" * 80)
180
+ LOGGER.info("Deep + Wide Network: xyY → Munsell")
181
+ LOGGER.info("=" * 80)
182
+
183
+ # Set device
184
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
185
+ LOGGER.info("Using device: %s", device)
186
+
187
+ # Paths
188
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
189
+ data_dir = PROJECT_ROOT / "data"
190
+ cache_file = data_dir / "training_data.npz"
191
+
192
+ # Load training data
193
+ LOGGER.info("")
194
+ LOGGER.info("Loading training data from %s...", cache_file)
195
+ data = np.load(cache_file)
196
+ X_train = data["X_train"]
197
+ y_train = data["y_train"]
198
+ X_val = data["X_val"]
199
+ y_val = data["y_val"]
200
+
201
+ LOGGER.info("Train samples: %d", len(X_train))
202
+ LOGGER.info("Validation samples: %d", len(X_val))
203
+
204
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
205
+ # Use hardcoded ranges covering the full Munsell space for generalization
206
+ output_params = MUNSELL_NORMALIZATION_PARAMS
207
+ y_train_norm = normalize_munsell(y_train, output_params)
208
+ y_val_norm = normalize_munsell(y_val, output_params)
209
+
210
+ # Convert to PyTorch tensors
211
+ X_train_t = torch.FloatTensor(X_train)
212
+ y_train_t = torch.FloatTensor(y_train_norm)
213
+ X_val_t = torch.FloatTensor(X_val)
214
+ y_val_t = torch.FloatTensor(y_val_norm)
215
+
216
+ # Create data loaders
217
+ train_dataset = TensorDataset(X_train_t, y_train_t)
218
+ val_dataset = TensorDataset(X_val_t, y_val_t)
219
+
220
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
221
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
222
+
223
+ # Initialize model
224
+ model = DeepWideNet(num_residual_blocks=4).to(device)
225
+ LOGGER.info("")
226
+ LOGGER.info("Deep + Wide architecture:")
227
+ LOGGER.info("%s", model)
228
+
229
+ # Count parameters
230
+ total_params = sum(p.numel() for p in model.parameters())
231
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
232
+
233
+ # Training setup
234
+ learning_rate = lr
235
+ optimizer = optim.AdamW(model.parameters(), lr=learning_rate, weight_decay=1e-5)
236
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
237
+ optimizer, mode="min", factor=0.5, patience=5
238
+ )
239
+ criterion = precision_focused_loss
240
+
241
+ # MLflow setup
242
+ run_name = setup_mlflow_experiment("from_xyY", "deep_wide")
243
+
244
+ LOGGER.info("")
245
+ LOGGER.info("MLflow run: %s", run_name)
246
+
247
+ # Training loop
248
+ best_val_loss = float("inf")
249
+ patience_counter = 0
250
+
251
+ LOGGER.info("")
252
+ LOGGER.info("Starting training...")
253
+
254
+ with mlflow.start_run(run_name=run_name):
255
+ # Log parameters
256
+ mlflow.log_params(
257
+ {
258
+ "model": "deep_wide",
259
+ "learning_rate": learning_rate,
260
+ "batch_size": batch_size,
261
+ "num_epochs": epochs,
262
+ "patience": patience,
263
+ "total_params": total_params,
264
+ }
265
+ )
266
+
267
+ for epoch in range(epochs):
268
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
269
+ val_loss = validate(model, val_loader, criterion, device)
270
+
271
+ scheduler.step(val_loss)
272
+
273
+ # Log to MLflow
274
+ log_training_epoch(
275
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
276
+ )
277
+
278
+ LOGGER.info(
279
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
280
+ epoch + 1,
281
+ epochs,
282
+ train_loss,
283
+ val_loss,
284
+ optimizer.param_groups[0]["lr"],
285
+ )
286
+
287
+ # Early stopping
288
+ if val_loss < best_val_loss:
289
+ best_val_loss = val_loss
290
+ patience_counter = 0
291
+
292
+ model_directory.mkdir(exist_ok=True)
293
+ checkpoint_file = model_directory / "deep_wide_best.pth"
294
+
295
+ torch.save(
296
+ {
297
+ "model_state_dict": model.state_dict(),
298
+ "epoch": epoch,
299
+ "val_loss": val_loss,
300
+ },
301
+ checkpoint_file,
302
+ )
303
+
304
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
305
+ else:
306
+ patience_counter += 1
307
+ if patience_counter >= patience:
308
+ LOGGER.info("")
309
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
310
+ break
311
+
312
+ # Log final metrics
313
+ mlflow.log_metrics(
314
+ {
315
+ "best_val_loss": best_val_loss,
316
+ "final_epoch": epoch + 1,
317
+ }
318
+ )
319
+
320
+ # Export to ONNX
321
+ LOGGER.info("")
322
+ LOGGER.info("Exporting to ONNX...")
323
+ model.eval()
324
+
325
+ checkpoint = torch.load(checkpoint_file)
326
+ model.load_state_dict(checkpoint["model_state_dict"])
327
+
328
+ dummy_input = torch.randn(1, 3).to(device)
329
+
330
+ onnx_file = model_directory / "deep_wide.onnx"
331
+ torch.onnx.export(
332
+ model,
333
+ dummy_input,
334
+ onnx_file,
335
+ export_params=True,
336
+ opset_version=15,
337
+ input_names=["xyY"],
338
+ output_names=["munsell_spec"],
339
+ dynamic_axes={
340
+ "xyY": {0: "batch_size"},
341
+ "munsell_spec": {0: "batch_size"},
342
+ },
343
+ )
344
+
345
+ # Save normalization parameters alongside model
346
+ params_file = model_directory / "deep_wide_normalization_params.npz"
347
+ input_params = XYY_NORMALIZATION_PARAMS
348
+ np.savez(
349
+ params_file,
350
+ input_params=input_params,
351
+ output_params=output_params,
352
+ )
353
+
354
+ # Log artifacts to MLflow
355
+ mlflow.log_artifact(str(checkpoint_file))
356
+ mlflow.log_artifact(str(onnx_file))
357
+ mlflow.log_artifact(str(params_file))
358
+ mlflow.pytorch.log_model(model, "model")
359
+
360
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
361
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
362
+ LOGGER.info("Artifacts logged to MLflow")
363
+
364
+
365
+ LOGGER.info("=" * 80)
366
+
367
+
368
+ if __name__ == "__main__":
369
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
370
+
371
+ main()
learning_munsell/training/from_xyY/train_ft_transformer.py ADDED
@@ -0,0 +1,356 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train FT-Transformer model for xyY to Munsell conversion.
3
+
4
+ Option 4: Feature Tokenizer + Transformer architecture
5
+ - Input: 3 features (xyY) → each becomes a 256-dim token
6
+ - Add [CLS] token for regression
7
+ - 4-6 transformer blocks with multi-head attention
8
+ - Output: Take [CLS] token → MLP → 4 features
9
+ """
10
+
11
+ import logging
12
+ import click
13
+ from typing import Any
14
+
15
+ import mlflow
16
+ import mlflow.pytorch
17
+ import numpy as np
18
+ import torch
19
+ from numpy.typing import NDArray
20
+ from torch import nn, optim
21
+ from torch.utils.data import DataLoader, TensorDataset
22
+
23
+ from learning_munsell import PROJECT_ROOT
24
+ from learning_munsell.models.networks import FeatureTokenizer, TransformerBlock
25
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
26
+ from learning_munsell.utilities.data import (
27
+ MUNSELL_NORMALIZATION_PARAMS,
28
+ XYY_NORMALIZATION_PARAMS,
29
+ normalize_munsell,
30
+ )
31
+ from learning_munsell.utilities.losses import precision_focused_loss
32
+ from learning_munsell.utilities.training import train_epoch, validate
33
+
34
+ LOGGER = logging.getLogger(__name__)
35
+
36
+
37
+ class FTTransformer(nn.Module):
38
+ """
39
+ Feature Tokenizer + Transformer for xyY to Munsell conversion.
40
+
41
+ This model adapts transformer architecture for tabular data by tokenizing
42
+ each input feature separately and using self-attention to capture complex
43
+ feature interactions.
44
+
45
+ Architecture
46
+ ------------
47
+ - Tokenize each feature (3 features → 3 tokens)
48
+ - Add CLS token (4 tokens total)
49
+ - 4 transformer blocks with multi-head attention
50
+ - Extract CLS token → MLP head → 4 outputs
51
+
52
+ Parameters
53
+ ----------
54
+ num_features : int, optional
55
+ Number of input features (xyY), default is 3.
56
+ embedding_dim : int, optional
57
+ Dimension of token embeddings, default is 256.
58
+ num_blocks : int, optional
59
+ Number of transformer blocks, default is 4.
60
+ num_heads : int, optional
61
+ Number of attention heads, default is 4.
62
+ ff_dim : int, optional
63
+ Feedforward network hidden dimension, default is 512.
64
+ dropout : float, optional
65
+ Dropout probability, default is 0.1.
66
+
67
+ Attributes
68
+ ----------
69
+ tokenizer : FeatureTokenizer
70
+ Converts input features to token embeddings.
71
+ transformer_blocks : nn.ModuleList
72
+ Stack of transformer blocks.
73
+ output_head : nn.Sequential
74
+ MLP that maps CLS token to output predictions.
75
+ """
76
+
77
+ def __init__(
78
+ self,
79
+ num_features: int = 3,
80
+ embedding_dim: int = 256,
81
+ num_blocks: int = 4,
82
+ num_heads: int = 4,
83
+ ff_dim: int = 512,
84
+ dropout: float = 0.1,
85
+ ) -> None:
86
+ """Initialize the FT-Transformer model."""
87
+ super().__init__()
88
+
89
+ # Feature tokenizer
90
+ self.tokenizer = FeatureTokenizer(num_features, embedding_dim)
91
+
92
+ # Transformer blocks
93
+ self.transformer_blocks = nn.ModuleList(
94
+ [
95
+ TransformerBlock(embedding_dim, num_heads, ff_dim, dropout)
96
+ for _ in range(num_blocks)
97
+ ]
98
+ )
99
+
100
+ # Output head (from CLS token)
101
+ self.output_head = nn.Sequential(
102
+ nn.Linear(embedding_dim, 128),
103
+ nn.GELU(),
104
+ nn.Dropout(dropout),
105
+ nn.Linear(128, 4),
106
+ )
107
+
108
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
109
+ """
110
+ Forward pass through FT-Transformer.
111
+
112
+ Parameters
113
+ ----------
114
+ x : Tensor
115
+ Input xyY values of shape (batch_size, 3).
116
+
117
+ Returns
118
+ -------
119
+ Tensor
120
+ Predicted Munsell specification [hue, value, chroma, code]
121
+ of shape (batch_size, 4).
122
+ """
123
+ # Tokenize features
124
+ tokens = self.tokenizer(x) # (batch_size, 1+num_features, embedding_dim)
125
+
126
+ # Transformer blocks
127
+ for block in self.transformer_blocks:
128
+ tokens = block(tokens)
129
+
130
+ # Extract CLS token (first token)
131
+ cls_token = tokens[:, 0, :] # (batch_size, embedding_dim)
132
+
133
+ # Output head
134
+ return self.output_head(cls_token)
135
+
136
+
137
+ @click.command()
138
+ @click.option("--epochs", default=200, help="Number of training epochs")
139
+ @click.option("--batch-size", default=1024, help="Batch size for training")
140
+ @click.option("--lr", default=3e-4, help="Learning rate")
141
+ @click.option("--patience", default=20, help="Early stopping patience")
142
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
143
+ """
144
+ Train FT-Transformer model for xyY to Munsell conversion.
145
+
146
+ Notes
147
+ -----
148
+ The training pipeline:
149
+ 1. Loads normalization parameters from existing config
150
+ 2. Loads training data from cache
151
+ 3. Normalizes inputs and outputs to [0, 1] range
152
+ 4. Creates PyTorch DataLoaders
153
+ 5. Initializes FT-Transformer with feature tokenization
154
+ 6. Trains with AdamW optimizer and precision-focused loss
155
+ 7. Uses learning rate scheduler (ReduceLROnPlateau)
156
+ 8. Implements early stopping based on validation loss
157
+ 9. Exports best model to ONNX format
158
+ 10. Logs all metrics and artifacts to MLflow
159
+ """
160
+
161
+
162
+ LOGGER.info("=" * 80)
163
+ LOGGER.info("FT-Transformer: xyY → Munsell")
164
+ LOGGER.info("=" * 80)
165
+
166
+ # Set device
167
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
168
+ LOGGER.info("Using device: %s", device)
169
+
170
+ # Paths
171
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
172
+ data_dir = PROJECT_ROOT / "data"
173
+ cache_file = data_dir / "training_data.npz"
174
+
175
+ # Load training data
176
+ LOGGER.info("")
177
+ LOGGER.info("Loading training data from %s...", cache_file)
178
+ data = np.load(cache_file)
179
+ X_train = data["X_train"]
180
+ y_train = data["y_train"]
181
+ X_val = data["X_val"]
182
+ y_val = data["y_val"]
183
+
184
+ LOGGER.info("Train samples: %d", len(X_train))
185
+ LOGGER.info("Validation samples: %d", len(X_val))
186
+
187
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
188
+ output_params = MUNSELL_NORMALIZATION_PARAMS
189
+ y_train_norm = normalize_munsell(y_train, output_params)
190
+ y_val_norm = normalize_munsell(y_val, output_params)
191
+
192
+ # Convert to PyTorch tensors
193
+ X_train_t = torch.FloatTensor(X_train)
194
+ y_train_t = torch.FloatTensor(y_train_norm)
195
+ X_val_t = torch.FloatTensor(X_val)
196
+ y_val_t = torch.FloatTensor(y_val_norm)
197
+
198
+ # Create data loaders
199
+ train_dataset = TensorDataset(X_train_t, y_train_t)
200
+ val_dataset = TensorDataset(X_val_t, y_val_t)
201
+
202
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
203
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
204
+
205
+ # Initialize model
206
+ model = FTTransformer(
207
+ num_features=3,
208
+ embedding_dim=256,
209
+ num_blocks=4,
210
+ num_heads=4,
211
+ ff_dim=512,
212
+ dropout=0.1,
213
+ ).to(device)
214
+
215
+ LOGGER.info("")
216
+ LOGGER.info("FT-Transformer architecture:")
217
+ LOGGER.info("%s", model)
218
+
219
+ # Count parameters
220
+ total_params = sum(p.numel() for p in model.parameters())
221
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
222
+
223
+ # Training setup
224
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
225
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
226
+ optimizer, mode="min", factor=0.5, patience=5
227
+ )
228
+ criterion = precision_focused_loss
229
+
230
+ # MLflow setup
231
+ run_name = setup_mlflow_experiment("from_xyY", "ft_transformer")
232
+
233
+ LOGGER.info("")
234
+ LOGGER.info("MLflow run: %s", run_name)
235
+
236
+ # Training loop
237
+ best_val_loss = float("inf")
238
+ patience_counter = 0
239
+
240
+ LOGGER.info("")
241
+ LOGGER.info("Starting training...")
242
+
243
+ with mlflow.start_run(run_name=run_name):
244
+ mlflow.log_params(
245
+ {
246
+ "model": "ft_transformer",
247
+ "learning_rate": lr,
248
+ "batch_size": batch_size,
249
+ "num_epochs": epochs,
250
+ "patience": patience,
251
+ "total_params": total_params,
252
+ }
253
+ )
254
+
255
+ for epoch in range(epochs):
256
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
257
+ val_loss = validate(model, val_loader, criterion, device)
258
+
259
+ scheduler.step(val_loss)
260
+
261
+ log_training_epoch(
262
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
263
+ )
264
+
265
+ LOGGER.info(
266
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
267
+ epoch + 1,
268
+ epochs,
269
+ train_loss,
270
+ val_loss,
271
+ optimizer.param_groups[0]["lr"],
272
+ )
273
+
274
+ # Early stopping
275
+ if val_loss < best_val_loss:
276
+ best_val_loss = val_loss
277
+ patience_counter = 0
278
+
279
+ model_directory.mkdir(exist_ok=True)
280
+ checkpoint_file = model_directory / "ft_transformer_best.pth"
281
+
282
+ torch.save(
283
+ {
284
+ "model_state_dict": model.state_dict(),
285
+ "epoch": epoch,
286
+ "val_loss": val_loss,
287
+ },
288
+ checkpoint_file,
289
+ )
290
+
291
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
292
+ else:
293
+ patience_counter += 1
294
+ if patience_counter >= patience:
295
+ LOGGER.info("")
296
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
297
+ break
298
+
299
+ mlflow.log_metrics(
300
+ {
301
+ "best_val_loss": best_val_loss,
302
+ "final_epoch": epoch + 1,
303
+ }
304
+ )
305
+
306
+ # Export to ONNX
307
+ LOGGER.info("")
308
+ LOGGER.info("Exporting to ONNX...")
309
+ model.eval()
310
+
311
+ checkpoint = torch.load(checkpoint_file)
312
+ model.load_state_dict(checkpoint["model_state_dict"])
313
+
314
+ dummy_input = torch.randn(1, 3).to(device)
315
+
316
+ onnx_file = model_directory / "ft_transformer.onnx"
317
+ torch.onnx.export(
318
+ model,
319
+ dummy_input,
320
+ onnx_file,
321
+ export_params=True,
322
+ opset_version=15,
323
+ input_names=["xyY"],
324
+ output_names=["munsell_spec"],
325
+ dynamic_axes={
326
+ "xyY": {0: "batch_size"},
327
+ "munsell_spec": {0: "batch_size"},
328
+ },
329
+ )
330
+
331
+ # Save normalization parameters alongside model
332
+ params_file = model_directory / "ft_transformer_normalization_params.npz"
333
+ input_params = XYY_NORMALIZATION_PARAMS
334
+ np.savez(
335
+ params_file,
336
+ input_params=input_params,
337
+ output_params=output_params,
338
+ )
339
+
340
+ mlflow.log_artifact(str(checkpoint_file))
341
+ mlflow.log_artifact(str(onnx_file))
342
+ mlflow.log_artifact(str(params_file))
343
+ mlflow.pytorch.log_model(model, "model")
344
+
345
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
346
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
347
+ LOGGER.info("Artifacts logged to MLflow")
348
+
349
+
350
+ LOGGER.info("=" * 80)
351
+
352
+
353
+ if __name__ == "__main__":
354
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
355
+
356
+ main()
learning_munsell/training/from_xyY/train_mixture_of_experts.py ADDED
@@ -0,0 +1,620 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Mixture of Experts model for xyY to Munsell conversion.
3
+
4
+ Option 6: Mixture of Experts architecture
5
+ - Input: 3 features (xyY)
6
+ - Gating network: 3 → 128 → 64 → 4 (softmax weights)
7
+ - 4 Expert networks: Each 3 → 256 → 256 → 4 (MLP)
8
+ - Output: Weighted combination of expert outputs
9
+ """
10
+
11
+ import logging
12
+ import click
13
+
14
+ import mlflow
15
+ import mlflow.pytorch
16
+ import numpy as np
17
+ import torch
18
+ from numpy.typing import NDArray
19
+ from torch import nn, optim
20
+ from torch.utils.data import DataLoader, TensorDataset
21
+
22
+ from learning_munsell import PROJECT_ROOT
23
+ from learning_munsell.models.networks import ResidualBlock
24
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
25
+ from learning_munsell.utilities.data import (
26
+ MUNSELL_NORMALIZATION_PARAMS,
27
+ XYY_NORMALIZATION_PARAMS,
28
+ normalize_munsell,
29
+ )
30
+
31
+ LOGGER = logging.getLogger(__name__)
32
+
33
+
34
+ class ExpertNetwork(nn.Module):
35
+ """
36
+ Single expert network with MLP architecture.
37
+
38
+ Each expert is a specialized neural network that learns to handle
39
+ specific regions of the input space. Uses residual connections for
40
+ improved gradient flow.
41
+
42
+ Architecture
43
+ ------------
44
+ - Encoder: 3 → 256 with GELU and BatchNorm
45
+ - Residual blocks: Configurable number of ResidualBlock(256)
46
+ - Decoder: 256 → 4
47
+
48
+ Parameters
49
+ ----------
50
+ num_residual_blocks : int, optional
51
+ Number of residual blocks, default is 2.
52
+
53
+ Attributes
54
+ ----------
55
+ encoder : nn.Sequential
56
+ Input encoding layer.
57
+ residual_blocks : nn.ModuleList
58
+ Stack of residual blocks.
59
+ decoder : nn.Sequential
60
+ Output decoding layer.
61
+ """
62
+
63
+ def __init__(self, num_residual_blocks: int = 2) -> None:
64
+ """Initialize the expert network."""
65
+ super().__init__()
66
+
67
+ self.encoder = nn.Sequential(
68
+ nn.Linear(3, 256),
69
+ nn.GELU(),
70
+ nn.BatchNorm1d(256),
71
+ )
72
+
73
+ self.residual_blocks = nn.ModuleList(
74
+ [ResidualBlock(256) for _ in range(num_residual_blocks)]
75
+ )
76
+
77
+ self.decoder = nn.Sequential(
78
+ nn.Linear(256, 4),
79
+ )
80
+
81
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
82
+ """
83
+ Forward pass through expert network.
84
+
85
+ Parameters
86
+ ----------
87
+ x : Tensor
88
+ Input xyY values of shape (batch_size, 3).
89
+
90
+ Returns
91
+ -------
92
+ Tensor
93
+ Expert's prediction of shape (batch_size, 4).
94
+ """
95
+ x = self.encoder(x)
96
+ for block in self.residual_blocks:
97
+ x = block(x)
98
+ return self.decoder(x)
99
+
100
+
101
+ class GatingNetwork(nn.Module):
102
+ """
103
+ Gating network to compute expert weights.
104
+
105
+ Learns to route inputs to appropriate experts by outputting a probability
106
+ distribution over all experts. Different inputs activate different experts
107
+ based on learned input characteristics.
108
+
109
+ Architecture
110
+ ------------
111
+ 3 → 128 → 64 → num_experts → softmax
112
+
113
+ Parameters
114
+ ----------
115
+ num_experts : int
116
+ Number of expert networks to gate.
117
+
118
+ Attributes
119
+ ----------
120
+ gate : nn.Sequential
121
+ MLP that maps inputs to expert logits.
122
+ """
123
+
124
+ def __init__(self, num_experts: int) -> None:
125
+ """Initialize the gating network."""
126
+ super().__init__()
127
+
128
+ self.gate = nn.Sequential(
129
+ nn.Linear(3, 128),
130
+ nn.GELU(),
131
+ nn.BatchNorm1d(128),
132
+ nn.Linear(128, 64),
133
+ nn.GELU(),
134
+ nn.BatchNorm1d(64),
135
+ nn.Linear(64, num_experts),
136
+ )
137
+
138
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
139
+ """
140
+ Compute expert weights for input.
141
+
142
+ Parameters
143
+ ----------
144
+ x : Tensor
145
+ Input xyY values of shape (batch_size, 3).
146
+
147
+ Returns
148
+ -------
149
+ Tensor
150
+ Softmax weights over experts of shape (batch_size, num_experts).
151
+ Weights sum to 1 along expert dimension.
152
+ """
153
+ # Output softmax weights for each expert
154
+ return torch.softmax(self.gate(x), dim=-1)
155
+
156
+
157
+ class MixtureOfExperts(nn.Module):
158
+ """
159
+ Mixture of Experts for xyY to Munsell conversion.
160
+
161
+ Implements a mixture of experts architecture where multiple specialized
162
+ neural networks (experts) are combined via learned gating weights. This
163
+ allows different experts to specialize in different regions of the input
164
+ space (e.g., different color ranges or hue families).
165
+
166
+ Architecture
167
+ ------------
168
+ - Gating network: Learns which expert(s) to use for each input
169
+ - Multiple expert networks: Each specializes in different input regions
170
+ - Output: Weighted combination of expert predictions based on gate weights
171
+ - Load balancing: Auxiliary loss encourages balanced expert usage
172
+
173
+ Parameters
174
+ ----------
175
+ num_experts : int, optional
176
+ Number of expert networks, default is 4.
177
+ num_residual_blocks : int, optional
178
+ Number of residual blocks per expert, default is 2.
179
+
180
+ Attributes
181
+ ----------
182
+ num_experts : int
183
+ Number of expert networks.
184
+ gating_network : GatingNetwork
185
+ Network that computes expert weights.
186
+ experts : nn.ModuleList
187
+ List of expert networks.
188
+ load_balance_weight : float
189
+ Weight for load balancing auxiliary loss.
190
+ """
191
+
192
+ def __init__(self, num_experts: int = 4, num_residual_blocks: int = 2) -> None:
193
+ """Initialize the mixture of experts model."""
194
+ super().__init__()
195
+
196
+ self.num_experts = num_experts
197
+
198
+ # Gating network
199
+ self.gating_network = GatingNetwork(num_experts)
200
+
201
+ # Expert networks
202
+ self.experts = nn.ModuleList(
203
+ [ExpertNetwork(num_residual_blocks) for _ in range(num_experts)]
204
+ )
205
+
206
+ # Load balancing loss weight
207
+ self.load_balance_weight = 0.01
208
+
209
+ def forward(self, x: torch.Tensor) -> tuple[torch.Tensor, torch.Tensor]:
210
+ """
211
+ Forward pass through mixture of experts.
212
+
213
+ Parameters
214
+ ----------
215
+ x : Tensor
216
+ Input xyY values of shape (batch_size, 3).
217
+
218
+ Returns
219
+ -------
220
+ tuple
221
+ (output, gate_weights) where:
222
+ - output: Weighted expert predictions of shape (batch_size, 4)
223
+ - gate_weights: Expert weights of shape (batch_size, num_experts)
224
+ """
225
+ # Get gating weights
226
+ gate_weights = self.gating_network(x) # (batch_size, num_experts)
227
+
228
+ # Get expert outputs
229
+ expert_outputs = torch.stack(
230
+ [expert(x) for expert in self.experts], dim=1
231
+ ) # (batch_size, num_experts, 4)
232
+
233
+ # Weighted combination
234
+ gate_weights_expanded = gate_weights.unsqueeze(
235
+ -1
236
+ ) # (batch_size, num_experts, 1)
237
+ output = torch.sum(
238
+ expert_outputs * gate_weights_expanded, dim=1
239
+ ) # (batch_size, 4)
240
+
241
+ return output, gate_weights
242
+
243
+
244
+ def precision_focused_loss(
245
+ pred: torch.Tensor,
246
+ target: torch.Tensor,
247
+ gate_weights: torch.Tensor,
248
+ load_balance_weight: float = 0.01,
249
+ ) -> torch.Tensor:
250
+ """
251
+ Precision-focused loss function with load balancing for mixture of experts.
252
+
253
+ Combines standard regression losses (MSE, MAE, log penalty, Huber) with
254
+ a load balancing auxiliary loss that encourages uniform expert usage across
255
+ the dataset to prevent expert collapse.
256
+
257
+ Parameters
258
+ ----------
259
+ pred : torch.Tensor
260
+ Predicted values.
261
+ target : torch.Tensor
262
+ Target ground truth values.
263
+ gate_weights : torch.Tensor
264
+ Expert gating weights of shape (batch_size, num_experts).
265
+ load_balance_weight : float, optional
266
+ Weight for load balancing auxiliary loss, default is 0.01.
267
+
268
+ Returns
269
+ -------
270
+ torch.Tensor
271
+ Combined loss value including load balancing term.
272
+
273
+ Notes
274
+ -----
275
+ The load balancing loss encourages each expert to handle roughly
276
+ 1/num_experts of the data, preventing scenarios where only a few
277
+ experts are used while others remain idle.
278
+ """
279
+ # Standard precision loss
280
+ mse = torch.mean((pred - target) ** 2)
281
+ mae = torch.mean(torch.abs(pred - target))
282
+ log_penalty = torch.mean(torch.log1p(torch.abs(pred - target) * 1000.0))
283
+
284
+ delta = 0.01
285
+ abs_error = torch.abs(pred - target)
286
+ huber = torch.where(
287
+ abs_error <= delta, 0.5 * abs_error**2, delta * (abs_error - 0.5 * delta)
288
+ )
289
+ huber_loss = torch.mean(huber)
290
+
291
+ # Load balancing loss: Encourage balanced expert usage
292
+ # Compute importance (sum of gate weights per expert)
293
+ importance = gate_weights.sum(dim=0) # (num_experts,)
294
+ # Normalize to probabilities
295
+ importance = importance / importance.sum()
296
+ # Encourage uniform distribution (1/num_experts for each)
297
+ num_experts = gate_weights.size(1)
298
+ target_importance = torch.ones_like(importance) / num_experts
299
+ load_balance_loss = torch.mean((importance - target_importance) ** 2)
300
+
301
+ return (
302
+ 1.0 * mse
303
+ + 0.5 * mae
304
+ + 0.3 * log_penalty
305
+ + 0.5 * huber_loss
306
+ + load_balance_weight * load_balance_loss
307
+ )
308
+
309
+
310
+ def train_epoch(
311
+ model: nn.Module,
312
+ dataloader: DataLoader,
313
+ optimizer: optim.Optimizer,
314
+ device: torch.device,
315
+ ) -> float:
316
+ """
317
+ Train the mixture of experts model for one epoch.
318
+
319
+ Parameters
320
+ ----------
321
+ model : nn.Module
322
+ The neural network model to train.
323
+ dataloader : DataLoader
324
+ DataLoader providing training batches (X, y).
325
+ optimizer : optim.Optimizer
326
+ Optimizer for updating model parameters.
327
+ device : torch.device
328
+ Device to run training on.
329
+
330
+ Returns
331
+ -------
332
+ float
333
+ Average loss for the epoch.
334
+
335
+ Notes
336
+ -----
337
+ Loss includes both prediction error and load balancing term.
338
+ The loss function is computed by precision_focused_loss which is
339
+ passed gate_weights for load balancing.
340
+ """
341
+ model.train()
342
+ total_loss = 0.0
343
+
344
+ for X_batch, y_batch in dataloader:
345
+ X_batch = X_batch.to(device)
346
+ y_batch = y_batch.to(device)
347
+ outputs, gate_weights = model(X_batch)
348
+ loss = precision_focused_loss(
349
+ outputs, y_batch, gate_weights, model.load_balance_weight
350
+ )
351
+
352
+ optimizer.zero_grad()
353
+ loss.backward()
354
+ optimizer.step()
355
+
356
+ total_loss += loss.item()
357
+
358
+ return total_loss / len(dataloader)
359
+
360
+
361
+ def validate(model: nn.Module, dataloader: DataLoader, device: torch.device) -> float:
362
+ """
363
+ Validate the mixture of experts model on validation set.
364
+
365
+ Parameters
366
+ ----------
367
+ model : nn.Module
368
+ The neural network model to validate.
369
+ dataloader : DataLoader
370
+ DataLoader providing validation batches (X, y).
371
+ device : torch.device
372
+ Device to run validation on.
373
+
374
+ Returns
375
+ -------
376
+ float
377
+ Average loss for the validation set.
378
+ """
379
+ model.eval()
380
+ total_loss = 0.0
381
+
382
+ with torch.no_grad():
383
+ for X_batch, y_batch in dataloader:
384
+ X_batch = X_batch.to(device)
385
+ y_batch = y_batch.to(device)
386
+ outputs, gate_weights = model(X_batch)
387
+ loss = precision_focused_loss(
388
+ outputs, y_batch, gate_weights, model.load_balance_weight
389
+ )
390
+
391
+ total_loss += loss.item()
392
+
393
+ return total_loss / len(dataloader)
394
+
395
+
396
+ @click.command()
397
+ @click.option("--epochs", default=200, help="Number of training epochs")
398
+ @click.option("--batch-size", default=1024, help="Batch size for training")
399
+ @click.option("--lr", default=3e-4, help="Learning rate")
400
+ @click.option("--patience", default=20, help="Early stopping patience")
401
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
402
+ """
403
+ Train mixture of experts model for xyY to Munsell conversion.
404
+
405
+ Notes
406
+ -----
407
+ The training pipeline:
408
+ 1. Loads normalization parameters from existing config
409
+ 2. Loads training data from cache
410
+ 3. Normalizes inputs and outputs to [0, 1] range
411
+ 4. Creates PyTorch DataLoaders
412
+ 5. Initializes MixtureOfExperts with 4 expert networks
413
+ 6. Trains with AdamW optimizer and precision-focused loss
414
+ 7. Uses learning rate scheduler (ReduceLROnPlateau)
415
+ 8. Implements early stopping based on validation loss
416
+ 9. Exports best model to ONNX format
417
+ 10. Logs all metrics and artifacts to MLflow
418
+ """
419
+
420
+
421
+ LOGGER.info("=" * 80)
422
+ LOGGER.info("Mixture of Experts: xyY → Munsell")
423
+ LOGGER.info("=" * 80)
424
+
425
+ # Set device
426
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
427
+ LOGGER.info("Using device: %s", device)
428
+
429
+ # Paths
430
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
431
+ data_dir = PROJECT_ROOT / "data"
432
+ cache_file = data_dir / "training_data.npz"
433
+
434
+ # Load training data
435
+ LOGGER.info("")
436
+ LOGGER.info("Loading training data from %s...", cache_file)
437
+ data = np.load(cache_file)
438
+ X_train = data["X_train"]
439
+ y_train = data["y_train"]
440
+ X_val = data["X_val"]
441
+ y_val = data["y_val"]
442
+
443
+ LOGGER.info("Train samples: %d", len(X_train))
444
+ LOGGER.info("Validation samples: %d", len(X_val))
445
+
446
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
447
+ # Use hardcoded ranges covering the full Munsell space for generalization
448
+ output_params = MUNSELL_NORMALIZATION_PARAMS
449
+ y_train_norm = normalize_munsell(y_train, output_params)
450
+ y_val_norm = normalize_munsell(y_val, output_params)
451
+
452
+ # Convert to PyTorch tensors
453
+ X_train_t = torch.FloatTensor(X_train)
454
+ y_train_t = torch.FloatTensor(y_train_norm)
455
+ X_val_t = torch.FloatTensor(X_val)
456
+ y_val_t = torch.FloatTensor(y_val_norm)
457
+
458
+ # Create data loaders
459
+ train_dataset = TensorDataset(X_train_t, y_train_t)
460
+ val_dataset = TensorDataset(X_val_t, y_val_t)
461
+
462
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
463
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
464
+
465
+ # Initialize model
466
+ model = MixtureOfExperts(num_experts=4, num_residual_blocks=2).to(device)
467
+ LOGGER.info("")
468
+ LOGGER.info("Mixture of Experts architecture:")
469
+ LOGGER.info("%s", model)
470
+
471
+ # Count parameters
472
+ total_params = sum(p.numel() for p in model.parameters())
473
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
474
+
475
+ # Training setup
476
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
477
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
478
+ optimizer, mode="min", factor=0.5, patience=5
479
+ )
480
+
481
+ # MLflow setup
482
+ run_name = setup_mlflow_experiment("from_xyY", "mixture_of_experts")
483
+
484
+ LOGGER.info("")
485
+ LOGGER.info("MLflow run: %s", run_name)
486
+
487
+ # Training loop
488
+ best_val_loss = float("inf")
489
+ patience_counter = 0
490
+
491
+ LOGGER.info("")
492
+ LOGGER.info("Starting training...")
493
+
494
+ with mlflow.start_run(run_name=run_name):
495
+ mlflow.log_params(
496
+ {
497
+ "model": "mixture_of_experts",
498
+ "learning_rate": lr,
499
+ "batch_size": batch_size,
500
+ "num_epochs": epochs,
501
+ "patience": patience,
502
+ "total_params": total_params,
503
+ }
504
+ )
505
+
506
+ for epoch in range(epochs):
507
+ train_loss = train_epoch(model, train_loader, optimizer, device)
508
+ val_loss = validate(model, val_loader, device)
509
+
510
+ scheduler.step(val_loss)
511
+
512
+ log_training_epoch(
513
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
514
+ )
515
+
516
+ LOGGER.info(
517
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
518
+ epoch + 1,
519
+ epochs,
520
+ train_loss,
521
+ val_loss,
522
+ optimizer.param_groups[0]["lr"],
523
+ )
524
+
525
+ # Early stopping
526
+ if val_loss < best_val_loss:
527
+ best_val_loss = val_loss
528
+ patience_counter = 0
529
+
530
+ model_directory.mkdir(exist_ok=True)
531
+ checkpoint_file = model_directory / "mixture_of_experts_best.pth"
532
+
533
+ torch.save(
534
+ {
535
+ "model_state_dict": model.state_dict(),
536
+ "epoch": epoch,
537
+ "val_loss": val_loss,
538
+ },
539
+ checkpoint_file,
540
+ )
541
+
542
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
543
+ else:
544
+ patience_counter += 1
545
+ if patience_counter >= patience:
546
+ LOGGER.info("")
547
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
548
+ break
549
+
550
+ mlflow.log_metrics(
551
+ {
552
+ "best_val_loss": best_val_loss,
553
+ "final_epoch": epoch + 1,
554
+ }
555
+ )
556
+
557
+ # Export to ONNX (simplified - outputs only prediction, not gate weights)
558
+ LOGGER.info("")
559
+ LOGGER.info("Exporting to ONNX...")
560
+ model.eval()
561
+
562
+ checkpoint = torch.load(checkpoint_file)
563
+ model.load_state_dict(checkpoint["model_state_dict"])
564
+
565
+ # Create wrapper for ONNX export (only return prediction)
566
+ class MoEWrapper(nn.Module):
567
+ def __init__(self, moe_model: nn.Module) -> None:
568
+ super().__init__()
569
+ self.moe_model = moe_model
570
+
571
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
572
+ output, _ = self.moe_model(x)
573
+ return output
574
+
575
+ wrapped_model = MoEWrapper(model).to(device)
576
+ wrapped_model.eval()
577
+
578
+ dummy_input = torch.randn(1, 3).to(device)
579
+
580
+ onnx_file = model_directory / "mixture_of_experts.onnx"
581
+ torch.onnx.export(
582
+ wrapped_model,
583
+ dummy_input,
584
+ onnx_file,
585
+ export_params=True,
586
+ opset_version=15,
587
+ input_names=["xyY"],
588
+ output_names=["munsell_spec"],
589
+ dynamic_axes={
590
+ "xyY": {0: "batch_size"},
591
+ "munsell_spec": {0: "batch_size"},
592
+ },
593
+ )
594
+
595
+ # Save normalization parameters alongside model
596
+ params_file = model_directory / "mixture_of_experts_normalization_params.npz"
597
+ input_params = XYY_NORMALIZATION_PARAMS
598
+ np.savez(
599
+ params_file,
600
+ input_params=input_params,
601
+ output_params=output_params,
602
+ )
603
+
604
+ mlflow.log_artifact(str(checkpoint_file))
605
+ mlflow.log_artifact(str(onnx_file))
606
+ mlflow.log_artifact(str(params_file))
607
+ mlflow.pytorch.log_model(model, "model")
608
+
609
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
610
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
611
+ LOGGER.info("Artifacts logged to MLflow")
612
+
613
+
614
+ LOGGER.info("=" * 80)
615
+
616
+
617
+ if __name__ == "__main__":
618
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
619
+
620
+ main()
learning_munsell/training/from_xyY/train_mlp.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train ML model for xyY to Munsell conversion.
3
+
4
+ This script trains a compact MLP/DNN model with architecture:
5
+ 3 inputs → [64, 128, 128, 64] hidden layers → 4 outputs
6
+
7
+ Target: < 1e-7 accuracy compared to iterative algorithm
8
+ """
9
+
10
+ import logging
11
+
12
+ import click
13
+ import mlflow
14
+ import mlflow.pytorch
15
+ import numpy as np
16
+ import torch
17
+ from torch import optim
18
+ from torch.utils.data import DataLoader, TensorDataset
19
+
20
+ from learning_munsell import PROJECT_ROOT
21
+ from learning_munsell.models.networks import MLPToMunsell
22
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
23
+ from learning_munsell.utilities.data import (
24
+ MUNSELL_NORMALIZATION_PARAMS,
25
+ XYY_NORMALIZATION_PARAMS,
26
+ normalize_munsell,
27
+ )
28
+ from learning_munsell.utilities.losses import weighted_mse_loss
29
+ from learning_munsell.utilities.training import train_epoch, validate
30
+
31
+ LOGGER = logging.getLogger(__name__)
32
+
33
+
34
+ @click.command()
35
+ @click.option("--epochs", default=200, help="Maximum training epochs.")
36
+ @click.option("--batch-size", default=1024, help="Training batch size.")
37
+ @click.option("--lr", default=5e-4, help="Learning rate.")
38
+ @click.option("--patience", default=20, help="Early stopping patience.")
39
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
40
+ """
41
+ Train the MLPToMunsell model for xyY to Munsell conversion.
42
+
43
+ Parameters
44
+ ----------
45
+ epochs : int
46
+ Maximum number of training epochs.
47
+ batch_size : int
48
+ Training batch size.
49
+ lr : float
50
+ Learning rate for AdamW optimizer.
51
+ patience : int
52
+ Early stopping patience (epochs without improvement).
53
+
54
+ Notes
55
+ -----
56
+ The training pipeline:
57
+ 1. Loads training data from cache
58
+ 2. Normalizes Munsell outputs to [0, 1] range
59
+ 3. Trains compact MLP model (3 → [64, 128, 128, 64] → 4)
60
+ 4. Uses weighted MSE loss function
61
+ 5. Learning rate scheduling with ReduceLROnPlateau
62
+ 6. Early stopping based on validation loss
63
+ 7. Exports model to ONNX format
64
+ 8. Logs metrics and artifacts to MLflow
65
+ """
66
+ LOGGER.info("=" * 80)
67
+ LOGGER.info("ML-Based xyY to Munsell Conversion: Model Training")
68
+ LOGGER.info("=" * 80)
69
+
70
+ # Set device
71
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
+ LOGGER.info("Using device: %s", device)
73
+
74
+ # Load training data
75
+ data_dir = PROJECT_ROOT / "data"
76
+ cache_file = data_dir / "training_data.npz"
77
+
78
+ if not cache_file.exists():
79
+ LOGGER.error("Error: Training data not found at %s", cache_file)
80
+ LOGGER.error("Please run 01_generate_training_data.py first")
81
+ return
82
+
83
+ LOGGER.info("Loading training data from %s...", cache_file)
84
+ data = np.load(cache_file)
85
+
86
+ X_train = data["X_train"]
87
+ y_train = data["y_train"]
88
+ X_val = data["X_val"]
89
+ y_val = data["y_val"]
90
+
91
+ # Note: Invalid samples (outside Munsell gamut) are also stored in the cache
92
+ # Available as: data['xyY_all'], data['munsell_all'], data['valid_mask']
93
+ # These can be used for future enhancements like:
94
+ # - Adversarial training to avoid extrapolation
95
+ # - Gamut-aware loss functions
96
+ # - Uncertainty estimation at boundaries
97
+
98
+ LOGGER.info("Train samples: %d", len(X_train))
99
+ LOGGER.info("Validation samples: %d", len(X_val))
100
+
101
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
102
+ # Use hardcoded ranges covering the full Munsell space for generalization
103
+ output_params = MUNSELL_NORMALIZATION_PARAMS
104
+ y_train_norm = normalize_munsell(y_train, output_params)
105
+ y_val_norm = normalize_munsell(y_val, output_params)
106
+
107
+ # Convert to PyTorch tensors
108
+ X_train_t = torch.FloatTensor(X_train)
109
+ y_train_t = torch.FloatTensor(y_train_norm)
110
+ X_val_t = torch.FloatTensor(X_val)
111
+ y_val_t = torch.FloatTensor(y_val_norm)
112
+
113
+ # Create data loaders
114
+ train_dataset = TensorDataset(X_train_t, y_train_t)
115
+ val_dataset = TensorDataset(X_val_t, y_val_t)
116
+
117
+ # Larger batch size for larger dataset (500K samples)
118
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
119
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
120
+
121
+ # Initialize model
122
+ model = MLPToMunsell().to(device)
123
+ LOGGER.info("")
124
+ LOGGER.info("Model architecture:")
125
+ LOGGER.info("%s", model)
126
+
127
+ # Count parameters
128
+ total_params = sum(p.numel() for p in model.parameters())
129
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
130
+
131
+ # Training setup - lower learning rate for larger model
132
+ optimizer = optim.Adam(model.parameters(), lr=lr)
133
+ # Use weighted MSE with default weights
134
+ weights = torch.tensor([1.0, 1.0, 2.0, 0.5])
135
+ criterion = lambda pred, target: weighted_mse_loss(pred, target, weights)
136
+
137
+ # MLflow setup
138
+ run_name = setup_mlflow_experiment("from_xyY", "mlp")
139
+
140
+ LOGGER.info("")
141
+ LOGGER.info("MLflow run: %s", run_name)
142
+
143
+ # Training loop
144
+ best_val_loss = float("inf")
145
+ patience_counter = 0
146
+
147
+ LOGGER.info("")
148
+ LOGGER.info("Starting training...")
149
+
150
+ with mlflow.start_run(run_name=run_name):
151
+ # Log hyperparameters
152
+ mlflow.log_params(
153
+ {
154
+ "epochs": epochs,
155
+ "batch_size": batch_size,
156
+ "learning_rate": lr,
157
+ "optimizer": "Adam",
158
+ "criterion": "weighted_mse_loss",
159
+ "patience": patience,
160
+ "total_params": total_params,
161
+ }
162
+ )
163
+
164
+ for epoch in range(epochs):
165
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
166
+ val_loss = validate(model, val_loader, criterion, device)
167
+
168
+ # Log to MLflow
169
+ log_training_epoch(
170
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
171
+ )
172
+
173
+ LOGGER.info(
174
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
175
+ epoch + 1,
176
+ epochs,
177
+ train_loss,
178
+ val_loss,
179
+ )
180
+
181
+ # Early stopping
182
+ if val_loss < best_val_loss:
183
+ best_val_loss = val_loss
184
+ patience_counter = 0
185
+
186
+ # Save best model
187
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
188
+ model_directory.mkdir(exist_ok=True)
189
+ checkpoint_file = model_directory / "mlp_best.pth"
190
+
191
+ torch.save(
192
+ {
193
+ "model_state_dict": model.state_dict(),
194
+ "output_params": output_params,
195
+ "epoch": epoch,
196
+ "val_loss": val_loss,
197
+ },
198
+ checkpoint_file,
199
+ )
200
+
201
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
202
+ else:
203
+ patience_counter += 1
204
+ if patience_counter >= patience:
205
+ LOGGER.info("")
206
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
207
+ break
208
+
209
+ # Log final metrics
210
+ mlflow.log_metrics(
211
+ {
212
+ "best_val_loss": best_val_loss,
213
+ "final_epoch": epoch + 1,
214
+ }
215
+ )
216
+
217
+ # Export to ONNX
218
+ LOGGER.info("")
219
+ LOGGER.info("Exporting model to ONNX...")
220
+ model.eval()
221
+
222
+ # Load best model
223
+ checkpoint = torch.load(checkpoint_file)
224
+ model.load_state_dict(checkpoint["model_state_dict"])
225
+
226
+ # Create dummy input
227
+ dummy_input = torch.randn(1, 3).to(device)
228
+
229
+ # Export
230
+ onnx_file = model_directory / "mlp.onnx"
231
+ torch.onnx.export(
232
+ model,
233
+ dummy_input,
234
+ onnx_file,
235
+ export_params=True,
236
+ opset_version=15,
237
+ input_names=["xyY"],
238
+ output_names=["munsell_spec"],
239
+ dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
240
+ )
241
+
242
+ # Save normalization parameters alongside model
243
+ params_file = model_directory / "mlp_normalization_params.npz"
244
+ input_params = XYY_NORMALIZATION_PARAMS
245
+ np.savez(
246
+ params_file,
247
+ input_params=input_params,
248
+ output_params=output_params,
249
+ )
250
+
251
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
252
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
253
+
254
+ # Log artifacts
255
+ mlflow.log_artifact(str(checkpoint_file))
256
+ mlflow.log_artifact(str(onnx_file))
257
+ mlflow.log_artifact(str(params_file))
258
+
259
+ # Log model
260
+ mlflow.pytorch.log_model(model, "model")
261
+
262
+
263
+ LOGGER.info("=" * 80)
264
+
265
+
266
+ if __name__ == "__main__":
267
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
268
+
269
+ main()
learning_munsell/training/from_xyY/train_mlp_attention.py ADDED
@@ -0,0 +1,460 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train MLP + Self-Attention model for xyY to Munsell conversion.
3
+
4
+ Option 1: MLP backbone with multi-head self-attention layers
5
+ - Input: 3 features (xyY)
6
+ - Architecture: 3 -> 512 -> 1024 + [Attention + ResBlock] x 4 -> 512 -> 4
7
+ - Output: 4 features (hue, value, chroma, code)
8
+ """
9
+
10
+ import logging
11
+ import click
12
+ import mlflow
13
+ import mlflow.pytorch
14
+ import numpy as np
15
+ import torch
16
+ from numpy.typing import NDArray
17
+ from torch import nn, optim
18
+ from torch.utils.data import DataLoader, TensorDataset
19
+
20
+ from learning_munsell import PROJECT_ROOT
21
+ from learning_munsell.models.networks import ResidualBlock
22
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
23
+ from learning_munsell.utilities.data import (
24
+ MUNSELL_NORMALIZATION_PARAMS,
25
+ XYY_NORMALIZATION_PARAMS,
26
+ normalize_munsell,
27
+ )
28
+ from learning_munsell.utilities.losses import precision_focused_loss
29
+ from learning_munsell.utilities.training import train_epoch, validate
30
+
31
+ LOGGER = logging.getLogger(__name__)
32
+
33
+
34
+ class MultiHeadSelfAttention(nn.Module):
35
+ """
36
+ Multi-head self-attention layer for feature interaction.
37
+
38
+ Implements scaled dot-product attention with multiple heads to capture
39
+ different aspects of feature relationships.
40
+
41
+ Parameters
42
+ ----------
43
+ dim
44
+ Input and output feature dimension.
45
+ num_heads
46
+ Number of attention heads. Must divide ``dim`` evenly.
47
+
48
+ Attributes
49
+ ----------
50
+ query
51
+ Linear projection for query vectors.
52
+ key
53
+ Linear projection for key vectors.
54
+ value
55
+ Linear projection for value vectors.
56
+ out
57
+ Output projection after attention.
58
+ scale
59
+ Scaling factor (1/sqrt(head_dim)) for dot-product attention.
60
+ """
61
+
62
+ def __init__(self, dim: int, num_heads: int = 4) -> None:
63
+ super().__init__()
64
+ self.num_heads = num_heads
65
+ self.dim = dim
66
+ self.head_dim = dim // num_heads
67
+
68
+ assert dim % num_heads == 0, "dim must be divisible by num_heads" # noqa: S101
69
+
70
+ self.query = nn.Linear(dim, dim)
71
+ self.key = nn.Linear(dim, dim)
72
+ self.value = nn.Linear(dim, dim)
73
+ self.out = nn.Linear(dim, dim)
74
+
75
+ self.scale = self.head_dim**-0.5
76
+
77
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
78
+ """
79
+ Apply multi-head self-attention.
80
+
81
+ Parameters
82
+ ----------
83
+ x
84
+ Input tensor of shape ``(batch_size, dim)``.
85
+
86
+ Returns
87
+ -------
88
+ torch.Tensor
89
+ Output tensor of shape ``(batch_size, dim)`` with attention applied.
90
+ """
91
+ batch_size = x.size(0)
92
+
93
+ # Linear projections
94
+ Q = self.query(x).view(batch_size, self.num_heads, self.head_dim)
95
+ K = self.key(x).view(batch_size, self.num_heads, self.head_dim)
96
+ V = self.value(x).view(batch_size, self.num_heads, self.head_dim)
97
+
98
+ # Scaled dot-product attention
99
+ attn_weights = torch.softmax(
100
+ torch.matmul(Q, K.transpose(-2, -1)) * self.scale, dim=-1
101
+ )
102
+
103
+ # Apply attention to values
104
+ attn_output = torch.matmul(attn_weights, V)
105
+
106
+ # Concatenate heads and project
107
+ attn_output = attn_output.view(batch_size, self.dim)
108
+ return self.out(attn_output)
109
+
110
+
111
+ class AttentionResBlock(nn.Module):
112
+ """
113
+ Combined attention and residual block.
114
+
115
+ Applies self-attention followed by a residual MLP block, each with
116
+ batch normalization and skip connections.
117
+
118
+ Parameters
119
+ ----------
120
+ dim
121
+ Input and output feature dimension.
122
+ num_heads
123
+ Number of attention heads for the self-attention layer.
124
+
125
+ Attributes
126
+ ----------
127
+ attention
128
+ Multi-head self-attention layer.
129
+ norm1
130
+ Batch normalization after attention.
131
+ residual
132
+ Residual MLP block.
133
+ norm2
134
+ Batch normalization after residual block.
135
+ """
136
+
137
+ def __init__(self, dim: int, num_heads: int = 4) -> None:
138
+ super().__init__()
139
+ self.attention = MultiHeadSelfAttention(dim, num_heads)
140
+ self.norm1 = nn.BatchNorm1d(dim)
141
+ self.residual = ResidualBlock(dim)
142
+ self.norm2 = nn.BatchNorm1d(dim)
143
+
144
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
145
+ """
146
+ Apply attention and residual transformations.
147
+
148
+ Parameters
149
+ ----------
150
+ x
151
+ Input tensor of shape ``(batch_size, dim)``.
152
+
153
+ Returns
154
+ -------
155
+ torch.Tensor
156
+ Output tensor of shape ``(batch_size, dim)``.
157
+ """
158
+ # Attention with residual
159
+ attn_out = self.norm1(x + self.attention(x))
160
+ # ResBlock with residual
161
+ return self.norm2(self.residual(attn_out))
162
+
163
+
164
+ class MLPAttention(nn.Module):
165
+ """
166
+ MLP with self-attention for xyY to Munsell conversion.
167
+
168
+ Architecture:
169
+ - Input: 3 features (xyY normalized to [0, 1])
170
+ - Encoder: 3 -> 512 -> 1024
171
+ - Attention-ResBlocks at 1024-dim (configurable count)
172
+ - Decoder: 1024 -> 512 -> 4
173
+ - Output: 4 features (hue, value, chroma, code normalized)
174
+
175
+ Parameters
176
+ ----------
177
+ num_blocks
178
+ Number of attention-residual blocks in the middle.
179
+ num_heads
180
+ Number of attention heads in each attention layer.
181
+
182
+ Attributes
183
+ ----------
184
+ encoder
185
+ MLP that projects 3D xyY input to 1024D feature space.
186
+ blocks
187
+ List of AttentionResBlock modules.
188
+ decoder
189
+ MLP that projects 1024D features to 4D Munsell output.
190
+ """
191
+
192
+ def __init__(self, num_blocks: int = 4, num_heads: int = 4) -> None:
193
+ super().__init__()
194
+
195
+ # Encoder
196
+ self.encoder = nn.Sequential(
197
+ nn.Linear(3, 512),
198
+ nn.GELU(),
199
+ nn.BatchNorm1d(512),
200
+ nn.Linear(512, 1024),
201
+ nn.GELU(),
202
+ nn.BatchNorm1d(1024),
203
+ )
204
+
205
+ # Attention-ResBlocks
206
+ self.blocks = nn.ModuleList(
207
+ [AttentionResBlock(1024, num_heads) for _ in range(num_blocks)]
208
+ )
209
+
210
+ # Decoder
211
+ self.decoder = nn.Sequential(
212
+ nn.Linear(1024, 512),
213
+ nn.GELU(),
214
+ nn.BatchNorm1d(512),
215
+ nn.Linear(512, 4),
216
+ )
217
+
218
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
219
+ """
220
+ Predict Munsell specification from xyY input.
221
+
222
+ Parameters
223
+ ----------
224
+ x
225
+ Input tensor of shape ``(batch_size, 3)`` containing normalized
226
+ xyY values.
227
+
228
+ Returns
229
+ -------
230
+ torch.Tensor
231
+ Output tensor of shape ``(batch_size, 4)`` containing normalized
232
+ Munsell specification [hue, value, chroma, code].
233
+ """
234
+ # Encode
235
+ x = self.encoder(x)
236
+
237
+ # Attention-ResBlocks
238
+ for block in self.blocks:
239
+ x = block(x)
240
+
241
+ # Decode
242
+ return self.decoder(x)
243
+
244
+
245
+ @click.command()
246
+ @click.option("--epochs", default=200, help="Number of training epochs")
247
+ @click.option("--batch-size", default=1024, help="Batch size for training")
248
+ @click.option("--lr", default=3e-4, help="Learning rate")
249
+ @click.option("--patience", default=20, help="Early stopping patience")
250
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
251
+ """
252
+ Train MLP + Self-Attention model for xyY to Munsell conversion.
253
+
254
+ Notes
255
+ -----
256
+ The training pipeline:
257
+ 1. Loads normalization parameters and training data from disk
258
+ 2. Normalizes inputs (xyY) and outputs (Munsell specification) to [0, 1]
259
+ 3. Creates MLPAttention model (4 blocks, 4 attention heads)
260
+ 4. Trains with precision-focused loss (MSE + MAE + log + Huber)
261
+ 5. Uses AdamW optimizer with ReduceLROnPlateau scheduler
262
+ 6. Applies early stopping based on validation loss (patience=20)
263
+ 7. Exports best model to ONNX format
264
+ 8. Logs metrics and artifacts to MLflow
265
+ """
266
+ LOGGER.info("=" * 80)
267
+ LOGGER.info("MLP + Self-Attention: xyY → Munsell")
268
+ LOGGER.info("=" * 80)
269
+
270
+ # Set device
271
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
272
+ LOGGER.info("Using device: %s", device)
273
+
274
+ # Paths
275
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
276
+ data_dir = PROJECT_ROOT / "data"
277
+ cache_file = data_dir / "training_data.npz"
278
+
279
+ # Load training data
280
+ LOGGER.info("")
281
+ LOGGER.info("Loading training data from %s...", cache_file)
282
+ data = np.load(cache_file)
283
+ X_train = data["X_train"]
284
+ y_train = data["y_train"]
285
+ X_val = data["X_val"]
286
+ y_val = data["y_val"]
287
+
288
+ LOGGER.info("Train samples: %d", len(X_train))
289
+ LOGGER.info("Validation samples: %d", len(X_val))
290
+
291
+ output_params = MUNSELL_NORMALIZATION_PARAMS
292
+ y_train_norm = normalize_munsell(y_train, output_params)
293
+ y_val_norm = normalize_munsell(y_val, output_params)
294
+
295
+ # Convert to PyTorch tensors
296
+ X_train_t = torch.FloatTensor(X_train)
297
+ y_train_t = torch.FloatTensor(y_train_norm)
298
+ X_val_t = torch.FloatTensor(X_val)
299
+ y_val_t = torch.FloatTensor(y_val_norm)
300
+
301
+ # Create data loaders
302
+ train_dataset = TensorDataset(X_train_t, y_train_t)
303
+ val_dataset = TensorDataset(X_val_t, y_val_t)
304
+
305
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
306
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
307
+
308
+ # Initialize model
309
+ model = MLPAttention(num_blocks=4, num_heads=4).to(device)
310
+ LOGGER.info("")
311
+ LOGGER.info("MLP + Attention architecture:")
312
+ LOGGER.info("%s", model)
313
+
314
+ # Count parameters
315
+ total_params = sum(p.numel() for p in model.parameters())
316
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
317
+
318
+ # Training setup
319
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
320
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
321
+ optimizer, mode="min", factor=0.5, patience=5
322
+ )
323
+ criterion = precision_focused_loss
324
+
325
+ # MLflow setup
326
+ run_name = setup_mlflow_experiment("from_xyY", "mlp_attention")
327
+
328
+ LOGGER.info("")
329
+ LOGGER.info("MLflow run: %s", run_name)
330
+
331
+ # Training loop
332
+ best_val_loss = float("inf")
333
+ patience_counter = 0
334
+
335
+ LOGGER.info("")
336
+ LOGGER.info("Starting training...")
337
+
338
+ with mlflow.start_run(run_name=run_name):
339
+ # Log hyperparameters
340
+ mlflow.log_params(
341
+ {
342
+ "num_epochs": epochs,
343
+ "batch_size": batch_size,
344
+ "learning_rate": lr,
345
+ "weight_decay": 1e-5,
346
+ "optimizer": "AdamW",
347
+ "scheduler": "ReduceLROnPlateau",
348
+ "criterion": "precision_focused_loss",
349
+ "patience": patience,
350
+ "total_params": total_params,
351
+ "num_blocks": 4,
352
+ "num_heads": 4,
353
+ }
354
+ )
355
+
356
+ for epoch in range(epochs):
357
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
358
+ val_loss = validate(model, val_loader, criterion, device)
359
+
360
+ scheduler.step(val_loss)
361
+
362
+ log_training_epoch(
363
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
364
+ )
365
+
366
+ LOGGER.info(
367
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
368
+ epoch + 1,
369
+ epochs,
370
+ train_loss,
371
+ val_loss,
372
+ optimizer.param_groups[0]["lr"],
373
+ )
374
+
375
+ # Early stopping
376
+ if val_loss < best_val_loss:
377
+ best_val_loss = val_loss
378
+ patience_counter = 0
379
+
380
+ model_directory.mkdir(exist_ok=True)
381
+ checkpoint_file = model_directory / "mlp_attention_best.pth"
382
+
383
+ torch.save(
384
+ {
385
+ "model_state_dict": model.state_dict(),
386
+ "epoch": epoch,
387
+ "val_loss": val_loss,
388
+ },
389
+ checkpoint_file,
390
+ )
391
+
392
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
393
+ else:
394
+ patience_counter += 1
395
+ if patience_counter >= patience:
396
+ LOGGER.info("")
397
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
398
+ break
399
+
400
+ # Log final metrics
401
+ mlflow.log_metrics(
402
+ {
403
+ "best_val_loss": best_val_loss,
404
+ "final_epoch": epoch + 1,
405
+ }
406
+ )
407
+
408
+ # Export to ONNX
409
+ LOGGER.info("")
410
+ LOGGER.info("Exporting to ONNX...")
411
+ model.eval()
412
+
413
+ checkpoint = torch.load(checkpoint_file)
414
+ model.load_state_dict(checkpoint["model_state_dict"])
415
+
416
+ dummy_input = torch.randn(1, 3).to(device)
417
+
418
+ onnx_file = model_directory / "mlp_attention.onnx"
419
+ torch.onnx.export(
420
+ model,
421
+ dummy_input,
422
+ onnx_file,
423
+ export_params=True,
424
+ opset_version=15,
425
+ input_names=["xyY"],
426
+ output_names=["munsell_spec"],
427
+ dynamic_axes={
428
+ "xyY": {0: "batch_size"},
429
+ "munsell_spec": {0: "batch_size"},
430
+ },
431
+ )
432
+
433
+ # Save normalization parameters alongside model
434
+ params_file = model_directory / "mlp_attention_normalization_params.npz"
435
+ input_params = XYY_NORMALIZATION_PARAMS
436
+ np.savez(
437
+ params_file,
438
+ input_params=input_params,
439
+ output_params=output_params,
440
+ )
441
+
442
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
443
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
444
+
445
+ # Log artifacts
446
+ mlflow.log_artifact(str(checkpoint_file))
447
+ mlflow.log_artifact(str(onnx_file))
448
+ mlflow.log_artifact(str(params_file))
449
+
450
+ # Log model
451
+ mlflow.pytorch.log_model(model, "model")
452
+
453
+
454
+ LOGGER.info("=" * 80)
455
+
456
+
457
+ if __name__ == "__main__":
458
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
459
+
460
+ main()
learning_munsell/training/from_xyY/train_mlp_error_predictor.py ADDED
@@ -0,0 +1,457 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train error predictor with advanced MLP architecture.
3
+
4
+ Architecture features:
5
+ - Larger capacity: 7 → 256 → 512 → 512 → 256 → 4
6
+ - Residual connections (MLP-style) for better gradient flow
7
+ - Modern activation functions (GELU instead of ReLU)
8
+ - Precision-focused loss function
9
+
10
+ Generic error predictor that can work with any base model.
11
+ """
12
+
13
+ import logging
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import click
18
+ import mlflow
19
+ import mlflow.pytorch
20
+ import numpy as np
21
+ import onnxruntime as ort
22
+ import torch
23
+ from numpy.typing import NDArray
24
+ from torch import nn, optim
25
+ from torch.utils.data import DataLoader, TensorDataset
26
+
27
+ from learning_munsell import PROJECT_ROOT
28
+ from learning_munsell.models.networks import ResidualBlock
29
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
30
+ from learning_munsell.utilities.data import normalize_munsell, normalize_xyY
31
+ from learning_munsell.utilities.losses import precision_focused_loss
32
+ from learning_munsell.utilities.training import train_epoch, validate
33
+
34
+ LOGGER = logging.getLogger(__name__)
35
+
36
+ # Note: This script has a custom ErrorPredictorMLP architecture
37
+ # so we don't import ComponentErrorPredictor/MultiHeadErrorPredictor from shared modules.
38
+
39
+
40
+ class ErrorPredictorMLP(nn.Module):
41
+ """
42
+ Advanced error predictor with residual connections.
43
+
44
+ This model implements a two-stage architecture for Munsell color prediction:
45
+ 1. Base model makes initial predictions from xyY coordinates
46
+ 2. Error predictor learns residual corrections to improve base predictions
47
+
48
+ The error predictor uses MLP-style residual blocks for better gradient
49
+ flow and deeper representations. It takes both the input xyY coordinates
50
+ and the base model's predictions to predict the error that should be added
51
+ to the base predictions.
52
+
53
+ Architecture:
54
+ - Input: 7 features (xyY_norm + base_pred_norm)
55
+ - Encoder: 7 → 256 → 512
56
+ - Residual blocks at 512-dim
57
+ - Decoder: 512 → 256 → 128 → 4
58
+ - Uses GELU activations and residual connections
59
+
60
+ Parameters
61
+ ----------
62
+ num_residual_blocks : int, optional
63
+ Number of residual blocks to use in the middle of the network.
64
+ Default is 3.
65
+
66
+ Attributes
67
+ ----------
68
+ encoder : nn.Sequential
69
+ Encoder network that maps 7D input to 512D representation.
70
+ residual_blocks : nn.ModuleList
71
+ List of residual blocks for deep feature extraction.
72
+ decoder : nn.Sequential
73
+ Decoder network that maps 512D representation to 4D error prediction.
74
+ """
75
+
76
+ def __init__(self, num_residual_blocks: int = 3) -> None:
77
+ super().__init__()
78
+
79
+ # Encoder
80
+ self.encoder = nn.Sequential(
81
+ nn.Linear(7, 256),
82
+ nn.GELU(),
83
+ nn.BatchNorm1d(256),
84
+ nn.Linear(256, 512),
85
+ nn.GELU(),
86
+ nn.BatchNorm1d(512),
87
+ )
88
+
89
+ # Residual blocks
90
+ self.residual_blocks = nn.ModuleList(
91
+ [ResidualBlock(512) for _ in range(num_residual_blocks)]
92
+ )
93
+
94
+ # Decoder
95
+ self.decoder = nn.Sequential(
96
+ nn.Linear(512, 256),
97
+ nn.GELU(),
98
+ nn.BatchNorm1d(256),
99
+ nn.Linear(256, 128),
100
+ nn.GELU(),
101
+ nn.BatchNorm1d(128),
102
+ nn.Linear(128, 4),
103
+ )
104
+
105
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
106
+ """
107
+ Forward pass through the error predictor.
108
+
109
+ Parameters
110
+ ----------
111
+ x : Tensor
112
+ Combined input [xyY_norm, base_pred_norm] of shape (batch_size, 7).
113
+
114
+ Returns
115
+ -------
116
+ Tensor
117
+ Predicted error correction of shape (batch_size, 4).
118
+ """
119
+ # Encode
120
+ x = self.encoder(x)
121
+
122
+ # Residual blocks
123
+ for block in self.residual_blocks:
124
+ x = block(x)
125
+
126
+ # Decode
127
+ return self.decoder(x)
128
+
129
+
130
+ def load_base_model(
131
+ model_path: Path, params_path: Path
132
+ ) -> tuple[ort.InferenceSession, dict, dict]:
133
+ """
134
+ Load the base ONNX model and its normalization parameters.
135
+
136
+ The base model is the first stage of the two-stage architecture that makes
137
+ initial predictions from xyY coordinates to Munsell specifications.
138
+
139
+ Parameters
140
+ ----------
141
+ model_path : Path
142
+ Path to the ONNX model file.
143
+ params_path : Path
144
+ Path to the .npz file containing input and output normalization parameters.
145
+
146
+ Returns
147
+ -------
148
+ session : ort.InferenceSession
149
+ ONNX Runtime inference session for the base model.
150
+ input_params : dict
151
+ Dictionary containing input normalization ranges (x_range, y_range, Y_range).
152
+ output_params : dict
153
+ Dictionary containing output normalization ranges (hue_range, value_range,
154
+ chroma_range, code_range).
155
+ """
156
+ session = ort.InferenceSession(str(model_path))
157
+ params = np.load(params_path, allow_pickle=True)
158
+ return session, params["input_params"].item(), params["output_params"].item()
159
+
160
+
161
+ @click.command()
162
+ @click.option(
163
+ "--base-model",
164
+ type=click.Path(exists=True, path_type=Path),
165
+ help="Path to base model ONNX file",
166
+ )
167
+ @click.option(
168
+ "--params",
169
+ type=click.Path(exists=True, path_type=Path),
170
+ help="Path to normalization params file",
171
+ )
172
+ @click.option(
173
+ "--epochs",
174
+ type=int,
175
+ default=200,
176
+ help="Number of training epochs",
177
+ )
178
+ @click.option(
179
+ "--batch-size",
180
+ type=int,
181
+ default=1024,
182
+ help="Batch size for training",
183
+ )
184
+ @click.option(
185
+ "--lr",
186
+ type=float,
187
+ default=3e-4,
188
+ help="Learning rate",
189
+ )
190
+ @click.option(
191
+ "--patience",
192
+ type=int,
193
+ default=20,
194
+ help="Patience for early stopping",
195
+ )
196
+ def main(
197
+ base_model: Path | None,
198
+ params: Path | None,
199
+ epochs: int,
200
+ batch_size: int,
201
+ lr: float,
202
+ patience: int,
203
+ ) -> None:
204
+ """
205
+ Train error predictor with advanced MLP architecture.
206
+
207
+ Parameters
208
+ ----------
209
+ base_model : Path or None
210
+ Path to the base model ONNX file. If None, uses default path.
211
+ params : Path or None
212
+ Path to normalization parameters .npz file. If None, uses default path.
213
+
214
+ Notes
215
+ -----
216
+ The training pipeline:
217
+ 1. Loads pre-trained base model
218
+ 2. Generates base model predictions for training data
219
+ 3. Computes residual errors between predictions and targets
220
+ 4. Trains error predictor on these residuals
221
+ 5. Uses precision-focused loss function
222
+ 6. Learning rate scheduling with ReduceLROnPlateau
223
+ 7. Early stopping based on validation loss
224
+ 8. Exports model to ONNX format
225
+ 9. Logs metrics and artifacts to MLflow
226
+ """
227
+
228
+
229
+ LOGGER.info("=" * 80)
230
+ LOGGER.info("Error Predictor: MLP + GELU + Precision Loss")
231
+ LOGGER.info("=" * 80)
232
+
233
+ # Set device
234
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
235
+ LOGGER.info("Using device: %s", device)
236
+
237
+ # Paths
238
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
239
+ data_dir = PROJECT_ROOT / "data"
240
+
241
+ base_model_path = base_model
242
+ params_path = params
243
+ cache_file = data_dir / "training_data.npz"
244
+
245
+ # Extract base model name for error predictor naming
246
+ base_model_name = (
247
+ base_model_path.stem if base_model_path else "xyY_to_munsell_specification"
248
+ )
249
+
250
+ # Load base model
251
+ LOGGER.info("")
252
+ LOGGER.info("Loading base model from %s...", base_model_path)
253
+ base_session, input_params, output_params = load_base_model(
254
+ base_model_path, params_path
255
+ )
256
+
257
+ # Load training data
258
+ LOGGER.info("Loading training data from %s...", cache_file)
259
+ data = np.load(cache_file)
260
+ X_train = data["X_train"]
261
+ y_train = data["y_train"]
262
+ X_val = data["X_val"]
263
+ y_val = data["y_val"]
264
+
265
+ LOGGER.info("Train samples: %d", len(X_train))
266
+ LOGGER.info("Validation samples: %d", len(X_val))
267
+
268
+ # Generate base model predictions
269
+ LOGGER.info("")
270
+ LOGGER.info("Generating base model predictions...")
271
+ X_train_norm = normalize_xyY(X_train, input_params)
272
+ y_train_norm = normalize_munsell(y_train, output_params)
273
+
274
+ # Base predictions (normalized)
275
+ base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
276
+
277
+ X_val_norm = normalize_xyY(X_val, input_params)
278
+ y_val_norm = normalize_munsell(y_val, output_params)
279
+ base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
280
+
281
+ # Compute errors (in normalized space)
282
+ error_train = y_train_norm - base_pred_train_norm
283
+ error_val = y_val_norm - base_pred_val_norm
284
+
285
+ # Statistics
286
+ LOGGER.info("")
287
+ LOGGER.info("Base model error statistics (normalized space):")
288
+ LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train)))
289
+ LOGGER.info(" Std of error: %.6f", np.std(error_train))
290
+ LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train)))
291
+
292
+ # Create combined input: [xyY_norm, base_prediction_norm]
293
+ X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
294
+ X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
295
+
296
+ # Convert to PyTorch tensors
297
+ X_train_t = torch.FloatTensor(X_train_combined)
298
+ error_train_t = torch.FloatTensor(error_train)
299
+ X_val_t = torch.FloatTensor(X_val_combined)
300
+ error_val_t = torch.FloatTensor(error_val)
301
+
302
+ # Create data loaders
303
+ train_dataset = TensorDataset(X_train_t, error_train_t)
304
+ val_dataset = TensorDataset(X_val_t, error_val_t)
305
+
306
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
307
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
308
+
309
+ # Initialize error predictor model with MLP architecture
310
+ model = ErrorPredictorMLP(num_residual_blocks=3).to(device)
311
+ LOGGER.info("")
312
+ LOGGER.info("Error predictor architecture:")
313
+ LOGGER.info("%s", model)
314
+
315
+ # Count parameters
316
+ total_params = sum(p.numel() for p in model.parameters())
317
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
318
+
319
+ # Training setup with precision-focused loss
320
+ LOGGER.info("")
321
+ LOGGER.info("Using precision-focused loss function:")
322
+ LOGGER.info(" - MSE (weight: 1.0)")
323
+ LOGGER.info(" - MAE (weight: 0.5)")
324
+ LOGGER.info(" - Log penalty for small errors (weight: 0.3)")
325
+ LOGGER.info(" - Huber loss with delta=0.01 (weight: 0.5)")
326
+
327
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
328
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
329
+ optimizer, mode="min", factor=0.5, patience=5
330
+ )
331
+ criterion = precision_focused_loss
332
+
333
+ # MLflow setup
334
+ model_name = f"{base_model_name}_error_predictor"
335
+ run_name = setup_mlflow_experiment("from_xyY", model_name)
336
+
337
+ LOGGER.info("")
338
+ LOGGER.info("MLflow run: %s", run_name)
339
+
340
+ # Training loop
341
+ best_val_loss = float("inf")
342
+ patience_counter = 0
343
+
344
+ LOGGER.info("")
345
+ LOGGER.info("Starting training...")
346
+
347
+ with mlflow.start_run(run_name=run_name):
348
+ mlflow.log_params(
349
+ {
350
+ "model": model_name,
351
+ "base_model": base_model_name,
352
+ "learning_rate": lr,
353
+ "batch_size": batch_size,
354
+ "num_epochs": epochs,
355
+ "patience": patience,
356
+ "total_params": total_params,
357
+ }
358
+ )
359
+
360
+ for epoch in range(epochs):
361
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
362
+ val_loss = validate(model, val_loader, criterion, device)
363
+
364
+ # Update learning rate
365
+ scheduler.step(val_loss)
366
+
367
+ log_training_epoch(
368
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
369
+ )
370
+
371
+ LOGGER.info(
372
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
373
+ epoch + 1,
374
+ epochs,
375
+ train_loss,
376
+ val_loss,
377
+ optimizer.param_groups[0]["lr"],
378
+ )
379
+
380
+ # Early stopping
381
+ if val_loss < best_val_loss:
382
+ best_val_loss = val_loss
383
+ patience_counter = 0
384
+
385
+ # Save best model
386
+ model_directory.mkdir(exist_ok=True)
387
+ checkpoint_file = (
388
+ model_directory / f"{base_model_name}_error_predictor_best.pth"
389
+ )
390
+
391
+ torch.save(
392
+ {
393
+ "model_state_dict": model.state_dict(),
394
+ "epoch": epoch,
395
+ "val_loss": val_loss,
396
+ },
397
+ checkpoint_file,
398
+ )
399
+
400
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
401
+ else:
402
+ patience_counter += 1
403
+ if patience_counter >= patience:
404
+ LOGGER.info("")
405
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
406
+ break
407
+
408
+ mlflow.log_metrics(
409
+ {
410
+ "best_val_loss": best_val_loss,
411
+ "final_epoch": epoch + 1,
412
+ }
413
+ )
414
+
415
+ # Export to ONNX
416
+ LOGGER.info("")
417
+ LOGGER.info("Exporting error predictor to ONNX...")
418
+ model.eval()
419
+
420
+ # Load best model
421
+ checkpoint = torch.load(checkpoint_file)
422
+ model.load_state_dict(checkpoint["model_state_dict"])
423
+
424
+ # Create dummy input (xyY_norm + base_pred_norm = 7 inputs)
425
+ dummy_input = torch.randn(1, 7).to(device)
426
+
427
+ # Export
428
+ onnx_file = model_directory / f"{base_model_name}_error_predictor.onnx"
429
+ torch.onnx.export(
430
+ model,
431
+ dummy_input,
432
+ onnx_file,
433
+ export_params=True,
434
+ opset_version=15,
435
+ input_names=["combined_input"],
436
+ output_names=["error_correction"],
437
+ dynamic_axes={
438
+ "combined_input": {0: "batch_size"},
439
+ "error_correction": {0: "batch_size"},
440
+ },
441
+ )
442
+
443
+ mlflow.log_artifact(str(checkpoint_file))
444
+ mlflow.log_artifact(str(onnx_file))
445
+ mlflow.pytorch.log_model(model, "model")
446
+
447
+ LOGGER.info("Error predictor ONNX model saved to: %s", onnx_file)
448
+ LOGGER.info("Artifacts logged to MLflow")
449
+
450
+
451
+ LOGGER.info("=" * 80)
452
+
453
+
454
+ if __name__ == "__main__":
455
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
456
+
457
+ main()
learning_munsell/training/from_xyY/train_mlp_gamma.py ADDED
@@ -0,0 +1,297 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train ML model for xyY to Munsell conversion with gamma-corrected Y.
3
+
4
+ Experiment: Apply gamma 2.33 to Y before normalization to better align
5
+ with perceptual lightness (Munsell Value scale is perceptually uniform).
6
+ """
7
+
8
+ import logging
9
+ from typing import Any
10
+
11
+ import click
12
+ import mlflow
13
+ import mlflow.pytorch
14
+ import numpy as np
15
+ import torch
16
+ from numpy.typing import NDArray
17
+ from torch import optim
18
+ from torch.utils.data import DataLoader, TensorDataset
19
+
20
+ from learning_munsell import PROJECT_ROOT
21
+ from learning_munsell.models.networks import MLPToMunsell
22
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
23
+ from learning_munsell.utilities.data import (
24
+ MUNSELL_NORMALIZATION_PARAMS,
25
+ normalize_munsell,
26
+ )
27
+ from learning_munsell.utilities.losses import weighted_mse_loss
28
+ from learning_munsell.utilities.training import train_epoch, validate
29
+
30
+ LOGGER = logging.getLogger(__name__)
31
+
32
+ # Gamma value for Y transformation
33
+ GAMMA = 2.33
34
+
35
+
36
+ def normalize_inputs(
37
+ X: NDArray, gamma: float = GAMMA
38
+ ) -> tuple[NDArray, dict[str, Any]]:
39
+ """
40
+ Normalize xyY inputs to [0, 1] range with gamma correction on Y.
41
+
42
+ Parameters
43
+ ----------
44
+ X : ndarray
45
+ xyY values of shape (n, 3) where columns are [x, y, Y].
46
+ gamma : float
47
+ Gamma value to apply to Y component.
48
+
49
+ Returns
50
+ -------
51
+ ndarray
52
+ Normalized values with gamma-corrected Y, dtype float32.
53
+ dict
54
+ Normalization parameters including gamma value.
55
+ """
56
+ # Typical ranges for xyY
57
+ x_range = (0.0, 1.0)
58
+ y_range = (0.0, 1.0)
59
+ Y_range = (0.0, 1.0)
60
+
61
+ X_norm = X.copy()
62
+ X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0])
63
+ X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0])
64
+
65
+ # Normalize Y first, then apply gamma
66
+ Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0])
67
+ # Clip to avoid numerical issues with negative values
68
+ Y_normalized = np.clip(Y_normalized, 0, 1)
69
+ # Apply gamma: Y_gamma = Y^(1/gamma) - this spreads dark values, compresses light
70
+ X_norm[:, 2] = np.power(Y_normalized, 1.0 / gamma)
71
+
72
+ params = {
73
+ "x_range": x_range,
74
+ "y_range": y_range,
75
+ "Y_range": Y_range,
76
+ "gamma": gamma,
77
+ }
78
+
79
+ return X_norm, params
80
+
81
+
82
+ @click.command()
83
+ @click.option("--epochs", default=200, help="Number of training epochs")
84
+ @click.option("--batch-size", default=1024, help="Batch size for training")
85
+ @click.option("--lr", default=5e-4, help="Learning rate")
86
+ @click.option("--patience", default=20, help="Early stopping patience")
87
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
88
+ """
89
+ Train MLP model with gamma-corrected Y input.
90
+
91
+ Notes
92
+ -----
93
+ The training pipeline:
94
+ 1. Loads training and validation data from cache
95
+ 2. Normalizes inputs with gamma correction (gamma=2.33) on Y
96
+ 3. Normalizes Munsell outputs to [0, 1] range
97
+ 4. Trains MLP with weighted MSE loss
98
+ 5. Uses early stopping based on validation loss
99
+ 6. Exports best model to ONNX format
100
+ 7. Logs metrics and artifacts to MLflow
101
+
102
+ The gamma correction on Y aligns with perceptual lightness. The gamma
103
+ transformation spreads dark values and compresses light values, matching
104
+ human lightness perception and the perceptually uniform Munsell Value scale.
105
+ """
106
+
107
+ LOGGER.info("=" * 80)
108
+ LOGGER.info("ML-Based xyY to Munsell Conversion: Gamma Experiment")
109
+ LOGGER.info("Gamma = %.2f applied to Y component", GAMMA)
110
+ LOGGER.info("=" * 80)
111
+
112
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
113
+ LOGGER.info("Using device: %s", device)
114
+
115
+ # Load training data
116
+ data_dir = PROJECT_ROOT / "data"
117
+ cache_file = data_dir / "training_data.npz"
118
+
119
+ if not cache_file.exists():
120
+ LOGGER.error("Error: Training data not found at %s", cache_file)
121
+ LOGGER.error("Please run 01_generate_training_data.py first")
122
+ return
123
+
124
+ LOGGER.info("Loading training data from %s...", cache_file)
125
+ data = np.load(cache_file)
126
+
127
+ X_train = data["X_train"]
128
+ y_train = data["y_train"]
129
+ X_val = data["X_val"]
130
+ y_val = data["y_val"]
131
+
132
+ LOGGER.info("Train samples: %d", len(X_train))
133
+ LOGGER.info("Validation samples: %d", len(X_val))
134
+
135
+ # Normalize data with gamma correction
136
+ X_train_norm, input_params = normalize_inputs(X_train, gamma=GAMMA)
137
+ X_val_norm, _ = normalize_inputs(X_val, gamma=GAMMA)
138
+
139
+ output_params = MUNSELL_NORMALIZATION_PARAMS
140
+ y_train_norm = normalize_munsell(y_train, output_params)
141
+ y_val_norm = normalize_munsell(y_val, output_params)
142
+
143
+ LOGGER.info("")
144
+ LOGGER.info("Input normalization with gamma=%.2f:", GAMMA)
145
+ LOGGER.info(" Y range after gamma: [%.4f, %.4f]", X_train_norm[:, 2].min(), X_train_norm[:, 2].max())
146
+
147
+ # Convert to PyTorch tensors
148
+ X_train_t = torch.FloatTensor(X_train_norm)
149
+ y_train_t = torch.FloatTensor(y_train_norm)
150
+ X_val_t = torch.FloatTensor(X_val_norm)
151
+ y_val_t = torch.FloatTensor(y_val_norm)
152
+
153
+ # Create data loaders
154
+ train_dataset = TensorDataset(X_train_t, y_train_t)
155
+ val_dataset = TensorDataset(X_val_t, y_val_t)
156
+
157
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
158
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
159
+
160
+ # Initialize model
161
+ model = MLPToMunsell().to(device)
162
+ LOGGER.info("")
163
+ LOGGER.info("Model architecture:")
164
+ LOGGER.info("%s", model)
165
+
166
+ total_params = sum(p.numel() for p in model.parameters())
167
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
168
+
169
+ # Training setup
170
+ optimizer = optim.Adam(model.parameters(), lr=lr)
171
+ # Component weights: emphasize chroma (2.0), de-emphasize code (0.5)
172
+ weights = torch.tensor([1.0, 1.0, 2.0, 0.5])
173
+ criterion = lambda pred, target: weighted_mse_loss(pred, target, weights)
174
+
175
+ # MLflow setup
176
+ run_name = setup_mlflow_experiment("from_xyY", f"mlp_gamma_{GAMMA}")
177
+
178
+ LOGGER.info("")
179
+ LOGGER.info("MLflow run: %s", run_name)
180
+
181
+ # Training loop
182
+ best_val_loss = float("inf")
183
+ patience_counter = 0
184
+
185
+ LOGGER.info("")
186
+ LOGGER.info("Starting training...")
187
+
188
+ with mlflow.start_run(run_name=run_name):
189
+ mlflow.log_params(
190
+ {
191
+ "num_epochs": epochs,
192
+ "batch_size": batch_size,
193
+ "learning_rate": lr,
194
+ "optimizer": "Adam",
195
+ "criterion": "weighted_mse_loss",
196
+ "patience": patience,
197
+ "total_params": total_params,
198
+ "gamma": GAMMA,
199
+ }
200
+ )
201
+
202
+ for epoch in range(epochs):
203
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
204
+ val_loss = validate(model, val_loader, criterion, device)
205
+
206
+ log_training_epoch(
207
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
208
+ )
209
+
210
+ LOGGER.info(
211
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
212
+ epoch + 1,
213
+ epochs,
214
+ train_loss,
215
+ val_loss,
216
+ )
217
+
218
+ if val_loss < best_val_loss:
219
+ best_val_loss = val_loss
220
+ patience_counter = 0
221
+
222
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
223
+ model_directory.mkdir(exist_ok=True)
224
+ checkpoint_file = model_directory / "mlp_gamma_best.pth"
225
+
226
+ torch.save(
227
+ {
228
+ "model_state_dict": model.state_dict(),
229
+ "input_params": input_params,
230
+ "output_params": output_params,
231
+ "epoch": epoch,
232
+ "val_loss": val_loss,
233
+ },
234
+ checkpoint_file,
235
+ )
236
+
237
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
238
+ else:
239
+ patience_counter += 1
240
+ if patience_counter >= patience:
241
+ LOGGER.info("")
242
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
243
+ break
244
+
245
+ mlflow.log_metrics(
246
+ {
247
+ "best_val_loss": best_val_loss,
248
+ "final_epoch": epoch + 1,
249
+ }
250
+ )
251
+
252
+ # Export to ONNX
253
+ LOGGER.info("")
254
+ LOGGER.info("Exporting model to ONNX...")
255
+ model.eval()
256
+
257
+ checkpoint = torch.load(checkpoint_file)
258
+ model.load_state_dict(checkpoint["model_state_dict"])
259
+
260
+ dummy_input = torch.randn(1, 3).to(device)
261
+
262
+ onnx_file = model_directory / "mlp_gamma.onnx"
263
+ torch.onnx.export(
264
+ model,
265
+ dummy_input,
266
+ onnx_file,
267
+ export_params=True,
268
+ opset_version=15,
269
+ input_names=["xyY_gamma"],
270
+ output_names=["munsell_spec"],
271
+ dynamic_axes={"xyY_gamma": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
272
+ )
273
+
274
+ # Save normalization parameters (including gamma)
275
+ params_file = model_directory / "mlp_gamma_normalization_params.npz"
276
+ np.savez(
277
+ params_file,
278
+ input_params=input_params,
279
+ output_params=output_params,
280
+ )
281
+
282
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
283
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
284
+ LOGGER.info("IMPORTANT: Input Y must be gamma-corrected with gamma=%.2f", GAMMA)
285
+
286
+ mlflow.log_artifact(str(checkpoint_file))
287
+ mlflow.log_artifact(str(onnx_file))
288
+ mlflow.log_artifact(str(params_file))
289
+ mlflow.pytorch.log_model(model, "model")
290
+
291
+ LOGGER.info("=" * 80)
292
+
293
+
294
+ if __name__ == "__main__":
295
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
296
+
297
+ main()
learning_munsell/training/from_xyY/train_multi_head_3stage_error_predictor.py ADDED
@@ -0,0 +1,411 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train second-stage error predictor for 3-stage model.
3
+
4
+ Architecture: Multi-Head + Multi-Error Predictor + Multi-Error Predictor
5
+ - Stage 1: Multi-Head base model (existing)
6
+ - Stage 2: First error predictor (existing)
7
+ - Stage 3: Second error predictor (this script) - learns residuals from stage 2
8
+
9
+ The second error predictor has the same architecture as the first but learns
10
+ the remaining errors after the first error correction is applied.
11
+ """
12
+
13
+ import logging
14
+ from pathlib import Path
15
+ from typing import Any
16
+
17
+ import click
18
+ import mlflow
19
+ import mlflow.pytorch
20
+ import numpy as np
21
+ import onnxruntime as ort
22
+ import torch
23
+ from numpy.typing import NDArray
24
+ from torch import nn, optim
25
+ from torch.utils.data import DataLoader, TensorDataset
26
+
27
+ from learning_munsell import PROJECT_ROOT
28
+ from learning_munsell.models.networks import (
29
+ ComponentErrorPredictor,
30
+ MultiHeadErrorPredictorToMunsell,
31
+ )
32
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
33
+ from learning_munsell.utilities.data import normalize_munsell, normalize_xyY
34
+ from learning_munsell.utilities.losses import precision_focused_loss
35
+ from learning_munsell.utilities.training import train_epoch, validate
36
+
37
+ LOGGER = logging.getLogger(__name__)
38
+
39
+
40
+ @click.command()
41
+ @click.option(
42
+ "--base-model",
43
+ type=click.Path(exists=True, path_type=Path),
44
+ default=None,
45
+ help="Path to Multi-Head base model ONNX file",
46
+ )
47
+ @click.option(
48
+ "--first-error-predictor",
49
+ type=click.Path(exists=True, path_type=Path),
50
+ default=None,
51
+ help="Path to first error predictor ONNX file",
52
+ )
53
+ @click.option(
54
+ "--params",
55
+ type=click.Path(exists=True, path_type=Path),
56
+ default=None,
57
+ help="Path to normalization params file",
58
+ )
59
+ @click.option(
60
+ "--epochs",
61
+ type=int,
62
+ default=300,
63
+ help="Number of training epochs (default: 300)",
64
+ )
65
+ @click.option(
66
+ "--batch-size",
67
+ type=int,
68
+ default=2048,
69
+ help="Batch size for training (default: 2048)",
70
+ )
71
+ @click.option(
72
+ "--lr",
73
+ type=float,
74
+ default=3e-4,
75
+ help="Learning rate (default: 3e-4)",
76
+ )
77
+ @click.option(
78
+ "--patience",
79
+ type=int,
80
+ default=30,
81
+ help="Early stopping patience (default: 30)",
82
+ )
83
+ def main(
84
+ base_model: Path | None,
85
+ first_error_predictor: Path | None,
86
+ params: Path | None,
87
+ epochs: int,
88
+ batch_size: int,
89
+ lr: float,
90
+ patience: int,
91
+ ) -> None:
92
+ """
93
+ Train the second-stage error predictor for the 3-stage model.
94
+
95
+ This script trains the third stage of a 3-stage model:
96
+ - Stage 1: Multi-Head base model (pre-trained)
97
+ - Stage 2: First error predictor (pre-trained)
98
+ - Stage 3: Second error predictor (trained by this script)
99
+
100
+ The second error predictor learns the residual errors remaining after
101
+ the first error correction is applied, further refining the predictions.
102
+
103
+ Parameters
104
+ ----------
105
+ base_model : Path, optional
106
+ Path to the Multi-Head base model ONNX file.
107
+ Default: models/from_xyY/multi_head_large.onnx
108
+ first_error_predictor : Path, optional
109
+ Path to the first error predictor ONNX file.
110
+ Default: models/from_xyY/multi_head_multi_error_predictor_large.onnx
111
+ params : Path, optional
112
+ Path to the normalization parameters file.
113
+ Default: models/from_xyY/multi_head_large_normalization_params.npz
114
+
115
+ Notes
116
+ -----
117
+ The training pipeline:
118
+ 1. Loads pre-trained Stage 1 and Stage 2 models
119
+ 2. Generates Stage 2 predictions (base + first error correction)
120
+ 3. Computes remaining residual errors
121
+ 4. Trains Stage 3 error predictor on these residuals
122
+ 5. Exports the model to ONNX format
123
+ 6. Logs metrics and artifacts to MLflow
124
+ """
125
+
126
+ LOGGER.info("=" * 80)
127
+ LOGGER.info("Second Error Predictor: 3-Stage Model Training")
128
+ LOGGER.info("Multi-Head + Multi-Error Predictor + Multi-Error Predictor")
129
+ LOGGER.info("=" * 80)
130
+
131
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
132
+ if torch.backends.mps.is_available():
133
+ device = torch.device("mps")
134
+ LOGGER.info("Using device: %s", device)
135
+
136
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
137
+ data_dir = PROJECT_ROOT / "data"
138
+
139
+ if base_model is None:
140
+ base_model = model_directory / "multi_head_large.onnx"
141
+ if first_error_predictor is None:
142
+ first_error_predictor = model_directory / "multi_head_multi_error_predictor_large.onnx"
143
+ if params is None:
144
+ params = model_directory / "multi_head_large_normalization_params.npz"
145
+
146
+ cache_file = data_dir / "training_data_large.npz"
147
+
148
+ if not cache_file.exists():
149
+ LOGGER.error("Error: Large training data not found at %s", cache_file)
150
+ return
151
+
152
+ if not base_model.exists():
153
+ LOGGER.error("Error: Base model not found at %s", base_model)
154
+ return
155
+
156
+ if not first_error_predictor.exists():
157
+ LOGGER.error("Error: First error predictor not found at %s", first_error_predictor)
158
+ return
159
+
160
+ # Load models
161
+ LOGGER.info("")
162
+ LOGGER.info("Loading Stage 1: Multi-Head base model from %s...", base_model)
163
+ base_session = ort.InferenceSession(str(base_model))
164
+
165
+ LOGGER.info("Loading Stage 2: First error predictor from %s...", first_error_predictor)
166
+ error_predictor_session = ort.InferenceSession(str(first_error_predictor))
167
+
168
+ # Load normalization params
169
+ params_data = np.load(params, allow_pickle=True)
170
+ input_params = params_data["input_params"].item()
171
+ output_params = params_data["output_params"].item()
172
+
173
+ # Load training data
174
+ LOGGER.info("Loading large training data from %s...", cache_file)
175
+ data = np.load(cache_file)
176
+ X_train = data["X_train"]
177
+ y_train = data["y_train"]
178
+ X_val = data["X_val"]
179
+ y_val = data["y_val"]
180
+
181
+ LOGGER.info("Train samples: %d", len(X_train))
182
+ LOGGER.info("Validation samples: %d", len(X_val))
183
+
184
+ # Generate stage 2 predictions (base + first error correction)
185
+ LOGGER.info("")
186
+ LOGGER.info("Computing Stage 2 predictions (base + first error correction)...")
187
+
188
+ X_train_norm = normalize_xyY(X_train, input_params)
189
+ y_train_norm = normalize_munsell(y_train, output_params)
190
+ X_val_norm = normalize_xyY(X_val, input_params)
191
+ y_val_norm = normalize_munsell(y_val, output_params)
192
+
193
+ inference_batch_size = 50000
194
+
195
+ # Stage 1: Base model predictions
196
+ LOGGER.info(" Stage 1: Base model predictions (training set)...")
197
+ base_pred_train = []
198
+ for i in range(0, len(X_train_norm), inference_batch_size):
199
+ batch = X_train_norm[i : i + inference_batch_size]
200
+ pred = base_session.run(None, {"xyY": batch})[0]
201
+ base_pred_train.append(pred)
202
+ base_pred_train = np.concatenate(base_pred_train, axis=0)
203
+
204
+ LOGGER.info(" Stage 1: Base model predictions (validation set)...")
205
+ base_pred_val = []
206
+ for i in range(0, len(X_val_norm), inference_batch_size):
207
+ batch = X_val_norm[i : i + inference_batch_size]
208
+ pred = base_session.run(None, {"xyY": batch})[0]
209
+ base_pred_val.append(pred)
210
+ base_pred_val = np.concatenate(base_pred_val, axis=0)
211
+
212
+ # Stage 2: First error predictor corrections
213
+ LOGGER.info(" Stage 2: First error predictor corrections (training set)...")
214
+ combined_train = np.concatenate([X_train_norm, base_pred_train], axis=1).astype(np.float32)
215
+ error_correction_train = []
216
+ for i in range(0, len(combined_train), inference_batch_size):
217
+ batch = combined_train[i : i + inference_batch_size]
218
+ correction = error_predictor_session.run(None, {"combined_input": batch})[0]
219
+ error_correction_train.append(correction)
220
+ error_correction_train = np.concatenate(error_correction_train, axis=0)
221
+
222
+ LOGGER.info(" Stage 2: First error predictor corrections (validation set)...")
223
+ combined_val = np.concatenate([X_val_norm, base_pred_val], axis=1).astype(np.float32)
224
+ error_correction_val = []
225
+ for i in range(0, len(combined_val), inference_batch_size):
226
+ batch = combined_val[i : i + inference_batch_size]
227
+ correction = error_predictor_session.run(None, {"combined_input": batch})[0]
228
+ error_correction_val.append(correction)
229
+ error_correction_val = np.concatenate(error_correction_val, axis=0)
230
+
231
+ # Stage 2 predictions (base + first error correction)
232
+ stage2_pred_train = base_pred_train + error_correction_train
233
+ stage2_pred_val = base_pred_val + error_correction_val
234
+
235
+ # Compute remaining errors for stage 3
236
+ error_train = y_train_norm - stage2_pred_train
237
+ error_val = y_val_norm - stage2_pred_val
238
+
239
+ # Statistics
240
+ LOGGER.info("")
241
+ LOGGER.info("Stage 2 prediction error statistics (normalized space):")
242
+ LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train)))
243
+ LOGGER.info(" Std of error: %.6f", np.std(error_train))
244
+ LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train)))
245
+
246
+ # Compare with stage 1 errors
247
+ stage1_error_train = y_train_norm - base_pred_train
248
+ LOGGER.info("")
249
+ LOGGER.info("Stage 1 (base only) error statistics for comparison:")
250
+ LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(stage1_error_train)))
251
+ LOGGER.info(" Std of error: %.6f", np.std(stage1_error_train))
252
+
253
+ error_reduction = (
254
+ (np.mean(np.abs(stage1_error_train)) - np.mean(np.abs(error_train)))
255
+ / np.mean(np.abs(stage1_error_train))
256
+ * 100
257
+ )
258
+ LOGGER.info("")
259
+ LOGGER.info("Stage 2 error reduction vs Stage 1: %.1f%%", error_reduction)
260
+
261
+ # Create combined input for stage 3: [xyY_norm, stage2_pred_norm]
262
+ X_train_combined = np.concatenate([X_train_norm, stage2_pred_train], axis=1)
263
+ X_val_combined = np.concatenate([X_val_norm, stage2_pred_val], axis=1)
264
+
265
+ # Convert to PyTorch tensors
266
+ X_train_t = torch.FloatTensor(X_train_combined)
267
+ error_train_t = torch.FloatTensor(error_train)
268
+ X_val_t = torch.FloatTensor(X_val_combined)
269
+ error_val_t = torch.FloatTensor(error_val)
270
+
271
+ train_dataset = TensorDataset(X_train_t, error_train_t)
272
+ val_dataset = TensorDataset(X_val_t, error_val_t)
273
+
274
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
275
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
276
+
277
+ # Initialize second error predictor (same architecture as first)
278
+ model = MultiHeadErrorPredictorToMunsell().to(device)
279
+ LOGGER.info("")
280
+ LOGGER.info("Stage 3: Second error predictor architecture:")
281
+ LOGGER.info("%s", model)
282
+
283
+ total_params = sum(p.numel() for p in model.parameters())
284
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
285
+
286
+ # Training setup
287
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
288
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
289
+ optimizer, mode="min", factor=0.5, patience=10
290
+ )
291
+ criterion = precision_focused_loss
292
+
293
+ run_name = setup_mlflow_experiment("from_xyY", "multi_head_3stage_error_predictor")
294
+
295
+ LOGGER.info("")
296
+ LOGGER.info("MLflow run: %s", run_name)
297
+
298
+ # Training loop
299
+ best_val_loss = float("inf")
300
+ patience_counter = 0
301
+
302
+ LOGGER.info("")
303
+ LOGGER.info("Starting Stage 3 training...")
304
+
305
+ with mlflow.start_run(run_name=run_name):
306
+ mlflow.log_params(
307
+ {
308
+ "model": "multi_head_3stage_error_predictor",
309
+ "num_epochs": epochs,
310
+ "batch_size": batch_size,
311
+ "learning_rate": lr,
312
+ "weight_decay": 1e-5,
313
+ "optimizer": "AdamW",
314
+ "scheduler": "ReduceLROnPlateau",
315
+ "criterion": "precision_focused_loss",
316
+ "patience": patience,
317
+ "total_params": total_params,
318
+ "train_samples": len(X_train),
319
+ "val_samples": len(X_val),
320
+ "stage2_error_reduction_pct": error_reduction,
321
+ }
322
+ )
323
+
324
+ for epoch in range(epochs):
325
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
326
+ val_loss = validate(model, val_loader, criterion, device)
327
+
328
+ scheduler.step(val_loss)
329
+
330
+ log_training_epoch(
331
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
332
+ )
333
+
334
+ LOGGER.info(
335
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
336
+ epoch + 1,
337
+ epochs,
338
+ train_loss,
339
+ val_loss,
340
+ optimizer.param_groups[0]["lr"],
341
+ )
342
+
343
+ if val_loss < best_val_loss:
344
+ best_val_loss = val_loss
345
+ patience_counter = 0
346
+
347
+ model_directory.mkdir(exist_ok=True)
348
+ checkpoint_file = model_directory / "multi_head_3stage_error_predictor_best.pth"
349
+
350
+ torch.save(
351
+ {
352
+ "model_state_dict": model.state_dict(),
353
+ "epoch": epoch,
354
+ "val_loss": val_loss,
355
+ },
356
+ checkpoint_file,
357
+ )
358
+
359
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
360
+ else:
361
+ patience_counter += 1
362
+ if patience_counter >= patience:
363
+ LOGGER.info("")
364
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
365
+ break
366
+
367
+ mlflow.log_metrics(
368
+ {
369
+ "best_val_loss": best_val_loss,
370
+ "final_epoch": epoch + 1,
371
+ }
372
+ )
373
+
374
+ # Export to ONNX
375
+ LOGGER.info("")
376
+ LOGGER.info("Exporting Stage 3 error predictor to ONNX...")
377
+ model.eval()
378
+
379
+ checkpoint = torch.load(checkpoint_file, weights_only=False)
380
+ model.load_state_dict(checkpoint["model_state_dict"])
381
+
382
+ dummy_input = torch.randn(1, 7).to(device)
383
+
384
+ onnx_file = model_directory / "multi_head_3stage_error_predictor.onnx"
385
+ torch.onnx.export(
386
+ model,
387
+ dummy_input,
388
+ onnx_file,
389
+ export_params=True,
390
+ opset_version=15,
391
+ input_names=["combined_input"],
392
+ output_names=["error_correction"],
393
+ dynamic_axes={
394
+ "combined_input": {0: "batch_size"},
395
+ "error_correction": {0: "batch_size"},
396
+ },
397
+ )
398
+
399
+ LOGGER.info("Stage 3 error predictor ONNX model saved to: %s", onnx_file)
400
+
401
+ mlflow.log_artifact(str(checkpoint_file))
402
+ mlflow.log_artifact(str(onnx_file))
403
+ mlflow.pytorch.log_model(model, "model")
404
+
405
+ LOGGER.info("=" * 80)
406
+
407
+
408
+ if __name__ == "__main__":
409
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
410
+
411
+ main()
learning_munsell/training/from_xyY/train_multi_head_circular.py ADDED
@@ -0,0 +1,479 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Multi-Head model with circular hue loss for xyY to Munsell conversion.
3
+
4
+ This version uses circular loss for the hue component (which wraps from 0-10)
5
+ to avoid penalizing predictions near the boundary.
6
+
7
+ Key Difference from Standard Training:
8
+ - Uses munsell_component_loss() which applies circular MSE for hue
9
+ - and regular MSE for value/chroma/code components
10
+ """
11
+
12
+ from __future__ import annotations
13
+
14
+ import copy
15
+ import logging
16
+
17
+ import click
18
+ import mlflow
19
+ import mlflow.pytorch
20
+ import numpy as np
21
+ import torch
22
+ from torch import nn, optim
23
+ from torch.utils.data import DataLoader, TensorDataset
24
+
25
+ from learning_munsell import PROJECT_ROOT
26
+ from learning_munsell.utilities.common import setup_mlflow_experiment
27
+ from learning_munsell.utilities.data import (
28
+ MUNSELL_NORMALIZATION_PARAMS,
29
+ normalize_munsell,
30
+ )
31
+ from learning_munsell.training.from_xyY.hyperparameter_search_multi_head import (
32
+ MultiHeadParametric,
33
+ )
34
+
35
+ LOGGER = logging.getLogger(__name__)
36
+
37
+
38
+ def circular_mse_loss(
39
+ pred_hue: torch.Tensor, target_hue: torch.Tensor, hue_range: float = 1.0
40
+ ) -> torch.Tensor:
41
+ """
42
+ Circular MSE loss for hue component (normalized 0-1).
43
+
44
+ Parameters
45
+ ----------
46
+ pred_hue : Tensor
47
+ Predicted hue values (normalized 0-1)
48
+ target_hue : Tensor
49
+ Target hue values (normalized 0-1)
50
+ hue_range : float
51
+ Range of hue values (1.0 for normalized)
52
+
53
+ Returns
54
+ -------
55
+ Tensor
56
+ Circular MSE loss
57
+ """
58
+ diff = torch.abs(pred_hue - target_hue)
59
+ circular_diff = torch.min(diff, hue_range - diff)
60
+ return torch.mean(circular_diff**2)
61
+
62
+
63
+ def munsell_component_loss(
64
+ pred: torch.Tensor, target: torch.Tensor, hue_range: float = 1.0
65
+ ) -> torch.Tensor:
66
+ """
67
+ Component-wise loss for Munsell predictions.
68
+
69
+ Uses circular MSE for hue (component 0) and regular MSE
70
+ for value, chroma, code (components 1-3).
71
+
72
+ Parameters
73
+ ----------
74
+ pred : Tensor
75
+ Predictions [hue, value, chroma, code] (shape: [batch, 4])
76
+ target : Tensor
77
+ Ground truth [hue, value, chroma, code] (shape: [batch, 4])
78
+ hue_range : float
79
+ Range of normalized hue values (default 1.0)
80
+
81
+ Returns
82
+ -------
83
+ Tensor
84
+ Combined loss
85
+ """
86
+ hue_loss = circular_mse_loss(pred[:, 0], target[:, 0], hue_range)
87
+ other_loss = nn.functional.mse_loss(pred[:, 1:], target[:, 1:])
88
+ return hue_loss + other_loss
89
+
90
+
91
+ @click.command()
92
+ @click.option("--epochs", default=300, help="Number of training epochs")
93
+ @click.option("--batch-size", default=512, help="Batch size for training")
94
+ @click.option("--lr", default=0.000837, help="Learning rate")
95
+ @click.option("--patience", default=30, help="Early stopping patience")
96
+ def main(
97
+ epochs: int,
98
+ batch_size: int,
99
+ lr: float,
100
+ patience: int,
101
+ encoder_width: float = 0.75,
102
+ head_width: float = 1.5,
103
+ chroma_head_width: float = 1.5,
104
+ dropout: float = 0.0,
105
+ weight_decay: float = 0.000013,
106
+ ) -> tuple[MultiHeadParametric, float]:
107
+ """
108
+ Train Multi-Head model with circular hue loss.
109
+
110
+ This script uses circular loss for the hue component (which wraps from
111
+ 0-10) to avoid penalizing predictions near the boundary.
112
+
113
+ Parameters
114
+ ----------
115
+ epochs : int, optional
116
+ Maximum number of training epochs.
117
+ batch_size : int, optional
118
+ Training batch size.
119
+ lr : float, optional
120
+ Learning rate for AdamW optimizer.
121
+ encoder_width : float, optional
122
+ Width multiplier for the shared encoder.
123
+ head_width : float, optional
124
+ Width multiplier for hue, value, and code heads.
125
+ chroma_head_width : float, optional
126
+ Width multiplier for chroma head (typically larger).
127
+ dropout : float, optional
128
+ Dropout rate for regularization.
129
+ weight_decay : float, optional
130
+ Weight decay for AdamW optimizer.
131
+
132
+ Returns
133
+ -------
134
+ model : MultiHeadParametric
135
+ Trained model with best validation loss weights.
136
+ best_val_loss : float
137
+ Best validation loss achieved during training.
138
+
139
+ Notes
140
+ -----
141
+ The training pipeline:
142
+ 1. Loads training data from cache
143
+ 2. Normalizes outputs to [0, 1] range
144
+ 3. Trains with circular MSE for hue and regular MSE for other components
145
+ 4. Uses CosineAnnealingLR scheduler
146
+ 5. Early stopping based on validation loss
147
+ 6. Exports model to ONNX format
148
+ 7. Logs metrics and artifacts to MLflow
149
+
150
+ The circular loss experiment showed that while mathematically correct,
151
+ the circular distance creates gradient discontinuities that harm
152
+ optimization. This model is included for comparison purposes.
153
+ """
154
+
155
+ LOGGER.info("=" * 80)
156
+ LOGGER.info("Training Multi-Head (Circular Hue Loss) for xyY to Munsell conversion")
157
+ LOGGER.info("=" * 80)
158
+ LOGGER.info("")
159
+ LOGGER.info("Using Circular Loss for Hue Component")
160
+ LOGGER.info("=" * 80)
161
+ LOGGER.info("")
162
+ LOGGER.info("Hyperparameters:")
163
+ LOGGER.info(" lr: %.6f", lr)
164
+ LOGGER.info(" batch_size: %d", batch_size)
165
+ LOGGER.info(" encoder_width: %.2f", encoder_width)
166
+ LOGGER.info(" head_width: %.2f", head_width)
167
+ LOGGER.info(" chroma_head_width: %.2f", chroma_head_width)
168
+ LOGGER.info(" dropout: %.2f", dropout)
169
+ LOGGER.info(" weight_decay: %.6f", weight_decay)
170
+ LOGGER.info("")
171
+
172
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
173
+ LOGGER.info("Using device: %s", device)
174
+
175
+ # Load data from cache
176
+ data_dir = PROJECT_ROOT / "data"
177
+ cache_file = data_dir / "training_data.npz"
178
+ data = np.load(cache_file)
179
+
180
+ X_train = data["X_train"]
181
+ y_train = data["y_train"]
182
+ X_val = data["X_val"]
183
+ y_val = data["y_val"]
184
+
185
+ LOGGER.info("Training samples: %d", len(X_train))
186
+ LOGGER.info("Validation samples: %d", len(X_val))
187
+
188
+ # Normalize outputs (xyY inputs already in [0, 1] range)
189
+ # Use shared normalization parameters covering the full Munsell space for generalization
190
+ output_params = MUNSELL_NORMALIZATION_PARAMS
191
+ y_train_norm = normalize_munsell(y_train, output_params)
192
+ y_val_norm = normalize_munsell(y_val, output_params)
193
+
194
+ # Convert to tensors
195
+ X_train_t = torch.from_numpy(X_train).float()
196
+ y_train_t = torch.from_numpy(y_train_norm).float()
197
+ X_val_t = torch.from_numpy(X_val).float()
198
+ y_val_t = torch.from_numpy(y_val_norm).float()
199
+
200
+ train_loader = DataLoader(
201
+ TensorDataset(X_train_t, y_train_t), batch_size=batch_size, shuffle=True
202
+ )
203
+ val_loader = DataLoader(
204
+ TensorDataset(X_val_t, y_val_t), batch_size=batch_size, shuffle=False
205
+ )
206
+
207
+ # Create model
208
+ model = MultiHeadParametric(
209
+ encoder_width=encoder_width,
210
+ head_width=head_width,
211
+ chroma_head_width=chroma_head_width,
212
+ dropout=dropout,
213
+ ).to(device)
214
+
215
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=weight_decay)
216
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
217
+
218
+ total_params = sum(p.numel() for p in model.parameters())
219
+ LOGGER.info("")
220
+ LOGGER.info("Model parameters: %s", f"{total_params:,}")
221
+
222
+ encoder_params = sum(p.numel() for p in model.encoder.parameters())
223
+ hue_params = sum(p.numel() for p in model.hue_head.parameters())
224
+ value_params = sum(p.numel() for p in model.value_head.parameters())
225
+ chroma_params = sum(p.numel() for p in model.chroma_head.parameters())
226
+ code_params = sum(p.numel() for p in model.code_head.parameters())
227
+
228
+ LOGGER.info(" - Shared encoder (%.2fx): %s", encoder_width, f"{encoder_params:,}")
229
+ LOGGER.info(" - Hue head (%.2fx): %s", head_width, f"{hue_params:,}")
230
+ LOGGER.info(" - Value head (%.2fx): %s", head_width, f"{value_params:,}")
231
+ LOGGER.info(" - Chroma head (%.2fx): %s", chroma_head_width, f"{chroma_params:,}")
232
+ LOGGER.info(" - Code head (%.2fx): %s", head_width, f"{code_params:,}")
233
+
234
+ # MLflow setup
235
+ run_name = setup_mlflow_experiment("from_xyY", "multi_head_circular")
236
+ LOGGER.info("")
237
+ LOGGER.info("MLflow run: %s", run_name)
238
+
239
+ best_val_loss = float("inf")
240
+ best_state = None
241
+ patience_counter = 0
242
+
243
+ LOGGER.info("")
244
+ LOGGER.info("Starting training with circular hue loss...")
245
+
246
+ with mlflow.start_run(run_name=run_name):
247
+ mlflow.log_params(
248
+ {
249
+ "model": "multi_head_circular",
250
+ "encoder_width": encoder_width,
251
+ "head_width": head_width,
252
+ "chroma_head_width": chroma_head_width,
253
+ "dropout": dropout,
254
+ "learning_rate": lr,
255
+ "batch_size": batch_size,
256
+ "weight_decay": weight_decay,
257
+ "epochs": epochs,
258
+ "patience": patience,
259
+ "total_params": total_params,
260
+ "loss_type": "circular_hue",
261
+ }
262
+ )
263
+
264
+ for epoch in range(epochs):
265
+ # Training
266
+ model.train()
267
+ train_loss = 0.0
268
+ for X_batch, y_batch in train_loader:
269
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device)
270
+
271
+ optimizer.zero_grad()
272
+ pred = model(X_batch)
273
+
274
+ # Use circular loss for hue component
275
+ loss = munsell_component_loss(pred, y_batch, hue_range=1.0)
276
+
277
+ loss.backward()
278
+ optimizer.step()
279
+ train_loss += loss.item() * len(X_batch)
280
+
281
+ train_loss /= len(X_train_t)
282
+ scheduler.step()
283
+
284
+ # Validation
285
+ model.eval()
286
+ val_loss = 0.0
287
+ with torch.no_grad():
288
+ for X_batch, y_batch in val_loader:
289
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device)
290
+ pred = model(X_batch)
291
+ val_loss += munsell_component_loss(
292
+ pred, y_batch, hue_range=1.0
293
+ ).item() * len(X_batch)
294
+ val_loss /= len(X_val_t)
295
+
296
+ # Per-component MAE (denormalized for interpretability)
297
+ with torch.no_grad():
298
+ pred_val = model(X_val_t.to(device)).cpu()
299
+ # Denormalize predictions and ground truth
300
+ pred_denorm = pred_val.numpy()
301
+ hue_min, hue_max = output_params["hue_range"]
302
+ value_min, value_max = output_params["value_range"]
303
+ chroma_min, chroma_max = output_params["chroma_range"]
304
+ code_min, code_max = output_params["code_range"]
305
+
306
+ pred_denorm[:, 0] = pred_val[:, 0].numpy() * (hue_max - hue_min) + hue_min # hue
307
+ pred_denorm[:, 1] = pred_val[:, 1].numpy() * (value_max - value_min) + value_min # value
308
+ pred_denorm[:, 2] = pred_val[:, 2].numpy() * (chroma_max - chroma_min) + chroma_min # chroma
309
+ pred_denorm[:, 3] = pred_val[:, 3].numpy() * (code_max - code_min) + code_min # code
310
+
311
+ y_denorm = y_val_norm.copy()
312
+ y_denorm[:, 0] = y_val_norm[:, 0] * (hue_max - hue_min) + hue_min
313
+ y_denorm[:, 1] = y_val_norm[:, 1] * (value_max - value_min) + value_min
314
+ y_denorm[:, 2] = y_val_norm[:, 2] * (chroma_max - chroma_min) + chroma_min
315
+ y_denorm[:, 3] = y_val_norm[:, 3] * (code_max - code_min) + code_min
316
+
317
+ mae = np.mean(np.abs(pred_denorm - y_denorm), axis=0)
318
+
319
+ mlflow.log_metrics(
320
+ {
321
+ "train_loss": train_loss,
322
+ "val_loss": val_loss,
323
+ "mae_hue": mae[0],
324
+ "mae_value": mae[1],
325
+ "mae_chroma": mae[2],
326
+ "mae_code": mae[3],
327
+ },
328
+ step=epoch,
329
+ )
330
+
331
+ if val_loss < best_val_loss:
332
+ best_val_loss = val_loss
333
+ best_state = copy.deepcopy(model.state_dict())
334
+ patience_counter = 0
335
+ LOGGER.info(
336
+ "Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - "
337
+ "MAE: hue=%.4f, value=%.4f, chroma=%.4f, code=%.4f",
338
+ epoch + 1,
339
+ epochs,
340
+ train_loss,
341
+ val_loss,
342
+ mae[0],
343
+ mae[1],
344
+ mae[2],
345
+ mae[3],
346
+ )
347
+ else:
348
+ patience_counter += 1
349
+ if (epoch + 1) % 50 == 0:
350
+ LOGGER.info(
351
+ "Epoch %03d/%d - Train: %.6f, Val: %.6f",
352
+ epoch + 1,
353
+ epochs,
354
+ train_loss,
355
+ val_loss,
356
+ )
357
+
358
+ if patience_counter >= patience:
359
+ LOGGER.info("Early stopping at epoch %d", epoch + 1)
360
+ break
361
+
362
+ # Load best model
363
+ model.load_state_dict(best_state)
364
+
365
+ # Final evaluation
366
+ model.eval()
367
+ with torch.no_grad():
368
+ pred_val = model(X_val_t.to(device)).cpu()
369
+ pred_denorm = pred_val.numpy()
370
+ hue_min, hue_max = output_params["hue_range"]
371
+ value_min, value_max = output_params["value_range"]
372
+ chroma_min, chroma_max = output_params["chroma_range"]
373
+ code_min, code_max = output_params["code_range"]
374
+
375
+ pred_denorm[:, 0] = pred_val[:, 0].numpy() * (hue_max - hue_min) + hue_min
376
+ pred_denorm[:, 1] = pred_val[:, 1].numpy() * (value_max - value_min) + value_min
377
+ pred_denorm[:, 2] = pred_val[:, 2].numpy() * (chroma_max - chroma_min) + chroma_min
378
+ pred_denorm[:, 3] = pred_val[:, 3].numpy() * (code_max - code_min) + code_min
379
+
380
+ y_denorm = y_val_norm.copy()
381
+ y_denorm[:, 0] = y_val_norm[:, 0] * (hue_max - hue_min) + hue_min
382
+ y_denorm[:, 1] = y_val_norm[:, 1] * (value_max - value_min) + value_min
383
+ y_denorm[:, 2] = y_val_norm[:, 2] * (chroma_max - chroma_min) + chroma_min
384
+ y_denorm[:, 3] = y_val_norm[:, 3] * (code_max - code_min) + code_min
385
+
386
+ mae = np.mean(np.abs(pred_denorm - y_denorm), axis=0)
387
+
388
+ # Log final metrics
389
+ mlflow.log_metrics(
390
+ {
391
+ "best_val_loss": best_val_loss,
392
+ "final_mae_hue": mae[0],
393
+ "final_mae_value": mae[1],
394
+ "final_mae_chroma": mae[2],
395
+ "final_mae_code": mae[3],
396
+ "final_epoch": epoch + 1,
397
+ }
398
+ )
399
+
400
+ LOGGER.info("")
401
+ LOGGER.info("Final Results:")
402
+ LOGGER.info(" Best Val Loss: %.6f", best_val_loss)
403
+ LOGGER.info(" MAE hue: %.6f", mae[0])
404
+ LOGGER.info(" MAE value: %.6f", mae[1])
405
+ LOGGER.info(" MAE chroma: %.6f", mae[2])
406
+ LOGGER.info(" MAE code: %.6f", mae[3])
407
+
408
+ # Save model
409
+ models_dir = PROJECT_ROOT / "models" / "from_xyY"
410
+ models_dir.mkdir(exist_ok=True)
411
+
412
+ checkpoint_path = models_dir / "multi_head_circular.pth"
413
+ torch.save(
414
+ {
415
+ "model_state_dict": model.state_dict(),
416
+ "output_params": output_params,
417
+ "val_loss": best_val_loss,
418
+ "mae": {
419
+ "hue": float(mae[0]),
420
+ "value": float(mae[1]),
421
+ "chroma": float(mae[2]),
422
+ "code": float(mae[3]),
423
+ },
424
+ "hyperparameters": {
425
+ "encoder_width": encoder_width,
426
+ "head_width": head_width,
427
+ "chroma_head_width": chroma_head_width,
428
+ "dropout": dropout,
429
+ "lr": lr,
430
+ "batch_size": batch_size,
431
+ "weight_decay": weight_decay,
432
+ },
433
+ "loss_type": "circular_hue",
434
+ },
435
+ checkpoint_path,
436
+ )
437
+ LOGGER.info("")
438
+ LOGGER.info("Saved checkpoint: %s", checkpoint_path)
439
+
440
+ # Export to ONNX
441
+ model.cpu().eval()
442
+ dummy_input = torch.randn(1, 3)
443
+ onnx_path = models_dir / "multi_head_circular.onnx"
444
+
445
+ torch.onnx.export(
446
+ model,
447
+ dummy_input,
448
+ onnx_path,
449
+ input_names=["xyY"], # Match other models for comparison compatibility
450
+ output_names=["munsell_spec"],
451
+ dynamic_axes={"xyY": {0: "batch"}, "munsell_spec": {0: "batch"}},
452
+ opset_version=17,
453
+ )
454
+ LOGGER.info("Saved ONNX: %s", onnx_path)
455
+
456
+ # Save normalization parameters
457
+ params_path = models_dir / "multi_head_circular_normalization_params.npz"
458
+ np.savez(
459
+ params_path,
460
+ output_params=output_params,
461
+ )
462
+ LOGGER.info("Saved normalization parameters: %s", params_path)
463
+
464
+ # Log artifacts to MLflow
465
+ mlflow.log_artifact(str(checkpoint_path))
466
+ mlflow.log_artifact(str(onnx_path))
467
+ mlflow.log_artifact(str(params_path))
468
+ mlflow.pytorch.log_model(model, "model")
469
+ LOGGER.info("Artifacts logged to MLflow")
470
+
471
+ LOGGER.info("=" * 80)
472
+
473
+ return model, best_val_loss
474
+
475
+
476
+ if __name__ == "__main__":
477
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
478
+
479
+ main()
learning_munsell/training/from_xyY/train_multi_head_cross_attention_error_predictor.py ADDED
@@ -0,0 +1,640 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Multi-Head + Cross-Attention Error Predictor for xyY to Munsell conversion.
3
+
4
+ This version uses cross-attention between component branches to learn
5
+ correlations between errors in different Munsell components.
6
+
7
+ Key Features:
8
+ - Shared context encoder
9
+ - Multi-head cross-attention between components
10
+ - Component-specific prediction heads
11
+ - Residual connections
12
+ """
13
+
14
+ from __future__ import annotations
15
+
16
+ import copy
17
+ import logging
18
+
19
+ import mlflow
20
+ import mlflow.pytorch
21
+ import numpy as np
22
+ import onnxruntime as ort
23
+ import torch
24
+ from torch import nn, optim
25
+ from torch.utils.data import DataLoader, TensorDataset
26
+
27
+ from learning_munsell import PROJECT_ROOT
28
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
29
+ from learning_munsell.utilities.data import normalize_xyY, normalize_munsell
30
+
31
+ LOGGER = logging.getLogger(__name__)
32
+
33
+ # Note: This script has a custom CrossAttentionErrorPredictor architecture
34
+ # so we don't import ComponentErrorPredictor/MultiHeadErrorPredictor from shared modules.
35
+
36
+
37
+ class CustomMultiheadAttention(nn.Module):
38
+ """
39
+ Custom multi-head attention that exports cleanly to ONNX.
40
+
41
+ Uses basic operations instead of nn.MultiheadAttention to avoid
42
+ reshape issues with dynamic batch sizes during ONNX export.
43
+
44
+ Parameters
45
+ ----------
46
+ embed_dim : int
47
+ Total dimension of the model (must be divisible by num_heads).
48
+ num_heads : int
49
+ Number of parallel attention heads.
50
+ dropout : float, optional
51
+ Dropout probability on attention weights.
52
+
53
+ Attributes
54
+ ----------
55
+ embed_dim : int
56
+ Total embedding dimension.
57
+ num_heads : int
58
+ Number of attention heads.
59
+ head_dim : int
60
+ Dimension of each attention head (embed_dim // num_heads).
61
+ scale : float
62
+ Scaling factor for attention scores (head_dim ** -0.5).
63
+ q_proj : nn.Linear
64
+ Query projection layer.
65
+ k_proj : nn.Linear
66
+ Key projection layer.
67
+ v_proj : nn.Linear
68
+ Value projection layer.
69
+ out_proj : nn.Linear
70
+ Output projection layer.
71
+ dropout : nn.Dropout
72
+ Dropout layer for attention weights.
73
+ """
74
+
75
+ def __init__(
76
+ self,
77
+ embed_dim: int,
78
+ num_heads: int,
79
+ dropout: float = 0.0,
80
+ ) -> None:
81
+ """Initialize the custom multi-head attention module."""
82
+ super().__init__()
83
+
84
+ assert embed_dim % num_heads == 0, "embed_dim must be divisible by num_heads"
85
+
86
+ self.embed_dim = embed_dim
87
+ self.num_heads = num_heads
88
+ self.head_dim = embed_dim // num_heads
89
+ self.scale = self.head_dim**-0.5
90
+
91
+ # Linear projections for Q, K, V
92
+ self.q_proj = nn.Linear(embed_dim, embed_dim)
93
+ self.k_proj = nn.Linear(embed_dim, embed_dim)
94
+ self.v_proj = nn.Linear(embed_dim, embed_dim)
95
+ self.out_proj = nn.Linear(embed_dim, embed_dim)
96
+
97
+ self.dropout = nn.Dropout(dropout)
98
+
99
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
100
+ """
101
+ Forward pass for self-attention.
102
+
103
+ Parameters
104
+ ----------
105
+ x : Tensor
106
+ Input tensor [batch, seq_len, embed_dim]
107
+
108
+ Returns
109
+ -------
110
+ Tensor
111
+ Output tensor [batch, seq_len, embed_dim]
112
+ """
113
+ batch_size, seq_len, embed_dim = x.shape
114
+
115
+ # Project to Q, K, V
116
+ q = self.q_proj(x) # [batch, seq_len, embed_dim]
117
+ k = self.k_proj(x) # [batch, seq_len, embed_dim]
118
+ v = self.v_proj(x) # [batch, seq_len, embed_dim]
119
+
120
+ # Reshape for multi-head attention: [batch, seq_len, num_heads, head_dim]
121
+ # Then transpose to: [batch, num_heads, seq_len, head_dim]
122
+ # Use -1 for batch dimension to enable dynamic batch size in ONNX
123
+ q = q.reshape(-1, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
124
+ k = k.reshape(-1, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
125
+ v = v.reshape(-1, seq_len, self.num_heads, self.head_dim).transpose(1, 2)
126
+
127
+ # Scaled dot-product attention
128
+ # Q @ K^T: [batch, heads, seq, dim] @ [batch, heads, dim, seq]
129
+ # -> [batch, heads, seq, seq]
130
+ attn_scores = torch.matmul(q, k.transpose(-2, -1)) * self.scale
131
+ attn_weights = torch.softmax(attn_scores, dim=-1)
132
+ attn_weights = self.dropout(attn_weights)
133
+
134
+ # Apply attention to values
135
+ # [batch, num_heads, seq_len, seq_len] @ [batch, num_heads, seq_len, head_dim]
136
+ # -> [batch, num_heads, seq_len, head_dim]
137
+ attn_output = torch.matmul(attn_weights, v)
138
+
139
+ # Transpose back and reshape: [batch, num_heads, seq_len, head_dim]
140
+ # -> [batch, seq_len, num_heads, head_dim]
141
+ # -> [batch, seq_len, embed_dim]
142
+ # Use -1 for batch dimension to enable dynamic batch size in ONNX
143
+ attn_output = attn_output.transpose(1, 2).contiguous()
144
+ attn_output = attn_output.reshape(-1, seq_len, self.embed_dim)
145
+
146
+ # Final projection
147
+ output = self.out_proj(attn_output)
148
+
149
+ return output
150
+
151
+
152
+ class CrossAttentionErrorPredictor(nn.Module):
153
+ """
154
+ Error predictor with cross-attention between Munsell components.
155
+
156
+ Uses cross-attention to learn correlations between errors in different
157
+ Munsell components (hue, value, chroma, code).
158
+
159
+ Parameters
160
+ ----------
161
+ input_dim : int, optional
162
+ Input dimension (7 = xyY_norm + base_pred_norm).
163
+ context_dim : int, optional
164
+ Dimension of shared context features.
165
+ component_dim : int, optional
166
+ Dimension of component-specific features.
167
+ n_components : int, optional
168
+ Number of Munsell components (4).
169
+ n_attention_heads : int, optional
170
+ Number of attention heads for cross-attention.
171
+ dropout : float, optional
172
+ Dropout probability.
173
+
174
+ Attributes
175
+ ----------
176
+ context_encoder : nn.Sequential
177
+ Shared encoder: input_dim → 256 → context_dim.
178
+ component_encoders : nn.ModuleList
179
+ Component-specific encoders: context_dim → component_dim (x4).
180
+ cross_attention : CustomMultiheadAttention
181
+ Cross-attention module between component features.
182
+ attention_norm : nn.LayerNorm
183
+ Layer normalization after attention.
184
+ component_decoders : nn.ModuleList
185
+ Component-specific decoders: component_dim → 128 → 1 (x4).
186
+
187
+ Notes
188
+ -----
189
+ Architecture:
190
+ 1. Shared context encoder: 7 → 256 → 512
191
+ 2. Component-specific encoders: 512 → 256 (x4)
192
+ 3. Multi-head cross-attention between components
193
+ 4. Residual connection + layer norm
194
+ 5. Component-specific decoders: 256 → 128 → 1
195
+ """
196
+
197
+ def __init__(
198
+ self,
199
+ input_dim: int = 7,
200
+ context_dim: int = 512,
201
+ component_dim: int = 256,
202
+ n_components: int = 4,
203
+ n_attention_heads: int = 4,
204
+ dropout: float = 0.1,
205
+ ) -> None:
206
+ """Initialize the cross-attention error predictor."""
207
+ super().__init__()
208
+
209
+ self.n_components = n_components
210
+ self.component_dim = component_dim
211
+
212
+ # Shared context encoder
213
+ self.context_encoder = nn.Sequential(
214
+ nn.Linear(input_dim, 256),
215
+ nn.GELU(),
216
+ nn.LayerNorm(256),
217
+ nn.Dropout(dropout),
218
+ nn.Linear(256, context_dim),
219
+ nn.GELU(),
220
+ nn.LayerNorm(context_dim),
221
+ )
222
+
223
+ # Component-specific encoders
224
+ self.component_encoders = nn.ModuleList(
225
+ [
226
+ nn.Sequential(
227
+ nn.Linear(context_dim, component_dim),
228
+ nn.GELU(),
229
+ nn.LayerNorm(component_dim),
230
+ )
231
+ for _ in range(n_components)
232
+ ]
233
+ )
234
+
235
+ # Multi-head cross-attention (using custom implementation)
236
+ self.cross_attention = CustomMultiheadAttention(
237
+ embed_dim=component_dim,
238
+ num_heads=n_attention_heads,
239
+ dropout=dropout,
240
+ )
241
+
242
+ # Layer norm after attention
243
+ self.attention_norm = nn.LayerNorm(component_dim)
244
+
245
+ # Component-specific decoders
246
+ self.component_decoders = nn.ModuleList(
247
+ [
248
+ nn.Sequential(
249
+ nn.Linear(component_dim, 128),
250
+ nn.GELU(),
251
+ nn.LayerNorm(128),
252
+ nn.Dropout(dropout),
253
+ nn.Linear(128, 1),
254
+ )
255
+ for _ in range(n_components)
256
+ ]
257
+ )
258
+
259
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
260
+ """
261
+ Forward pass with cross-attention.
262
+
263
+ Parameters
264
+ ----------
265
+ x : Tensor
266
+ Input [xyY_norm (3) + base_pred_norm (4)] = 7 features
267
+
268
+ Returns
269
+ -------
270
+ Tensor
271
+ Predicted errors [hue_err, value_err, chroma_err, code_err]
272
+ """
273
+ # Shared context encoding
274
+ context = self.context_encoder(x) # [batch, 512]
275
+
276
+ # Component-specific encoding
277
+ component_features = []
278
+ for encoder in self.component_encoders:
279
+ feat = encoder(context) # [batch, 256]
280
+ component_features.append(feat)
281
+
282
+ # Stack for cross-attention: [batch, 4, 256]
283
+ component_stack = torch.stack(component_features, dim=1)
284
+
285
+ # Cross-attention between components
286
+ attended = self.cross_attention(component_stack) # [batch, 4, 256]
287
+
288
+ # Residual connection + layer norm
289
+ component_stack = self.attention_norm(component_stack + attended)
290
+
291
+ # Component-specific decoding (unrolled for ONNX compatibility)
292
+ # Use unbind to split the tensor instead of indexing to preserve batch dimension
293
+ components = torch.unbind(
294
+ component_stack, dim=1
295
+ ) # Split into 4 tensors of shape [batch, 256]
296
+
297
+ # Decode each component explicitly
298
+ pred_0 = self.component_decoders[0](components[0]) # [batch, 1]
299
+ pred_1 = self.component_decoders[1](components[1]) # [batch, 1]
300
+ pred_2 = self.component_decoders[2](components[2]) # [batch, 1]
301
+ pred_3 = self.component_decoders[3](components[3]) # [batch, 1]
302
+
303
+ # Concatenate along dimension 1 and squeeze
304
+ predictions = torch.cat([pred_0, pred_1, pred_2, pred_3], dim=1) # [batch, 4]
305
+
306
+ return predictions
307
+
308
+
309
+ def train_cross_attention_error_predictor(
310
+ epochs: int = 300,
311
+ batch_size: int = 1024,
312
+ lr: float = 0.0005,
313
+ dropout: float = 0.1,
314
+ context_dim: int = 512,
315
+ component_dim: int = 256,
316
+ n_attention_heads: int = 4,
317
+ ) -> tuple[CrossAttentionErrorPredictor, float]:
318
+ """
319
+ Train cross-attention error predictor.
320
+
321
+ This model uses cross-attention between component branches to learn
322
+ correlations between errors in different Munsell components.
323
+
324
+ Parameters
325
+ ----------
326
+ epochs : int, optional
327
+ Maximum number of training epochs.
328
+ batch_size : int, optional
329
+ Training batch size.
330
+ lr : float, optional
331
+ Learning rate for AdamW optimizer.
332
+ dropout : float, optional
333
+ Dropout rate for regularization.
334
+ context_dim : int, optional
335
+ Dimension of shared context features.
336
+ component_dim : int, optional
337
+ Dimension of component-specific features.
338
+ n_attention_heads : int, optional
339
+ Number of attention heads for cross-attention.
340
+
341
+ Returns
342
+ -------
343
+ model : CrossAttentionErrorPredictor
344
+ Trained model with best validation loss weights.
345
+ best_val_loss : float
346
+ Best validation loss achieved during training.
347
+
348
+ Notes
349
+ -----
350
+ The training pipeline:
351
+ 1. Loads pre-trained Multi-Head base model
352
+ 2. Generates base model predictions for training data
353
+ 3. Computes residual errors between predictions and targets
354
+ 4. Trains cross-attention error predictor on these residuals
355
+ 5. Uses CosineAnnealingLR scheduler
356
+ 6. Early stopping based on validation loss
357
+ 7. Exports model to ONNX format
358
+ 8. Logs metrics and artifacts to MLflow
359
+ """
360
+
361
+ LOGGER.info("=" * 80)
362
+ LOGGER.info("Training Multi-Head + Cross-Attention Error Predictor")
363
+ LOGGER.info("=" * 80)
364
+ LOGGER.info("")
365
+ LOGGER.info("Architecture:")
366
+ LOGGER.info(" - Shared context encoder: 7 → 256 → %d", context_dim)
367
+ LOGGER.info(" - Component encoders: %d → %d (x4)", context_dim, component_dim)
368
+ LOGGER.info(" - Cross-attention: %d heads", n_attention_heads)
369
+ LOGGER.info(" - Component decoders: %d → 128 → 1 (x4)", component_dim)
370
+ LOGGER.info("")
371
+ LOGGER.info("Hyperparameters:")
372
+ LOGGER.info(" lr: %.6f", lr)
373
+ LOGGER.info(" batch_size: %d", batch_size)
374
+ LOGGER.info(" dropout: %.2f", dropout)
375
+ LOGGER.info("")
376
+
377
+ device = torch.device("mps" if torch.backends.mps.is_available() else "cpu")
378
+ LOGGER.info("Using device: %s", device)
379
+
380
+ # Paths
381
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
382
+ data_dir = PROJECT_ROOT / "data"
383
+ base_model_path = model_directory / "multi_head.onnx"
384
+ params_path = model_directory / "multi_head_normalization_params.npz"
385
+ cache_file = data_dir / "training_data.npz"
386
+
387
+ # Load base model
388
+ LOGGER.info("")
389
+ LOGGER.info("Loading Multi-Head base model from %s...", base_model_path)
390
+ base_session = ort.InferenceSession(str(base_model_path))
391
+ params = np.load(params_path, allow_pickle=True)
392
+ input_params = params["input_params"].item()
393
+ output_params = params["output_params"].item()
394
+
395
+ # Load training data
396
+ LOGGER.info("Loading training data from %s...", cache_file)
397
+ data = np.load(cache_file)
398
+ X_train = data["X_train"]
399
+ y_train = data["y_train"]
400
+ X_val = data["X_val"]
401
+ y_val = data["y_val"]
402
+
403
+ LOGGER.info("Train samples: %d", len(X_train))
404
+ LOGGER.info("Validation samples: %d", len(X_val))
405
+
406
+ # Generate base model predictions
407
+ LOGGER.info("")
408
+ LOGGER.info("Generating Multi-Head base model predictions...")
409
+ X_train_norm = normalize_xyY(X_train, input_params)
410
+ y_train_norm = normalize_munsell(y_train, output_params)
411
+ base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
412
+
413
+ X_val_norm = normalize_xyY(X_val, input_params)
414
+ y_val_norm = normalize_munsell(y_val, output_params)
415
+ base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
416
+
417
+ # Compute errors
418
+ error_train = y_train_norm - base_pred_train_norm
419
+ error_val = y_val_norm - base_pred_val_norm
420
+
421
+ LOGGER.info("")
422
+ LOGGER.info("Base model error statistics (normalized space):")
423
+ LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train)))
424
+ LOGGER.info(" Std of error: %.6f", np.std(error_train))
425
+ LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train)))
426
+
427
+ # Create combined input: [xyY_norm, base_prediction_norm]
428
+ X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
429
+ X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
430
+
431
+ # Convert to PyTorch tensors
432
+ X_train_t = torch.FloatTensor(X_train_combined)
433
+ error_train_t = torch.FloatTensor(error_train)
434
+ X_val_t = torch.FloatTensor(X_val_combined)
435
+ error_val_t = torch.FloatTensor(error_val)
436
+
437
+ # Create data loaders
438
+ train_dataset = TensorDataset(X_train_t, error_train_t)
439
+ val_dataset = TensorDataset(X_val_t, error_val_t)
440
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
441
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
442
+
443
+ # Initialize model
444
+ model = CrossAttentionErrorPredictor(
445
+ input_dim=7,
446
+ context_dim=context_dim,
447
+ component_dim=component_dim,
448
+ n_attention_heads=n_attention_heads,
449
+ dropout=dropout,
450
+ ).to(device)
451
+
452
+ # Count parameters
453
+ total_params = sum(p.numel() for p in model.parameters())
454
+ LOGGER.info("")
455
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
456
+
457
+ context_params = sum(p.numel() for p in model.context_encoder.parameters())
458
+ attention_params = sum(p.numel() for p in model.cross_attention.parameters())
459
+ LOGGER.info(" - Context encoder: %s", f"{context_params:,}")
460
+ LOGGER.info(" - Cross-attention: %s", f"{attention_params:,}")
461
+
462
+ # Training setup
463
+ criterion = nn.MSELoss()
464
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
465
+ scheduler = optim.lr_scheduler.CosineAnnealingLR(optimizer, T_max=epochs)
466
+
467
+ # MLflow setup
468
+ run_name = setup_mlflow_experiment("from_xyY", "cross_attention_error_predictor")
469
+ LOGGER.info("")
470
+ LOGGER.info("MLflow run: %s", run_name)
471
+
472
+ # Training loop
473
+ best_val_loss = float("inf")
474
+ best_state = None
475
+ patience = 30
476
+ patience_counter = 0
477
+
478
+ LOGGER.info("")
479
+ LOGGER.info("Starting training...")
480
+
481
+ with mlflow.start_run(run_name=run_name):
482
+ mlflow.log_params(
483
+ {
484
+ "model": "cross_attention_error_predictor",
485
+ "context_dim": context_dim,
486
+ "component_dim": component_dim,
487
+ "n_attention_heads": n_attention_heads,
488
+ "dropout": dropout,
489
+ "learning_rate": lr,
490
+ "batch_size": batch_size,
491
+ "epochs": epochs,
492
+ "patience": patience,
493
+ "total_params": total_params,
494
+ }
495
+ )
496
+
497
+ for epoch in range(epochs):
498
+ # Training
499
+ model.train()
500
+ train_loss = 0.0
501
+ for X_batch, y_batch in train_loader:
502
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device)
503
+
504
+ optimizer.zero_grad()
505
+ pred = model(X_batch)
506
+ loss = criterion(pred, y_batch)
507
+ loss.backward()
508
+ optimizer.step()
509
+ train_loss += loss.item() * len(X_batch)
510
+
511
+ train_loss /= len(X_train_t)
512
+ scheduler.step()
513
+
514
+ # Validation
515
+ model.eval()
516
+ val_loss = 0.0
517
+ with torch.no_grad():
518
+ for X_batch, y_batch in val_loader:
519
+ X_batch, y_batch = X_batch.to(device), y_batch.to(device)
520
+ pred = model(X_batch)
521
+ val_loss += criterion(pred, y_batch).item() * len(X_batch)
522
+ val_loss /= len(X_val_t)
523
+
524
+ log_training_epoch(
525
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
526
+ )
527
+
528
+ if val_loss < best_val_loss:
529
+ best_val_loss = val_loss
530
+ best_state = copy.deepcopy(model.state_dict())
531
+ patience_counter = 0
532
+ LOGGER.info(
533
+ "Epoch %03d/%d - Train: %.6f, Val: %.6f (best) - LR: %.6f",
534
+ epoch + 1,
535
+ epochs,
536
+ train_loss,
537
+ val_loss,
538
+ optimizer.param_groups[0]["lr"],
539
+ )
540
+ else:
541
+ patience_counter += 1
542
+ if (epoch + 1) % 50 == 0:
543
+ LOGGER.info(
544
+ "Epoch %03d/%d - Train: %.6f, Val: %.6f",
545
+ epoch + 1,
546
+ epochs,
547
+ train_loss,
548
+ val_loss,
549
+ )
550
+
551
+ if patience_counter >= patience:
552
+ LOGGER.info("Early stopping at epoch %d", epoch + 1)
553
+ break
554
+
555
+ # Load best model
556
+ model.load_state_dict(best_state)
557
+
558
+ mlflow.log_metrics(
559
+ {
560
+ "best_val_loss": best_val_loss,
561
+ "final_epoch": epoch + 1,
562
+ }
563
+ )
564
+
565
+ LOGGER.info("")
566
+ LOGGER.info("Final Results:")
567
+ LOGGER.info(" Best Val Loss: %.6f", best_val_loss)
568
+
569
+ # Save model
570
+ model_directory.mkdir(exist_ok=True)
571
+ checkpoint_path = (
572
+ model_directory / "multi_head_cross_attention_error_predictor.pth"
573
+ )
574
+
575
+ torch.save(
576
+ {
577
+ "model_state_dict": model.state_dict(),
578
+ "val_loss": best_val_loss,
579
+ "hyperparameters": {
580
+ "context_dim": context_dim,
581
+ "component_dim": component_dim,
582
+ "n_attention_heads": n_attention_heads,
583
+ "dropout": dropout,
584
+ "lr": lr,
585
+ "batch_size": batch_size,
586
+ },
587
+ },
588
+ checkpoint_path,
589
+ )
590
+ LOGGER.info("")
591
+ LOGGER.info("Saved checkpoint: %s", checkpoint_path)
592
+
593
+ # Export to ONNX
594
+ LOGGER.info("")
595
+ LOGGER.info("Exporting error predictor to ONNX...")
596
+ model.eval()
597
+ model.cpu()
598
+
599
+ dummy_input = torch.randn(1, 7)
600
+ onnx_path = model_directory / "multi_head_cross_attention_error_predictor.onnx"
601
+
602
+ torch.onnx.export(
603
+ model,
604
+ dummy_input,
605
+ onnx_path,
606
+ export_params=True,
607
+ opset_version=17,
608
+ input_names=["combined_input"],
609
+ output_names=["error_correction"],
610
+ dynamic_axes={
611
+ "combined_input": {0: "batch_size"},
612
+ "error_correction": {0: "batch_size"},
613
+ },
614
+ )
615
+
616
+ mlflow.log_artifact(str(checkpoint_path))
617
+ mlflow.log_artifact(str(onnx_path))
618
+ mlflow.pytorch.log_model(model, "model")
619
+
620
+ LOGGER.info("ONNX model saved to: %s", onnx_path)
621
+ LOGGER.info("Artifacts logged to MLflow")
622
+
623
+ LOGGER.info("=" * 80)
624
+
625
+
626
+ return model, best_val_loss
627
+
628
+
629
+ if __name__ == "__main__":
630
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
631
+
632
+ train_cross_attention_error_predictor(
633
+ epochs=300,
634
+ batch_size=1024,
635
+ lr=0.0005,
636
+ dropout=0.1,
637
+ context_dim=512,
638
+ component_dim=256,
639
+ n_attention_heads=4,
640
+ )
learning_munsell/training/from_xyY/train_multi_head_gamma.py ADDED
@@ -0,0 +1,300 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train multi-head ML model for xyY to Munsell conversion with gamma-corrected Y.
3
+
4
+ Experiment: Apply gamma 2.33 to Y before normalization to better align
5
+ with perceptual lightness (Munsell Value scale is perceptually uniform).
6
+
7
+ The multi-head architecture has separate heads for each Munsell component,
8
+ so gamma correction on Y should primarily benefit Value prediction without
9
+ negatively impacting Chroma prediction (unlike the single MLP).
10
+ """
11
+
12
+ import logging
13
+ from typing import Any
14
+
15
+ import click
16
+ import mlflow
17
+ import mlflow.pytorch
18
+ import numpy as np
19
+ import torch
20
+ from numpy.typing import NDArray
21
+ from torch import nn, optim
22
+ from torch.utils.data import DataLoader, TensorDataset
23
+
24
+ from learning_munsell import PROJECT_ROOT
25
+ from learning_munsell.models.networks import MultiHeadMLPToMunsell
26
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
27
+ from learning_munsell.utilities.data import MUNSELL_NORMALIZATION_PARAMS, normalize_munsell
28
+ from learning_munsell.utilities.losses import weighted_mse_loss
29
+ from learning_munsell.utilities.training import train_epoch, validate
30
+
31
+ LOGGER = logging.getLogger(__name__)
32
+
33
+ # Gamma value for Y transformation
34
+ GAMMA = 2.33
35
+
36
+
37
+ def normalize_inputs(
38
+ X: NDArray, gamma: float = GAMMA
39
+ ) -> tuple[NDArray, dict[str, Any]]:
40
+ """
41
+ Normalize xyY inputs to [0, 1] range with gamma correction on Y.
42
+
43
+ Parameters
44
+ ----------
45
+ X : ndarray
46
+ xyY values of shape (n, 3) where columns are [x, y, Y].
47
+ gamma : float
48
+ Gamma value to apply to Y component.
49
+
50
+ Returns
51
+ -------
52
+ ndarray
53
+ Normalized values with gamma-corrected Y, dtype float32.
54
+ dict
55
+ Normalization parameters including gamma value.
56
+ """
57
+ # xyY chromaticity and luminance ranges (all [0, 1])
58
+ x_range = (0.0, 1.0)
59
+ y_range = (0.0, 1.0)
60
+ Y_range = (0.0, 1.0)
61
+
62
+ X_norm = X.copy()
63
+ X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0])
64
+ X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0])
65
+
66
+ # Normalize Y first, then apply gamma
67
+ Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0])
68
+ # Clip to avoid numerical issues with negative values
69
+ Y_normalized = np.clip(Y_normalized, 0, 1)
70
+ # Apply gamma: Y_gamma = Y^(1/gamma) - this spreads dark values, compresses light
71
+ X_norm[:, 2] = np.power(Y_normalized, 1.0 / gamma)
72
+
73
+ params = {
74
+ "x_range": x_range,
75
+ "y_range": y_range,
76
+ "Y_range": Y_range,
77
+ "gamma": gamma,
78
+ }
79
+
80
+ return X_norm, params
81
+
82
+
83
+
84
+
85
+ @click.command()
86
+ @click.option("--epochs", default=200, help="Number of training epochs")
87
+ @click.option("--batch-size", default=1024, help="Batch size for training")
88
+ @click.option("--lr", default=5e-4, help="Learning rate")
89
+ @click.option("--patience", default=20, help="Early stopping patience")
90
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
91
+ """
92
+ Train the multi-head model with gamma-corrected Y input.
93
+
94
+ Notes
95
+ -----
96
+ The training pipeline:
97
+ 1. Loads training and validation data from cache
98
+ 2. Normalizes inputs with gamma correction (gamma=2.33) on Y
99
+ 3. Normalizes Munsell outputs to [0, 1] range
100
+ 4. Trains multi-head MLP with weighted MSE loss
101
+ 5. Uses early stopping based on validation loss
102
+ 6. Exports best model to ONNX format
103
+ 7. Logs metrics and artifacts to MLflow
104
+
105
+ The gamma correction on Y aligns with perceptual lightness. The Munsell
106
+ Value scale is perceptually uniform, so gamma correction should primarily
107
+ benefit Value prediction without negatively impacting Chroma prediction.
108
+ """
109
+
110
+ LOGGER.info("=" * 80)
111
+ LOGGER.info("ML-Based xyY to Munsell Conversion: Multi-Head Gamma Experiment")
112
+ LOGGER.info("Gamma = %.2f applied to Y component", GAMMA)
113
+ LOGGER.info("=" * 80)
114
+
115
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
116
+ LOGGER.info("Using device: %s", device)
117
+
118
+ # Load training data
119
+ data_dir = PROJECT_ROOT / "data"
120
+ cache_file = data_dir / "training_data.npz"
121
+
122
+ if not cache_file.exists():
123
+ LOGGER.error("Error: Training data not found at %s", cache_file)
124
+ LOGGER.error("Please run 01_generate_training_data.py first")
125
+ return
126
+
127
+ LOGGER.info("Loading training data from %s...", cache_file)
128
+ data = np.load(cache_file)
129
+
130
+ X_train = data["X_train"]
131
+ y_train = data["y_train"]
132
+ X_val = data["X_val"]
133
+ y_val = data["y_val"]
134
+
135
+ LOGGER.info("Train samples: %d", len(X_train))
136
+ LOGGER.info("Validation samples: %d", len(X_val))
137
+
138
+ # Normalize data with gamma correction
139
+ X_train_norm, input_params = normalize_inputs(X_train, gamma=GAMMA)
140
+ X_val_norm, _ = normalize_inputs(X_val, gamma=GAMMA)
141
+
142
+ # Use shared normalization parameters for Munsell outputs
143
+ output_params = MUNSELL_NORMALIZATION_PARAMS
144
+ y_train_norm = normalize_munsell(y_train, output_params)
145
+ y_val_norm = normalize_munsell(y_val, output_params)
146
+
147
+ LOGGER.info("")
148
+ LOGGER.info("Input normalization with gamma=%.2f:", GAMMA)
149
+ LOGGER.info(" Y range after gamma: [%.4f, %.4f]", X_train_norm[:, 2].min(), X_train_norm[:, 2].max())
150
+
151
+ # Convert to PyTorch tensors
152
+ X_train_t = torch.FloatTensor(X_train_norm)
153
+ y_train_t = torch.FloatTensor(y_train_norm)
154
+ X_val_t = torch.FloatTensor(X_val_norm)
155
+ y_val_t = torch.FloatTensor(y_val_norm)
156
+
157
+ # Create data loaders
158
+ train_dataset = TensorDataset(X_train_t, y_train_t)
159
+ val_dataset = TensorDataset(X_val_t, y_val_t)
160
+
161
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
162
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
163
+
164
+ # Initialize model
165
+ model = MultiHeadMLPToMunsell().to(device)
166
+ LOGGER.info("")
167
+ LOGGER.info("Model architecture:")
168
+ LOGGER.info("%s", model)
169
+
170
+ total_params = sum(p.numel() for p in model.parameters())
171
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
172
+
173
+ # Training setup
174
+ optimizer = optim.Adam(model.parameters(), lr=lr)
175
+ criterion = weighted_mse_loss
176
+
177
+ # MLflow setup
178
+ run_name = setup_mlflow_experiment("from_xyY", f"multi_head_gamma_{GAMMA}")
179
+
180
+ LOGGER.info("")
181
+ LOGGER.info("MLflow run: %s", run_name)
182
+
183
+ # Training loop
184
+ best_val_loss = float("inf")
185
+ patience_counter = 0
186
+
187
+ LOGGER.info("")
188
+ LOGGER.info("Starting training...")
189
+
190
+ with mlflow.start_run(run_name=run_name):
191
+ mlflow.log_params(
192
+ {
193
+ "model": "multi_head_gamma",
194
+ "num_epochs": epochs,
195
+ "batch_size": batch_size,
196
+ "learning_rate": lr,
197
+ "optimizer": "Adam",
198
+ "criterion": "weighted_mse_loss",
199
+ "patience": patience,
200
+ "total_params": total_params,
201
+ "gamma": GAMMA,
202
+ }
203
+ )
204
+
205
+ for epoch in range(epochs):
206
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
207
+ val_loss = validate(model, val_loader, criterion, device)
208
+
209
+ log_training_epoch(
210
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
211
+ )
212
+
213
+ LOGGER.info(
214
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
215
+ epoch + 1,
216
+ epochs,
217
+ train_loss,
218
+ val_loss,
219
+ )
220
+
221
+ if val_loss < best_val_loss:
222
+ best_val_loss = val_loss
223
+ patience_counter = 0
224
+
225
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
226
+ model_directory.mkdir(exist_ok=True)
227
+ checkpoint_file = model_directory / "multi_head_gamma_best.pth"
228
+
229
+ torch.save(
230
+ {
231
+ "model_state_dict": model.state_dict(),
232
+ "input_params": input_params,
233
+ "output_params": output_params,
234
+ "epoch": epoch,
235
+ "val_loss": val_loss,
236
+ },
237
+ checkpoint_file,
238
+ )
239
+
240
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
241
+ else:
242
+ patience_counter += 1
243
+ if patience_counter >= patience:
244
+ LOGGER.info("")
245
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
246
+ break
247
+
248
+ mlflow.log_metrics(
249
+ {
250
+ "best_val_loss": best_val_loss,
251
+ "final_epoch": epoch + 1,
252
+ }
253
+ )
254
+
255
+ # Export to ONNX
256
+ LOGGER.info("")
257
+ LOGGER.info("Exporting model to ONNX...")
258
+ model.eval()
259
+
260
+ checkpoint = torch.load(checkpoint_file)
261
+ model.load_state_dict(checkpoint["model_state_dict"])
262
+
263
+ dummy_input = torch.randn(1, 3).to(device)
264
+
265
+ onnx_file = model_directory / "multi_head_gamma.onnx"
266
+ torch.onnx.export(
267
+ model,
268
+ dummy_input,
269
+ onnx_file,
270
+ export_params=True,
271
+ opset_version=15,
272
+ input_names=["xyY_gamma"],
273
+ output_names=["munsell_spec"],
274
+ dynamic_axes={"xyY_gamma": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
275
+ )
276
+
277
+ # Save normalization parameters (including gamma)
278
+ params_file = model_directory / "multi_head_gamma_normalization_params.npz"
279
+ np.savez(
280
+ params_file,
281
+ input_params=input_params,
282
+ output_params=output_params,
283
+ )
284
+
285
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
286
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
287
+ LOGGER.info("IMPORTANT: Input Y must be gamma-corrected with gamma=%.2f", GAMMA)
288
+
289
+ mlflow.log_artifact(str(checkpoint_file))
290
+ mlflow.log_artifact(str(onnx_file))
291
+ mlflow.log_artifact(str(params_file))
292
+ mlflow.pytorch.log_model(model, "model")
293
+
294
+ LOGGER.info("=" * 80)
295
+
296
+
297
+ if __name__ == "__main__":
298
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
299
+
300
+ main()
learning_munsell/training/from_xyY/train_multi_head_gamma_sweep.py ADDED
@@ -0,0 +1,605 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train multi-head ML models with various gamma values to find optimal gamma.
3
+
4
+ Sweeps gamma from 1.0 to 3.0 in increments of 0.1 and evaluates each model
5
+ on real Munsell colours using Delta-E CIE2000.
6
+
7
+ Supports parallel execution with multiple runs per gamma for averaging.
8
+ """
9
+
10
+ import logging
11
+ from concurrent.futures import ProcessPoolExecutor, as_completed
12
+ from typing import Any
13
+
14
+ import numpy as np
15
+ import torch
16
+ from colour import XYZ_to_Lab, xyY_to_XYZ
17
+ from colour.difference import delta_E_CIE2000
18
+ from colour.notation.datasets.munsell import MUNSELL_COLOURS_REAL
19
+ from colour.notation.munsell import (
20
+ CCS_ILLUMINANT_MUNSELL,
21
+ munsell_specification_to_xyY,
22
+ )
23
+ from numpy.typing import NDArray
24
+ from torch import nn, optim
25
+ from torch.utils.data import DataLoader, TensorDataset
26
+
27
+ from learning_munsell import PROJECT_ROOT
28
+ from learning_munsell.models.networks import MultiHeadMLPToMunsell
29
+ from learning_munsell.utilities.data import (
30
+ MUNSELL_NORMALIZATION_PARAMS,
31
+ normalize_munsell,
32
+ )
33
+
34
+ LOGGER = logging.getLogger(__name__)
35
+
36
+
37
+ def normalize_inputs(X: NDArray, gamma: float) -> tuple[NDArray, dict[str, Any]]:
38
+ """
39
+ Normalize xyY inputs to [0, 1] range with gamma correction on Y.
40
+
41
+ Parameters
42
+ ----------
43
+ X : ndarray
44
+ xyY values of shape (n, 3) where columns are [x, y, Y].
45
+ gamma : float
46
+ Gamma value to apply to Y component.
47
+
48
+ Returns
49
+ -------
50
+ ndarray
51
+ Normalized values with gamma-corrected Y, dtype float32.
52
+ dict
53
+ Normalization parameters including gamma value.
54
+ """
55
+ x_range = (0.0, 1.0)
56
+ y_range = (0.0, 1.0)
57
+ Y_range = (0.0, 1.0)
58
+
59
+ X_norm = X.copy()
60
+ X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0])
61
+ X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0])
62
+
63
+ Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0])
64
+ Y_normalized = np.clip(Y_normalized, 0, 1)
65
+ X_norm[:, 2] = np.power(Y_normalized, 1.0 / gamma)
66
+
67
+ params = {
68
+ "x_range": x_range,
69
+ "y_range": y_range,
70
+ "Y_range": Y_range,
71
+ "gamma": gamma,
72
+ }
73
+
74
+ return X_norm, params
75
+
76
+
77
+ def denormalize_output(y_norm: NDArray, params: dict[str, Any]) -> NDArray:
78
+ """
79
+ Denormalize Munsell output from [0, 1] to original ranges.
80
+
81
+ Parameters
82
+ ----------
83
+ y_norm : ndarray
84
+ Normalized Munsell values in [0, 1] range.
85
+ params : dict
86
+ Normalization parameters containing range information.
87
+
88
+ Returns
89
+ -------
90
+ ndarray
91
+ Denormalized Munsell values in original ranges.
92
+ """
93
+ y = np.copy(y_norm)
94
+ y[..., 0] = (
95
+ y_norm[..., 0] * (params["hue_range"][1] - params["hue_range"][0])
96
+ + params["hue_range"][0]
97
+ )
98
+ y[..., 1] = (
99
+ y_norm[..., 1] * (params["value_range"][1] - params["value_range"][0])
100
+ + params["value_range"][0]
101
+ )
102
+ y[..., 2] = (
103
+ y_norm[..., 2] * (params["chroma_range"][1] - params["chroma_range"][0])
104
+ + params["chroma_range"][0]
105
+ )
106
+ y[..., 3] = (
107
+ y_norm[..., 3] * (params["code_range"][1] - params["code_range"][0])
108
+ + params["code_range"][0]
109
+ )
110
+ return y
111
+
112
+
113
+ def weighted_mse_loss(
114
+ pred: torch.Tensor, target: torch.Tensor, weights: torch.Tensor = None
115
+ ) -> torch.Tensor:
116
+ """
117
+ Component-wise weighted MSE loss.
118
+
119
+ Parameters
120
+ ----------
121
+ pred : Tensor
122
+ Predicted Munsell values.
123
+ target : Tensor
124
+ Ground truth Munsell values.
125
+ weights : Tensor, optional
126
+ Component weights [w_hue, w_value, w_chroma, w_code].
127
+
128
+ Returns
129
+ -------
130
+ Tensor
131
+ Weighted mean squared error loss.
132
+ """
133
+ if weights is None:
134
+ weights = torch.tensor([1.0, 1.0, 3.0, 0.5], device=pred.device)
135
+ mse = (pred - target) ** 2
136
+ weighted_mse = mse * weights
137
+ return weighted_mse.mean()
138
+
139
+
140
+ def clamp_munsell_specification(spec: NDArray) -> NDArray:
141
+ """
142
+ Clamp Munsell specification to valid ranges.
143
+
144
+ Parameters
145
+ ----------
146
+ spec : ndarray
147
+ Munsell specification [hue, value, chroma, code].
148
+
149
+ Returns
150
+ -------
151
+ ndarray
152
+ Clamped Munsell specification within valid ranges.
153
+ """
154
+ clamped = np.copy(spec)
155
+ clamped[..., 0] = np.clip(spec[..., 0], 0.5, 10.0)
156
+ clamped[..., 1] = np.clip(spec[..., 1], 1.0, 9.0)
157
+ clamped[..., 2] = np.clip(spec[..., 2], 0.0, 50.0)
158
+ clamped[..., 3] = np.clip(spec[..., 3], 1.0, 10.0)
159
+ return clamped
160
+
161
+
162
+ def compute_delta_e(pred: NDArray, reference_Lab: NDArray) -> list[float]:
163
+ """
164
+ Compute Delta-E CIE2000 for predicted Munsell specifications.
165
+
166
+ Parameters
167
+ ----------
168
+ pred : ndarray
169
+ Predicted Munsell specifications.
170
+ reference_Lab : ndarray
171
+ Reference CIELAB values for comparison.
172
+
173
+ Returns
174
+ -------
175
+ list of float
176
+ Delta-E CIE2000 values for valid predictions.
177
+
178
+ Notes
179
+ -----
180
+ Predictions that cannot be converted to valid xyY are skipped.
181
+ """
182
+ delta_E_values = []
183
+ for idx in range(len(pred)):
184
+ try:
185
+ ml_spec = clamp_munsell_specification(pred[idx])
186
+ ml_spec_for_conversion = ml_spec.copy()
187
+ ml_spec_for_conversion[3] = round(ml_spec[3])
188
+ ml_xyy = munsell_specification_to_xyY(ml_spec_for_conversion)
189
+ ml_XYZ = xyY_to_XYZ(ml_xyy)
190
+ ml_Lab = XYZ_to_Lab(ml_XYZ, CCS_ILLUMINANT_MUNSELL)
191
+ delta_E = delta_E_CIE2000(reference_Lab[idx], ml_Lab)
192
+ delta_E_values.append(delta_E)
193
+ except (RuntimeError, ValueError):
194
+ continue
195
+ return delta_E_values
196
+
197
+
198
+ def train_model(
199
+ gamma: float,
200
+ X_train: NDArray,
201
+ y_train: NDArray,
202
+ X_val: NDArray,
203
+ y_val: NDArray,
204
+ device: torch.device,
205
+ num_epochs: int = 100,
206
+ patience: int = 15,
207
+ ) -> tuple[nn.Module, dict[str, Any], dict[str, Any], float]:
208
+ """
209
+ Train a multi-head model with specified gamma value.
210
+
211
+ Parameters
212
+ ----------
213
+ gamma : float
214
+ Gamma value for Y correction.
215
+ X_train : ndarray
216
+ Training inputs (xyY values).
217
+ y_train : ndarray
218
+ Training targets (Munsell specifications).
219
+ X_val : ndarray
220
+ Validation inputs.
221
+ y_val : ndarray
222
+ Validation targets.
223
+ device : torch.device
224
+ Device to run training on.
225
+ num_epochs : int, optional
226
+ Maximum number of training epochs. Default is 100.
227
+ patience : int, optional
228
+ Early stopping patience. Default is 15.
229
+
230
+ Returns
231
+ -------
232
+ nn.Module
233
+ Trained model with best validation loss.
234
+ dict
235
+ Input normalization parameters.
236
+ dict
237
+ Output normalization parameters.
238
+ float
239
+ Best validation loss achieved.
240
+ """
241
+ # Normalize data
242
+ X_train_norm, input_params = normalize_inputs(X_train, gamma=gamma)
243
+ X_val_norm, _ = normalize_inputs(X_val, gamma=gamma)
244
+
245
+ # Use shared normalization parameters covering the full Munsell space for generalization
246
+ output_params = MUNSELL_NORMALIZATION_PARAMS
247
+ y_train_norm = normalize_munsell(y_train, output_params)
248
+ y_val_norm = normalize_munsell(y_val, output_params)
249
+
250
+ # Convert to tensors
251
+ X_train_t = torch.FloatTensor(X_train_norm)
252
+ y_train_t = torch.FloatTensor(y_train_norm)
253
+ X_val_t = torch.FloatTensor(X_val_norm)
254
+ y_val_t = torch.FloatTensor(y_val_norm)
255
+
256
+ # Create data loaders
257
+ train_dataset = TensorDataset(X_train_t, y_train_t)
258
+ val_dataset = TensorDataset(X_val_t, y_val_t)
259
+
260
+ train_loader = DataLoader(train_dataset, batch_size=1024, shuffle=True)
261
+ val_loader = DataLoader(val_dataset, batch_size=1024, shuffle=False)
262
+
263
+ # Initialize model
264
+ model = MultiHeadMLPToMunsell().to(device)
265
+ optimizer = optim.Adam(model.parameters(), lr=5e-4)
266
+ criterion = weighted_mse_loss
267
+
268
+ best_val_loss = float("inf")
269
+ patience_counter = 0
270
+ best_state = None
271
+
272
+ for epoch in range(num_epochs):
273
+ # Train
274
+ model.train()
275
+ for X_batch, y_batch in train_loader:
276
+ X_batch = X_batch.to(device)
277
+ y_batch = y_batch.to(device)
278
+
279
+ outputs = model(X_batch)
280
+ loss = criterion(outputs, y_batch)
281
+
282
+ optimizer.zero_grad()
283
+ loss.backward()
284
+ optimizer.step()
285
+
286
+ # Validate
287
+ model.eval()
288
+ total_val_loss = 0.0
289
+ with torch.no_grad():
290
+ for X_batch, y_batch in val_loader:
291
+ X_batch = X_batch.to(device)
292
+ y_batch = y_batch.to(device)
293
+ outputs = model(X_batch)
294
+ loss = criterion(outputs, y_batch)
295
+ total_val_loss += loss.item()
296
+ val_loss = total_val_loss / len(val_loader)
297
+
298
+ if val_loss < best_val_loss:
299
+ best_val_loss = val_loss
300
+ patience_counter = 0
301
+ best_state = model.state_dict().copy()
302
+ else:
303
+ patience_counter += 1
304
+ if patience_counter >= patience:
305
+ break
306
+
307
+ # Load best state
308
+ if best_state is not None:
309
+ model.load_state_dict(best_state)
310
+
311
+ return model, input_params, output_params, best_val_loss
312
+
313
+
314
+ def evaluate_on_real_munsell(
315
+ model: nn.Module,
316
+ input_params: dict[str, Any],
317
+ output_params: dict[str, Any],
318
+ xyY_array: NDArray,
319
+ reference_Lab: NDArray,
320
+ device: torch.device,
321
+ ) -> tuple[float, float]:
322
+ """
323
+ Evaluate model on real Munsell colors using Delta-E CIE2000.
324
+
325
+ Parameters
326
+ ----------
327
+ model : nn.Module
328
+ Trained model to evaluate.
329
+ input_params : dict
330
+ Input normalization parameters.
331
+ output_params : dict
332
+ Output normalization parameters.
333
+ xyY_array : ndarray
334
+ Real Munsell xyY values.
335
+ reference_Lab : ndarray
336
+ Reference CIELAB values for Delta-E computation.
337
+ device : torch.device
338
+ Device to run evaluation on.
339
+
340
+ Returns
341
+ -------
342
+ float
343
+ Mean Delta-E CIE2000.
344
+ float
345
+ Median Delta-E CIE2000.
346
+ """
347
+ model.eval()
348
+ gamma = input_params["gamma"]
349
+
350
+ # Normalize inputs
351
+ X_norm, _ = normalize_inputs(xyY_array, gamma=gamma)
352
+ X_t = torch.FloatTensor(X_norm).to(device)
353
+
354
+ # Predict
355
+ with torch.no_grad():
356
+ pred_norm = model(X_t).cpu().numpy()
357
+
358
+ pred = denormalize_output(pred_norm, output_params)
359
+ delta_E_values = compute_delta_e(pred, reference_Lab)
360
+
361
+ return np.mean(delta_E_values), np.median(delta_E_values)
362
+
363
+
364
+ def run_single_trial(
365
+ gamma: float,
366
+ run_id: int,
367
+ X_train: NDArray,
368
+ y_train: NDArray,
369
+ X_val: NDArray,
370
+ y_val: NDArray,
371
+ xyY_array: NDArray,
372
+ reference_Lab: NDArray,
373
+ ) -> dict[str, Any]:
374
+ """
375
+ Run a single training trial for a given gamma value.
376
+
377
+ Parameters
378
+ ----------
379
+ gamma : float
380
+ Gamma value for Y correction.
381
+ run_id : int
382
+ Run identifier for this trial.
383
+ X_train : ndarray
384
+ Training inputs.
385
+ y_train : ndarray
386
+ Training targets.
387
+ X_val : ndarray
388
+ Validation inputs.
389
+ y_val : ndarray
390
+ Validation targets.
391
+ xyY_array : ndarray
392
+ Real Munsell xyY values for evaluation.
393
+ reference_Lab : ndarray
394
+ Reference CIELAB values for Delta-E computation.
395
+
396
+ Returns
397
+ -------
398
+ dict
399
+ Results dictionary containing gamma, run_id, val_loss,
400
+ mean_delta_e, and median_delta_e.
401
+
402
+ Notes
403
+ -----
404
+ Uses CPU to avoid MPS multiprocessing issues.
405
+ """
406
+ # Each process uses CPU to avoid MPS multiprocessing issues
407
+ device = torch.device("cpu")
408
+
409
+ model, input_params, output_params, val_loss = train_model(
410
+ gamma=gamma,
411
+ X_train=X_train,
412
+ y_train=y_train,
413
+ X_val=X_val,
414
+ y_val=y_val,
415
+ device=device,
416
+ num_epochs=100,
417
+ patience=15,
418
+ )
419
+
420
+ mean_delta_e, median_delta_e = evaluate_on_real_munsell(
421
+ model, input_params, output_params, xyY_array, reference_Lab, device
422
+ )
423
+
424
+ return {
425
+ "gamma": gamma,
426
+ "run_id": run_id,
427
+ "val_loss": val_loss,
428
+ "mean_delta_e": mean_delta_e,
429
+ "median_delta_e": median_delta_e,
430
+ }
431
+
432
+
433
+ def main() -> None:
434
+ """
435
+ Run gamma sweep experiment to find optimal gamma value.
436
+
437
+ Notes
438
+ -----
439
+ The training pipeline:
440
+ 1. Loads training and validation data from cache
441
+ 2. Loads real Munsell colors for evaluation
442
+ 3. Sweeps gamma values from 1.0 to 3.0 in 0.1 increments
443
+ 4. Trains multiple models per gamma value for averaging
444
+ 5. Evaluates each model on real Munsell colors using Delta-E CIE2000
445
+ 6. Aggregates results and identifies best gamma value
446
+ 7. Saves results to NPZ file for analysis
447
+
448
+ Uses parallel execution with ProcessPoolExecutor for efficiency.
449
+ Each model is trained with early stopping and evaluated on validation set.
450
+ """
451
+ import argparse
452
+
453
+ parser = argparse.ArgumentParser(description="Gamma sweep with averaging")
454
+ parser.add_argument("--runs", type=int, default=3, help="Number of runs per gamma")
455
+ parser.add_argument("--workers", type=int, default=4, help="Number of parallel workers")
456
+ args = parser.parse_args()
457
+
458
+ num_runs = args.runs
459
+ num_workers = args.workers
460
+
461
+ LOGGER.info("=" * 80)
462
+ LOGGER.info("Multi-Head Gamma Sweep: Finding Optimal Gamma Value")
463
+ LOGGER.info("Testing gamma values from 1.0 to 3.0 in increments of 0.1")
464
+ LOGGER.info("Runs per gamma: %d, Parallel workers: %d", num_runs, num_workers)
465
+ LOGGER.info("=" * 80)
466
+
467
+ # Load training data
468
+ data_dir = PROJECT_ROOT / "data"
469
+ cache_file = data_dir / "training_data.npz"
470
+
471
+ if not cache_file.exists():
472
+ LOGGER.error("Error: Training data not found at %s", cache_file)
473
+ return
474
+
475
+ LOGGER.info("\nLoading training data...")
476
+ data = np.load(cache_file)
477
+ X_train = data["X_train"]
478
+ y_train = data["y_train"]
479
+ X_val = data["X_val"]
480
+ y_val = data["y_val"]
481
+ LOGGER.info("Train samples: %d, Validation samples: %d", len(X_train), len(X_val))
482
+
483
+ # Load real Munsell data for evaluation
484
+ LOGGER.info("Loading real Munsell colours for evaluation...")
485
+ xyY_values = []
486
+ reference_Lab = []
487
+
488
+ for munsell_spec_tuple, xyY in MUNSELL_COLOURS_REAL:
489
+ try:
490
+ xyY_scaled = np.array([xyY[0], xyY[1], xyY[2] / 100.0])
491
+ XYZ = xyY_to_XYZ(xyY_scaled)
492
+ Lab = XYZ_to_Lab(XYZ, CCS_ILLUMINANT_MUNSELL)
493
+ xyY_values.append(xyY_scaled)
494
+ reference_Lab.append(Lab)
495
+ except (RuntimeError, ValueError):
496
+ continue
497
+
498
+ xyY_array = np.array(xyY_values)
499
+ reference_Lab = np.array(reference_Lab)
500
+ LOGGER.info("Loaded %d real Munsell colours", len(xyY_array))
501
+
502
+ # Gamma values to test
503
+ gamma_values = [round(1.0 + i * 0.1, 1) for i in range(21)] # 1.0 to 3.0
504
+
505
+ # Create all tasks: (gamma, run_id) pairs
506
+ tasks = [(gamma, run_id) for gamma in gamma_values for run_id in range(num_runs)]
507
+ total_tasks = len(tasks)
508
+
509
+ LOGGER.info("\n" + "-" * 80)
510
+ LOGGER.info("Starting gamma sweep: %d total tasks (%d gamma values x %d runs)",
511
+ total_tasks, len(gamma_values), num_runs)
512
+ LOGGER.info("-" * 80)
513
+
514
+ all_results = []
515
+ completed = 0
516
+
517
+ with ProcessPoolExecutor(max_workers=num_workers) as executor:
518
+ futures = {
519
+ executor.submit(
520
+ run_single_trial, gamma, run_id,
521
+ X_train, y_train, X_val, y_val, xyY_array, reference_Lab
522
+ ): (gamma, run_id)
523
+ for gamma, run_id in tasks
524
+ }
525
+
526
+ for future in as_completed(futures):
527
+ gamma, run_id = futures[future]
528
+ try:
529
+ result = future.result()
530
+ all_results.append(result)
531
+ completed += 1
532
+ LOGGER.info(
533
+ "[%3d/%3d] gamma=%.1f run=%d: mean_ΔE=%.4f, median_ΔE=%.4f",
534
+ completed, total_tasks, gamma, run_id,
535
+ result["mean_delta_e"], result["median_delta_e"]
536
+ )
537
+ except Exception as e:
538
+ LOGGER.error("Task failed for gamma=%.1f run=%d: %s", gamma, run_id, e)
539
+ completed += 1
540
+
541
+ # Aggregate results by gamma (average across runs)
542
+ aggregated = {}
543
+ for r in all_results:
544
+ gamma = r["gamma"]
545
+ if gamma not in aggregated:
546
+ aggregated[gamma] = {"val_losses": [], "means": [], "medians": []}
547
+ aggregated[gamma]["val_losses"].append(r["val_loss"])
548
+ aggregated[gamma]["means"].append(r["mean_delta_e"])
549
+ aggregated[gamma]["medians"].append(r["median_delta_e"])
550
+
551
+ results = []
552
+ for gamma in sorted(aggregated.keys()):
553
+ agg = aggregated[gamma]
554
+ results.append({
555
+ "gamma": gamma,
556
+ "val_loss": np.mean(agg["val_losses"]),
557
+ "val_loss_std": np.std(agg["val_losses"]),
558
+ "mean_delta_e": np.mean(agg["means"]),
559
+ "mean_delta_e_std": np.std(agg["means"]),
560
+ "median_delta_e": np.mean(agg["medians"]),
561
+ "median_delta_e_std": np.std(agg["medians"]),
562
+ "num_runs": len(agg["means"]),
563
+ })
564
+
565
+ # Print results
566
+ LOGGER.info("\n" + "=" * 80)
567
+ LOGGER.info("GAMMA SWEEP RESULTS (averaged over %d runs)", num_runs)
568
+ LOGGER.info("=" * 80)
569
+ LOGGER.info("")
570
+ LOGGER.info("%-8s %-14s %-14s %-14s", "Gamma", "Val Loss", "Mean ΔE", "Median ΔE")
571
+ LOGGER.info("-" * 50)
572
+
573
+ for r in results:
574
+ LOGGER.info(
575
+ "%-8.1f %-14s %-14s %-14s",
576
+ r["gamma"],
577
+ f"{r['val_loss']:.6f}±{r['val_loss_std']:.4f}",
578
+ f"{r['mean_delta_e']:.4f}±{r['mean_delta_e_std']:.4f}",
579
+ f"{r['median_delta_e']:.4f}±{r['median_delta_e_std']:.4f}",
580
+ )
581
+
582
+ # Find best by mean Delta-E
583
+ best_by_mean = min(results, key=lambda x: x["mean_delta_e"])
584
+ best_by_median = min(results, key=lambda x: x["median_delta_e"])
585
+
586
+ LOGGER.info("")
587
+ LOGGER.info("Best gamma by MEAN Delta-E: %.1f (ΔE = %.4f ± %.4f)",
588
+ best_by_mean["gamma"], best_by_mean["mean_delta_e"],
589
+ best_by_mean["mean_delta_e_std"])
590
+ LOGGER.info("Best gamma by MEDIAN Delta-E: %.1f (ΔE = %.4f ± %.4f)",
591
+ best_by_median["gamma"], best_by_median["median_delta_e"],
592
+ best_by_median["median_delta_e_std"])
593
+
594
+ # Save results
595
+ results_file = PROJECT_ROOT / "models" / "from_xyY" / "gamma_sweep_results_averaged.npz"
596
+ np.savez(results_file, results=results, all_results=all_results)
597
+ LOGGER.info("\nResults saved to: %s", results_file)
598
+
599
+ LOGGER.info("\n" + "=" * 80)
600
+
601
+
602
+ if __name__ == "__main__":
603
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
604
+
605
+ main()
learning_munsell/training/from_xyY/train_multi_head_large.py ADDED
@@ -0,0 +1,246 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train multi-head ML model on large dataset (2M samples) for xyY to Munsell conversion.
3
+
4
+ This script trains on the larger dataset for potentially improved accuracy.
5
+ Uses the same architecture as train_multi_head_mlp.py but with the large dataset.
6
+ """
7
+
8
+ import logging
9
+
10
+ import click
11
+ import mlflow
12
+ import mlflow.pytorch
13
+ import numpy as np
14
+ import torch
15
+ from numpy.typing import NDArray
16
+ from torch import nn, optim
17
+ from torch.utils.data import DataLoader, TensorDataset
18
+
19
+ from learning_munsell import PROJECT_ROOT
20
+ from learning_munsell.models.networks import MultiHeadMLPToMunsell
21
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
22
+ from learning_munsell.utilities.data import (
23
+ MUNSELL_NORMALIZATION_PARAMS,
24
+ XYY_NORMALIZATION_PARAMS,
25
+ normalize_munsell,
26
+ )
27
+ from learning_munsell.utilities.losses import weighted_mse_loss
28
+ from learning_munsell.utilities.training import train_epoch, validate
29
+
30
+ LOGGER = logging.getLogger(__name__)
31
+
32
+
33
+ @click.command()
34
+ @click.option("--epochs", default=300, help="Number of training epochs")
35
+ @click.option("--batch-size", default=2048, help="Batch size for training")
36
+ @click.option("--lr", default=5e-4, help="Learning rate")
37
+ @click.option("--patience", default=30, help="Early stopping patience")
38
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
39
+ """
40
+ Train multi-head MLP on large dataset (2M samples) for xyY to Munsell.
41
+
42
+ Notes
43
+ -----
44
+ The training pipeline:
45
+ 1. Loads training and validation data from large cached .npz file
46
+ 2. Normalizes xyY inputs (already [0,1]) and Munsell outputs to [0,1]
47
+ 3. Creates multi-head MLP with shared encoder and component-specific heads
48
+ 4. Trains with weighted MSE loss (emphasizing chroma)
49
+ 5. Uses Adam optimizer with ReduceLROnPlateau scheduler
50
+ 6. Applies early stopping based on validation loss (patience=30)
51
+ 7. Exports best model to ONNX format
52
+ 8. Logs metrics and artifacts to MLflow
53
+ """
54
+
55
+ LOGGER.info("=" * 80)
56
+ LOGGER.info("Multi-Head Model Training on Large Dataset (2M samples)")
57
+ LOGGER.info("=" * 80)
58
+
59
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
60
+ if torch.backends.mps.is_available():
61
+ device = torch.device("mps")
62
+ LOGGER.info("Using device: %s", device)
63
+
64
+ # Load large training data
65
+ data_dir = PROJECT_ROOT / "data"
66
+ cache_file = data_dir / "training_data_large.npz"
67
+
68
+ if not cache_file.exists():
69
+ LOGGER.error("Error: Large training data not found at %s", cache_file)
70
+ LOGGER.error("Please run generate_large_training_data.py first")
71
+ return
72
+
73
+ LOGGER.info("Loading large training data from %s...", cache_file)
74
+ data = np.load(cache_file)
75
+
76
+ X_train = data["X_train"]
77
+ y_train = data["y_train"]
78
+ X_val = data["X_val"]
79
+ y_val = data["y_val"]
80
+
81
+ LOGGER.info("Train samples: %d", len(X_train))
82
+ LOGGER.info("Validation samples: %d", len(X_val))
83
+
84
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
85
+ # Use shared normalization parameters covering the full Munsell space for generalization
86
+ output_params = MUNSELL_NORMALIZATION_PARAMS
87
+ y_train_norm = normalize_munsell(y_train, output_params)
88
+ y_val_norm = normalize_munsell(y_val, output_params)
89
+
90
+ # Convert to PyTorch tensors
91
+ X_train_t = torch.FloatTensor(X_train)
92
+ y_train_t = torch.FloatTensor(y_train_norm)
93
+ X_val_t = torch.FloatTensor(X_val)
94
+ y_val_t = torch.FloatTensor(y_val_norm)
95
+
96
+ # Create data loaders (larger batch size for larger dataset)
97
+ train_dataset = TensorDataset(X_train_t, y_train_t)
98
+ val_dataset = TensorDataset(X_val_t, y_val_t)
99
+
100
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
101
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
102
+
103
+ # Initialize model
104
+ model = MultiHeadMLPToMunsell().to(device)
105
+ LOGGER.info("")
106
+ LOGGER.info("Model architecture:")
107
+ LOGGER.info("%s", model)
108
+
109
+ total_params = sum(p.numel() for p in model.parameters())
110
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
111
+
112
+ # Training setup
113
+ learning_rate = lr
114
+ optimizer = optim.Adam(model.parameters(), lr=learning_rate)
115
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
116
+ optimizer, mode="min", factor=0.5, patience=10
117
+ )
118
+ criterion = weighted_mse_loss
119
+
120
+ # MLflow setup
121
+ run_name = setup_mlflow_experiment("from_xyY", "multi_head_large")
122
+
123
+ LOGGER.info("")
124
+ LOGGER.info("MLflow run: %s", run_name)
125
+
126
+ # Training loop
127
+ best_val_loss = float("inf")
128
+ patience_counter = 0
129
+
130
+ LOGGER.info("")
131
+ LOGGER.info("Starting training...")
132
+
133
+ with mlflow.start_run(run_name=run_name):
134
+ mlflow.log_params(
135
+ {
136
+ "model": "multi_head_large",
137
+ "learning_rate": learning_rate,
138
+ "batch_size": batch_size,
139
+ "num_epochs": epochs,
140
+ "patience": patience,
141
+ "total_params": total_params,
142
+ "train_samples": len(X_train),
143
+ "val_samples": len(X_val),
144
+ "dataset": "large_2M",
145
+ }
146
+ )
147
+
148
+ for epoch in range(epochs):
149
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
150
+ val_loss = validate(model, val_loader, criterion, device)
151
+
152
+ scheduler.step(val_loss)
153
+
154
+ log_training_epoch(
155
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
156
+ )
157
+
158
+ LOGGER.info(
159
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
160
+ epoch + 1,
161
+ epochs,
162
+ train_loss,
163
+ val_loss,
164
+ optimizer.param_groups[0]["lr"],
165
+ )
166
+
167
+ if val_loss < best_val_loss:
168
+ best_val_loss = val_loss
169
+ patience_counter = 0
170
+
171
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
172
+ model_directory.mkdir(exist_ok=True)
173
+ checkpoint_file = model_directory / "multi_head_large_best.pth"
174
+
175
+ torch.save(
176
+ {
177
+ "model_state_dict": model.state_dict(),
178
+ "output_params": output_params,
179
+ "epoch": epoch,
180
+ "val_loss": val_loss,
181
+ },
182
+ checkpoint_file,
183
+ )
184
+
185
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
186
+ else:
187
+ patience_counter += 1
188
+ if patience_counter >= patience:
189
+ LOGGER.info("")
190
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
191
+ break
192
+
193
+ mlflow.log_metrics(
194
+ {
195
+ "best_val_loss": best_val_loss,
196
+ "final_epoch": epoch + 1,
197
+ }
198
+ )
199
+
200
+ # Export to ONNX
201
+ LOGGER.info("")
202
+ LOGGER.info("Exporting model to ONNX...")
203
+ model.eval()
204
+
205
+ checkpoint = torch.load(checkpoint_file, weights_only=False)
206
+ model.load_state_dict(checkpoint["model_state_dict"])
207
+
208
+ dummy_input = torch.randn(1, 3).to(device)
209
+
210
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
211
+ onnx_file = model_directory / "multi_head_large.onnx"
212
+ torch.onnx.export(
213
+ model,
214
+ dummy_input,
215
+ onnx_file,
216
+ export_params=True,
217
+ opset_version=15,
218
+ input_names=["xyY"],
219
+ output_names=["munsell_spec"],
220
+ dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
221
+ )
222
+
223
+ params_file = model_directory / "multi_head_large_normalization_params.npz"
224
+ input_params = XYY_NORMALIZATION_PARAMS
225
+ np.savez(
226
+ params_file,
227
+ input_params=input_params,
228
+ output_params=output_params,
229
+ )
230
+
231
+ mlflow.log_artifact(str(checkpoint_file))
232
+ mlflow.log_artifact(str(onnx_file))
233
+ mlflow.log_artifact(str(params_file))
234
+ mlflow.pytorch.log_model(model, "model")
235
+
236
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
237
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
238
+ LOGGER.info("Artifacts logged to MLflow")
239
+
240
+ LOGGER.info("=" * 80)
241
+
242
+
243
+ if __name__ == "__main__":
244
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
245
+
246
+ main()
learning_munsell/training/from_xyY/train_multi_head_mlp.py ADDED
@@ -0,0 +1,269 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train multi-head ML model for xyY to Munsell conversion.
3
+
4
+ Architecture:
5
+ - Shared encoder: 3 inputs → 512-dim features
6
+ - 4 separate heads (one per component):
7
+ - Hue head (circular/angular)
8
+ - Value head (linear lightness)
9
+ - Chroma head (non-linear saturation - larger capacity)
10
+ - Code head (discrete categorical)
11
+
12
+ This architecture allows each component to learn specialized features
13
+ while sharing the general color space understanding.
14
+ """
15
+
16
+ import logging
17
+ import click
18
+ import mlflow
19
+ import mlflow.pytorch
20
+ import numpy as np
21
+ import torch
22
+ from numpy.typing import NDArray
23
+ from torch import nn, optim
24
+ from torch.utils.data import DataLoader, TensorDataset
25
+
26
+ from learning_munsell import PROJECT_ROOT
27
+ from learning_munsell.models.networks import MultiHeadMLPToMunsell
28
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
29
+ from learning_munsell.utilities.data import (
30
+ MUNSELL_NORMALIZATION_PARAMS,
31
+ XYY_NORMALIZATION_PARAMS,
32
+ normalize_munsell,
33
+ )
34
+ from learning_munsell.utilities.losses import weighted_mse_loss
35
+ from learning_munsell.utilities.training import train_epoch, validate
36
+
37
+ LOGGER = logging.getLogger(__name__)
38
+
39
+
40
+ @click.command()
41
+ @click.option("--epochs", default=200, help="Number of training epochs")
42
+ @click.option("--batch-size", default=1024, help="Batch size for training")
43
+ @click.option("--lr", default=5e-4, help="Learning rate")
44
+ @click.option("--patience", default=20, help="Early stopping patience")
45
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
46
+ """
47
+ Train multi-head MLP for xyY to Munsell conversion.
48
+
49
+ Notes
50
+ -----
51
+ The training pipeline:
52
+ 1. Loads training and validation data from cached .npz file
53
+ 2. Normalizes xyY inputs (already [0,1]) and Munsell outputs to [0,1]
54
+ 3. Creates multi-head MLP with shared encoder and component-specific heads
55
+ 4. Trains with weighted MSE loss (emphasizing chroma)
56
+ 5. Uses Adam optimizer with no learning rate scheduling
57
+ 6. Applies early stopping based on validation loss (patience=20)
58
+ 7. Exports best model to ONNX format
59
+ 8. Logs metrics and artifacts to MLflow
60
+ """
61
+
62
+
63
+ LOGGER.info("=" * 80)
64
+ LOGGER.info("ML-Based xyY to Munsell Conversion: Multi-Head Model Training")
65
+ LOGGER.info("=" * 80)
66
+
67
+ # Set device
68
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
69
+ LOGGER.info("Using device: %s", device)
70
+
71
+ # Load training data
72
+ data_dir = PROJECT_ROOT / "data"
73
+ cache_file = data_dir / "training_data.npz"
74
+
75
+ if not cache_file.exists():
76
+ LOGGER.error("Error: Training data not found at %s", cache_file)
77
+ LOGGER.error("Please run 01_generate_training_data.py first")
78
+ return
79
+
80
+ LOGGER.info("Loading training data from %s...", cache_file)
81
+ data = np.load(cache_file)
82
+
83
+ X_train = data["X_train"]
84
+ y_train = data["y_train"]
85
+ X_val = data["X_val"]
86
+ y_val = data["y_val"]
87
+
88
+ LOGGER.info("Train samples: %d", len(X_train))
89
+ LOGGER.info("Validation samples: %d", len(X_val))
90
+
91
+ # Normalize outputs (xyY inputs are already in [0, 1] range)
92
+ # Use shared normalization parameters covering the full Munsell space for generalization
93
+ output_params = MUNSELL_NORMALIZATION_PARAMS
94
+ y_train_norm = normalize_munsell(y_train, output_params)
95
+ y_val_norm = normalize_munsell(y_val, output_params)
96
+
97
+ # Convert to PyTorch tensors
98
+ X_train_t = torch.FloatTensor(X_train)
99
+ y_train_t = torch.FloatTensor(y_train_norm)
100
+ X_val_t = torch.FloatTensor(X_val)
101
+ y_val_t = torch.FloatTensor(y_val_norm)
102
+
103
+ # Create data loaders
104
+ train_dataset = TensorDataset(X_train_t, y_train_t)
105
+ val_dataset = TensorDataset(X_val_t, y_val_t)
106
+
107
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
108
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
109
+
110
+ # Initialize model
111
+ model = MultiHeadMLPToMunsell().to(device)
112
+ LOGGER.info("")
113
+ LOGGER.info("Model architecture:")
114
+ LOGGER.info("%s", model)
115
+
116
+ # Count parameters
117
+ total_params = sum(p.numel() for p in model.parameters())
118
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
119
+
120
+ # Count parameters per component
121
+ encoder_params = sum(p.numel() for p in model.encoder.parameters())
122
+ hue_params = sum(p.numel() for p in model.hue_head.parameters())
123
+ value_params = sum(p.numel() for p in model.value_head.parameters())
124
+ chroma_params = sum(p.numel() for p in model.chroma_head.parameters())
125
+ code_params = sum(p.numel() for p in model.code_head.parameters())
126
+
127
+ LOGGER.info(" - Shared encoder: %s", f"{encoder_params:,}")
128
+ LOGGER.info(" - Hue head: %s", f"{hue_params:,}")
129
+ LOGGER.info(" - Value head: %s", f"{value_params:,}")
130
+ LOGGER.info(" - Chroma head: %s (WIDER)", f"{chroma_params:,}")
131
+ LOGGER.info(" - Code head: %s", f"{code_params:,}")
132
+
133
+ # Training setup
134
+ optimizer = optim.Adam(model.parameters(), lr=lr)
135
+ # Use weighted MSE with default weights
136
+ weights = torch.tensor([1.0, 1.0, 3.0, 0.5])
137
+ criterion = lambda pred, target: weighted_mse_loss(pred, target, weights)
138
+
139
+ # MLflow setup
140
+ run_name = setup_mlflow_experiment("from_xyY", "multi_head")
141
+
142
+ LOGGER.info("")
143
+ LOGGER.info("MLflow run: %s", run_name)
144
+
145
+ # Training loop
146
+ best_val_loss = float("inf")
147
+ patience_counter = 0
148
+
149
+ LOGGER.info("")
150
+ LOGGER.info("Starting training...")
151
+
152
+ with mlflow.start_run(run_name=run_name):
153
+ # Log parameters
154
+ mlflow.log_params(
155
+ {
156
+ "model": "multi_head",
157
+ "learning_rate": lr,
158
+ "batch_size": batch_size,
159
+ "num_epochs": epochs,
160
+ "patience": patience,
161
+ "total_params": total_params,
162
+ }
163
+ )
164
+
165
+ for epoch in range(epochs):
166
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
167
+ val_loss = validate(model, val_loader, criterion, device)
168
+
169
+ # Log to MLflow
170
+ log_training_epoch(
171
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
172
+ )
173
+
174
+ LOGGER.info(
175
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
176
+ epoch + 1,
177
+ epochs,
178
+ train_loss,
179
+ val_loss,
180
+ )
181
+
182
+ # Early stopping
183
+ if val_loss < best_val_loss:
184
+ best_val_loss = val_loss
185
+ patience_counter = 0
186
+
187
+ # Save best model
188
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
189
+ model_directory.mkdir(exist_ok=True)
190
+ checkpoint_file = model_directory / "multi_head_best.pth"
191
+
192
+ torch.save(
193
+ {
194
+ "model_state_dict": model.state_dict(),
195
+ "output_params": output_params,
196
+ "epoch": epoch,
197
+ "val_loss": val_loss,
198
+ },
199
+ checkpoint_file,
200
+ )
201
+
202
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
203
+ else:
204
+ patience_counter += 1
205
+ if patience_counter >= patience:
206
+ LOGGER.info("")
207
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
208
+ break
209
+
210
+ # Log final metrics
211
+ mlflow.log_metrics(
212
+ {
213
+ "best_val_loss": best_val_loss,
214
+ "final_epoch": epoch + 1,
215
+ }
216
+ )
217
+
218
+ # Export to ONNX
219
+ LOGGER.info("")
220
+ LOGGER.info("Exporting model to ONNX...")
221
+ model.eval()
222
+
223
+ # Load best model
224
+ checkpoint = torch.load(checkpoint_file)
225
+ model.load_state_dict(checkpoint["model_state_dict"])
226
+
227
+ # Create dummy input
228
+ dummy_input = torch.randn(1, 3).to(device)
229
+
230
+ # Export
231
+ onnx_file = model_directory / "multi_head.onnx"
232
+ torch.onnx.export(
233
+ model,
234
+ dummy_input,
235
+ onnx_file,
236
+ export_params=True,
237
+ opset_version=15,
238
+ input_names=["xyY"],
239
+ output_names=["munsell_spec"],
240
+ dynamic_axes={"xyY": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
241
+ )
242
+
243
+ # Save normalization parameters alongside model
244
+ params_file = model_directory / "multi_head_normalization_params.npz"
245
+ input_params = XYY_NORMALIZATION_PARAMS
246
+ np.savez(
247
+ params_file,
248
+ input_params=input_params,
249
+ output_params=output_params,
250
+ )
251
+
252
+ # Log artifacts to MLflow
253
+ mlflow.log_artifact(str(checkpoint_file))
254
+ mlflow.log_artifact(str(onnx_file))
255
+ mlflow.log_artifact(str(params_file))
256
+ mlflow.pytorch.log_model(model, "model")
257
+
258
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
259
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
260
+ LOGGER.info("Artifacts logged to MLflow")
261
+
262
+
263
+ LOGGER.info("=" * 80)
264
+
265
+
266
+ if __name__ == "__main__":
267
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
268
+
269
+ main()
learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor.py ADDED
@@ -0,0 +1,378 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Multi-Head error predictor for Multi-Head base model.
3
+
4
+ Architecture:
5
+ - 4 independent error correction branches (one per component)
6
+ - Each branch: 7 inputs (xyY + base_pred) → encoder → decoder → 1 error output
7
+ - Chroma branch: WIDER (1.5x capacity for hardest component)
8
+
9
+ Complete independence matches the Multi-Head base model philosophy.
10
+ """
11
+
12
+ import logging
13
+ from pathlib import Path
14
+ from typing import Any
15
+
16
+ import click
17
+ import mlflow
18
+ import mlflow.pytorch
19
+ import numpy as np
20
+ import onnxruntime as ort
21
+ import torch
22
+ from numpy.typing import NDArray
23
+ from torch import nn, optim
24
+ from torch.utils.data import DataLoader, TensorDataset
25
+
26
+ from learning_munsell import PROJECT_ROOT
27
+ from learning_munsell.models.networks import (
28
+ ComponentErrorPredictor,
29
+ MultiHeadErrorPredictorToMunsell,
30
+ )
31
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
32
+ from learning_munsell.utilities.data import normalize_munsell, normalize_xyY
33
+ from learning_munsell.utilities.losses import precision_focused_loss
34
+ from learning_munsell.utilities.training import train_epoch, validate
35
+
36
+ LOGGER = logging.getLogger(__name__)
37
+
38
+
39
+ def load_base_model(
40
+ model_path: Path, params_path: Path
41
+ ) -> tuple[ort.InferenceSession, dict, dict]:
42
+ """
43
+ Load Multi-Head base ONNX model and normalization parameters.
44
+
45
+ Parameters
46
+ ----------
47
+ model_path : Path
48
+ Path to Multi-Head base model ONNX file.
49
+ params_path : Path
50
+ Path to normalization parameters .npz file.
51
+
52
+ Returns
53
+ -------
54
+ session : ort.InferenceSession
55
+ ONNX Runtime inference session.
56
+ input_params : dict
57
+ Input normalization ranges.
58
+ output_params : dict
59
+ Output normalization ranges.
60
+ """
61
+ session = ort.InferenceSession(str(model_path))
62
+ params = np.load(params_path, allow_pickle=True)
63
+ return session, params["input_params"].item(), params["output_params"].item()
64
+
65
+
66
+ @click.command()
67
+ @click.option(
68
+ "--base-model",
69
+ type=click.Path(exists=True, path_type=Path),
70
+ default=None,
71
+ help="Path to Multi-Head base model ONNX file",
72
+ )
73
+ @click.option(
74
+ "--params",
75
+ type=click.Path(exists=True, path_type=Path),
76
+ default=None,
77
+ help="Path to normalization params file",
78
+ )
79
+ @click.option(
80
+ "--epochs",
81
+ type=int,
82
+ default=200,
83
+ help="Number of training epochs",
84
+ )
85
+ @click.option(
86
+ "--batch-size",
87
+ type=int,
88
+ default=1024,
89
+ help="Batch size for training",
90
+ )
91
+ @click.option(
92
+ "--lr",
93
+ type=float,
94
+ default=3e-4,
95
+ help="Learning rate",
96
+ )
97
+ @click.option(
98
+ "--patience",
99
+ type=int,
100
+ default=20,
101
+ help="Patience for early stopping",
102
+ )
103
+ def main(
104
+ base_model: Path | None,
105
+ params: Path | None,
106
+ epochs: int,
107
+ batch_size: int,
108
+ lr: float,
109
+ patience: int,
110
+ ) -> None:
111
+ """
112
+ Train Multi-Head error predictor with 4 independent branches.
113
+
114
+ Parameters
115
+ ----------
116
+ base_model : Path or None
117
+ Path to Multi-Head base model ONNX file. Uses default if None.
118
+ params : Path or None
119
+ Path to normalization parameters. Uses default if None.
120
+
121
+ Notes
122
+ -----
123
+ The training pipeline:
124
+ 1. Loads pre-trained base model
125
+ 2. Generates base model predictions for training data
126
+ 3. Computes residual errors between predictions and targets
127
+ 4. Trains error predictor on these residuals
128
+ 5. Uses precision-focused loss function
129
+ 6. Learning rate scheduling with ReduceLROnPlateau
130
+ 7. Early stopping based on validation loss
131
+ 8. Exports model to ONNX format
132
+ 9. Logs metrics and artifacts to MLflow
133
+ """
134
+
135
+
136
+ LOGGER.info("=" * 80)
137
+ LOGGER.info("Multi-Head Error Predictor: 4 Independent Branches")
138
+ LOGGER.info("=" * 80)
139
+
140
+ # Set device
141
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
142
+ LOGGER.info("Using device: %s", device)
143
+
144
+ # Paths
145
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
146
+ data_dir = PROJECT_ROOT / "data"
147
+
148
+ # Use provided paths or defaults
149
+ if base_model is None:
150
+ base_model = model_directory / "multi_head.onnx"
151
+ if params is None:
152
+ params = model_directory / "multi_head_normalization_params.npz"
153
+
154
+ cache_file = data_dir / "training_data.npz"
155
+
156
+ # Load base model
157
+ LOGGER.info("")
158
+ LOGGER.info("Loading Multi-Head base model from %s...", base_model)
159
+ base_session, input_params, output_params = load_base_model(base_model, params)
160
+
161
+ # Load training data
162
+ LOGGER.info("Loading training data from %s...", cache_file)
163
+ data = np.load(cache_file)
164
+ X_train = data["X_train"]
165
+ y_train = data["y_train"]
166
+ X_val = data["X_val"]
167
+ y_val = data["y_val"]
168
+
169
+ LOGGER.info("Train samples: %d", len(X_train))
170
+ LOGGER.info("Validation samples: %d", len(X_val))
171
+
172
+ # Generate base model predictions
173
+ LOGGER.info("")
174
+ LOGGER.info("Generating Multi-Head base model predictions...")
175
+ X_train_norm = normalize_xyY(X_train, input_params)
176
+ y_train_norm = normalize_munsell(y_train, output_params)
177
+
178
+ # Base predictions (normalized)
179
+ base_pred_train_norm = base_session.run(None, {"xyY": X_train_norm})[0]
180
+
181
+ X_val_norm = normalize_xyY(X_val, input_params)
182
+ y_val_norm = normalize_munsell(y_val, output_params)
183
+ base_pred_val_norm = base_session.run(None, {"xyY": X_val_norm})[0]
184
+
185
+ # Compute errors (in normalized space)
186
+ error_train = y_train_norm - base_pred_train_norm
187
+ error_val = y_val_norm - base_pred_val_norm
188
+
189
+ # Statistics
190
+ LOGGER.info("")
191
+ LOGGER.info("Multi-Head base model error statistics (normalized space):")
192
+ LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train)))
193
+ LOGGER.info(" Std of error: %.6f", np.std(error_train))
194
+ LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train)))
195
+
196
+ # Create combined input: [xyY_norm, base_prediction_norm]
197
+ X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
198
+ X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
199
+
200
+ # Convert to PyTorch tensors
201
+ X_train_t = torch.FloatTensor(X_train_combined)
202
+ error_train_t = torch.FloatTensor(error_train)
203
+ X_val_t = torch.FloatTensor(X_val_combined)
204
+ error_val_t = torch.FloatTensor(error_val)
205
+
206
+ # Create data loaders
207
+ train_dataset = TensorDataset(X_train_t, error_train_t)
208
+ val_dataset = TensorDataset(X_val_t, error_val_t)
209
+
210
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
211
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
212
+
213
+ # Initialize Multi-Head error predictor
214
+ model = MultiHeadErrorPredictorToMunsell().to(device)
215
+ LOGGER.info("")
216
+ LOGGER.info("Multi-Head error predictor architecture:")
217
+ LOGGER.info("%s", model)
218
+
219
+ # Count parameters
220
+ total_params = sum(p.numel() for p in model.parameters())
221
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
222
+
223
+ # Count parameters per branch
224
+ hue_params = sum(p.numel() for p in model.hue_branch.parameters())
225
+ value_params = sum(p.numel() for p in model.value_branch.parameters())
226
+ chroma_params = sum(p.numel() for p in model.chroma_branch.parameters())
227
+ code_params = sum(p.numel() for p in model.code_branch.parameters())
228
+
229
+ LOGGER.info(" - Hue branch: %s", f"{hue_params:,}")
230
+ LOGGER.info(" - Value branch: %s", f"{value_params:,}")
231
+ LOGGER.info(" - Chroma branch: %s (WIDER 1.5x)", f"{chroma_params:,}")
232
+ LOGGER.info(" - Code branch: %s", f"{code_params:,}")
233
+
234
+ # Training setup with precision-focused loss
235
+ LOGGER.info("")
236
+ LOGGER.info("Using precision-focused loss function:")
237
+ LOGGER.info(" - MSE (weight: 1.0)")
238
+ LOGGER.info(" - MAE (weight: 0.5)")
239
+ LOGGER.info(" - Log penalty for small errors (weight: 0.3)")
240
+ LOGGER.info(" - Huber loss with delta=0.01 (weight: 0.5)")
241
+
242
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
243
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
244
+ optimizer, mode="min", factor=0.5, patience=5
245
+ )
246
+ criterion = precision_focused_loss
247
+
248
+ # MLflow setup
249
+ run_name = setup_mlflow_experiment("from_xyY", "multi_head_multi_error_predictor")
250
+
251
+ LOGGER.info("")
252
+ LOGGER.info("MLflow run: %s", run_name)
253
+
254
+ # Training loop
255
+ best_val_loss = float("inf")
256
+ patience_counter = 0
257
+
258
+ LOGGER.info("")
259
+ LOGGER.info("Starting training...")
260
+
261
+ with mlflow.start_run(run_name=run_name):
262
+ # Log hyperparameters
263
+ mlflow.log_params(
264
+ {
265
+ "num_epochs": epochs,
266
+ "batch_size": batch_size,
267
+ "learning_rate": lr,
268
+ "weight_decay": 1e-5,
269
+ "optimizer": "AdamW",
270
+ "scheduler": "ReduceLROnPlateau",
271
+ "criterion": "precision_focused_loss",
272
+ "patience": patience,
273
+ "total_params": total_params,
274
+ }
275
+ )
276
+
277
+ for epoch in range(epochs):
278
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
279
+ val_loss = validate(model, val_loader, criterion, device)
280
+
281
+ # Update learning rate
282
+ scheduler.step(val_loss)
283
+
284
+ # Log to MLflow
285
+ log_training_epoch(
286
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
287
+ )
288
+
289
+ LOGGER.info(
290
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
291
+ epoch + 1,
292
+ epochs,
293
+ train_loss,
294
+ val_loss,
295
+ optimizer.param_groups[0]["lr"],
296
+ )
297
+
298
+ # Early stopping
299
+ if val_loss < best_val_loss:
300
+ best_val_loss = val_loss
301
+ patience_counter = 0
302
+
303
+ # Save best model
304
+ model_directory.mkdir(exist_ok=True)
305
+ checkpoint_file = (
306
+ model_directory / "multi_head_multi_error_predictor_best.pth"
307
+ )
308
+
309
+ torch.save(
310
+ {
311
+ "model_state_dict": model.state_dict(),
312
+ "epoch": epoch,
313
+ "val_loss": val_loss,
314
+ },
315
+ checkpoint_file,
316
+ )
317
+
318
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
319
+ else:
320
+ patience_counter += 1
321
+ if patience_counter >= patience:
322
+ LOGGER.info("")
323
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
324
+ break
325
+
326
+ # Log final metrics
327
+ mlflow.log_metrics(
328
+ {
329
+ "best_val_loss": best_val_loss,
330
+ "final_epoch": epoch + 1,
331
+ }
332
+ )
333
+
334
+ # Export to ONNX
335
+ LOGGER.info("")
336
+ LOGGER.info("Exporting Multi-Head error predictor to ONNX...")
337
+ model.eval()
338
+
339
+ # Load best model
340
+ checkpoint = torch.load(checkpoint_file)
341
+ model.load_state_dict(checkpoint["model_state_dict"])
342
+
343
+ # Create dummy input (xyY_norm + base_pred_norm = 7 inputs)
344
+ dummy_input = torch.randn(1, 7).to(device)
345
+
346
+ # Export
347
+ onnx_file = model_directory / "multi_head_multi_error_predictor.onnx"
348
+ torch.onnx.export(
349
+ model,
350
+ dummy_input,
351
+ onnx_file,
352
+ export_params=True,
353
+ opset_version=15,
354
+ input_names=["combined_input"],
355
+ output_names=["error_correction"],
356
+ dynamic_axes={
357
+ "combined_input": {0: "batch_size"},
358
+ "error_correction": {0: "batch_size"},
359
+ },
360
+ )
361
+
362
+ LOGGER.info("Multi-Head error predictor ONNX model saved to: %s", onnx_file)
363
+
364
+ # Log artifacts
365
+ mlflow.log_artifact(str(checkpoint_file))
366
+ mlflow.log_artifact(str(onnx_file))
367
+
368
+ # Log model
369
+ mlflow.pytorch.log_model(model, "model")
370
+
371
+
372
+ LOGGER.info("=" * 80)
373
+
374
+
375
+ if __name__ == "__main__":
376
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
377
+
378
+ main()
learning_munsell/training/from_xyY/train_multi_head_multi_error_predictor_large.py ADDED
@@ -0,0 +1,409 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train Multi-Head error predictor on large dataset (2M samples).
3
+
4
+ Architecture:
5
+ - 4 independent error correction branches (one per component)
6
+ - Each branch: 7 inputs (xyY + base_pred) → encoder → decoder → 1 error output
7
+ - Chroma branch: WIDER (1.5x capacity for hardest component)
8
+
9
+ Uses the large dataset for improved model training.
10
+ """
11
+
12
+ import logging
13
+ from pathlib import Path
14
+ import click
15
+ import mlflow
16
+ import mlflow.pytorch
17
+ import numpy as np
18
+ import onnxruntime as ort
19
+ import torch
20
+ from numpy.typing import NDArray
21
+ from torch import nn, optim
22
+ from torch.utils.data import DataLoader, TensorDataset
23
+
24
+ from learning_munsell import PROJECT_ROOT
25
+ from learning_munsell.models.networks import (
26
+ ComponentErrorPredictor,
27
+ MultiHeadErrorPredictorToMunsell,
28
+ )
29
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
30
+ from learning_munsell.utilities.data import normalize_xyY, normalize_munsell
31
+ from learning_munsell.utilities.losses import precision_focused_loss
32
+ from learning_munsell.utilities.training import train_epoch, validate
33
+
34
+ LOGGER = logging.getLogger(__name__)
35
+
36
+
37
+ def load_base_model(
38
+ model_path: Path, params_path: Path
39
+ ) -> tuple[ort.InferenceSession, dict, dict]:
40
+ """
41
+ Load the base ONNX model and normalization parameters.
42
+
43
+ Parameters
44
+ ----------
45
+ model_path : Path
46
+ Path to the ONNX model file.
47
+ params_path : Path
48
+ Path to the normalization parameters file (.npz).
49
+
50
+ Returns
51
+ -------
52
+ session : ort.InferenceSession
53
+ ONNX Runtime inference session.
54
+ input_params : dict
55
+ Input normalization parameters.
56
+ output_params : dict
57
+ Output normalization parameters.
58
+ """
59
+ session = ort.InferenceSession(str(model_path))
60
+ params = np.load(params_path, allow_pickle=True)
61
+ return session, params["input_params"].item(), params["output_params"].item()
62
+
63
+
64
+ @click.command()
65
+ @click.option(
66
+ "--base-model",
67
+ type=click.Path(exists=True, path_type=Path),
68
+ default=None,
69
+ help="Path to Multi-Head large base model ONNX file",
70
+ )
71
+ @click.option(
72
+ "--params",
73
+ type=click.Path(exists=True, path_type=Path),
74
+ default=None,
75
+ help="Path to normalization params file",
76
+ )
77
+ @click.option(
78
+ "--output-suffix",
79
+ type=str,
80
+ default="large",
81
+ help="Suffix for output filenames (default: 'large')",
82
+ )
83
+ @click.option(
84
+ "--epochs",
85
+ type=int,
86
+ default=300,
87
+ help="Number of training epochs (default: 300)",
88
+ )
89
+ @click.option(
90
+ "--batch-size",
91
+ type=int,
92
+ default=2048,
93
+ help="Batch size for training (default: 2048)",
94
+ )
95
+ @click.option(
96
+ "--lr",
97
+ type=float,
98
+ default=3e-4,
99
+ help="Learning rate (default: 3e-4)",
100
+ )
101
+ @click.option(
102
+ "--patience",
103
+ type=int,
104
+ default=30,
105
+ help="Early stopping patience (default: 30)",
106
+ )
107
+ def main(
108
+ base_model: Path | None,
109
+ params: Path | None,
110
+ output_suffix: str,
111
+ epochs: int,
112
+ batch_size: int,
113
+ lr: float,
114
+ patience: int,
115
+ ) -> None:
116
+ """
117
+ Train Multi-Head error predictor on large dataset.
118
+
119
+ This script trains an error predictor on top of the Multi-Head large
120
+ base model, using the 2M sample dataset for improved accuracy.
121
+
122
+ Parameters
123
+ ----------
124
+ base_model : Path, optional
125
+ Path to the Multi-Head large base model ONNX file.
126
+ Default: models/from_xyY/multi_head_large.onnx
127
+ params : Path, optional
128
+ Path to the normalization parameters file.
129
+ Default: models/from_xyY/multi_head_large_normalization_params.npz
130
+ output_suffix : str
131
+ Suffix for output filenames (default: 'large').
132
+
133
+ Notes
134
+ -----
135
+ The training pipeline:
136
+ 1. Loads pre-trained Multi-Head large base model
137
+ 2. Generates base model predictions for training data (in batches)
138
+ 3. Computes residual errors between predictions and targets
139
+ 4. Trains multi-head error predictor on these residuals
140
+ 5. Uses precision-focused loss function
141
+ 6. Learning rate scheduling with ReduceLROnPlateau
142
+ 7. Early stopping based on validation loss
143
+ 8. Exports model to ONNX format
144
+ 9. Logs metrics and artifacts to MLflow
145
+ """
146
+
147
+ LOGGER.info("=" * 80)
148
+ LOGGER.info("Multi-Head Error Predictor: Large Dataset (2M samples)")
149
+ LOGGER.info("=" * 80)
150
+
151
+ # Set device
152
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
153
+ if torch.backends.mps.is_available():
154
+ device = torch.device("mps")
155
+ LOGGER.info("Using device: %s", device)
156
+
157
+ # Paths
158
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
159
+ data_dir = PROJECT_ROOT / "data"
160
+
161
+ # Use provided paths or defaults for large model
162
+ if base_model is None:
163
+ base_model = model_directory / "multi_head_large.onnx"
164
+ if params is None:
165
+ params = model_directory / "multi_head_large_normalization_params.npz"
166
+
167
+ cache_file = data_dir / "training_data_large.npz"
168
+
169
+ if not cache_file.exists():
170
+ LOGGER.error("Error: Large training data not found at %s", cache_file)
171
+ LOGGER.error("Please run generate_large_training_data.py first")
172
+ return
173
+
174
+ if not base_model.exists():
175
+ LOGGER.error("Error: Multi-Head large base model not found at %s", base_model)
176
+ LOGGER.error("Please run train_multi_head_large.py first")
177
+ return
178
+
179
+ # Load base model
180
+ LOGGER.info("")
181
+ LOGGER.info("Loading Multi-Head large base model from %s...", base_model)
182
+ base_session, input_params, output_params = load_base_model(base_model, params)
183
+
184
+ # Load training data
185
+ LOGGER.info("Loading large training data from %s...", cache_file)
186
+ data = np.load(cache_file)
187
+ X_train = data["X_train"]
188
+ y_train = data["y_train"]
189
+ X_val = data["X_val"]
190
+ y_val = data["y_val"]
191
+
192
+ LOGGER.info("Train samples: %d", len(X_train))
193
+ LOGGER.info("Validation samples: %d", len(X_val))
194
+
195
+ # Generate base model predictions
196
+ LOGGER.info("")
197
+ LOGGER.info("Generating Multi-Head large base model predictions...")
198
+ X_train_norm = normalize_xyY(X_train, input_params)
199
+ y_train_norm = normalize_munsell(y_train, output_params)
200
+
201
+ # Base predictions (normalized) - process in batches for memory efficiency
202
+ LOGGER.info(" Processing training set predictions...")
203
+ inference_batch_size = 50000
204
+ base_pred_train_norm = []
205
+ for i in range(0, len(X_train_norm), inference_batch_size):
206
+ batch = X_train_norm[i : i + inference_batch_size]
207
+ pred = base_session.run(None, {"xyY": batch})[0]
208
+ base_pred_train_norm.append(pred)
209
+ base_pred_train_norm = np.concatenate(base_pred_train_norm, axis=0)
210
+
211
+ X_val_norm = normalize_xyY(X_val, input_params)
212
+ y_val_norm = normalize_munsell(y_val, output_params)
213
+
214
+ LOGGER.info(" Processing validation set predictions...")
215
+ base_pred_val_norm = []
216
+ for i in range(0, len(X_val_norm), inference_batch_size):
217
+ batch = X_val_norm[i : i + inference_batch_size]
218
+ pred = base_session.run(None, {"xyY": batch})[0]
219
+ base_pred_val_norm.append(pred)
220
+ base_pred_val_norm = np.concatenate(base_pred_val_norm, axis=0)
221
+
222
+ # Compute errors (in normalized space)
223
+ error_train = y_train_norm - base_pred_train_norm
224
+ error_val = y_val_norm - base_pred_val_norm
225
+
226
+ # Statistics
227
+ LOGGER.info("")
228
+ LOGGER.info("Multi-Head large base model error statistics (normalized space):")
229
+ LOGGER.info(" Mean absolute error: %.6f", np.mean(np.abs(error_train)))
230
+ LOGGER.info(" Std of error: %.6f", np.std(error_train))
231
+ LOGGER.info(" Max absolute error: %.6f", np.max(np.abs(error_train)))
232
+
233
+ # Create combined input: [xyY_norm, base_prediction_norm]
234
+ X_train_combined = np.concatenate([X_train_norm, base_pred_train_norm], axis=1)
235
+ X_val_combined = np.concatenate([X_val_norm, base_pred_val_norm], axis=1)
236
+
237
+ # Convert to PyTorch tensors
238
+ X_train_t = torch.FloatTensor(X_train_combined)
239
+ error_train_t = torch.FloatTensor(error_train)
240
+ X_val_t = torch.FloatTensor(X_val_combined)
241
+ error_val_t = torch.FloatTensor(error_val)
242
+
243
+ # Create data loaders (larger batch size for large dataset)
244
+ train_dataset = TensorDataset(X_train_t, error_train_t)
245
+ val_dataset = TensorDataset(X_val_t, error_val_t)
246
+
247
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
248
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
249
+
250
+ # Initialize Multi-Head error predictor
251
+ model = MultiHeadErrorPredictorToMunsell().to(device)
252
+ LOGGER.info("")
253
+ LOGGER.info("Multi-Head error predictor architecture:")
254
+ LOGGER.info("%s", model)
255
+
256
+ # Count parameters
257
+ total_params = sum(p.numel() for p in model.parameters())
258
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
259
+
260
+ # Count parameters per branch
261
+ hue_params = sum(p.numel() for p in model.hue_branch.parameters())
262
+ value_params = sum(p.numel() for p in model.value_branch.parameters())
263
+ chroma_params = sum(p.numel() for p in model.chroma_branch.parameters())
264
+ code_params = sum(p.numel() for p in model.code_branch.parameters())
265
+
266
+ LOGGER.info(" - Hue branch: %s", f"{hue_params:,}")
267
+ LOGGER.info(" - Value branch: %s", f"{value_params:,}")
268
+ LOGGER.info(" - Chroma branch: %s (WIDER 1.5x)", f"{chroma_params:,}")
269
+ LOGGER.info(" - Code branch: %s", f"{code_params:,}")
270
+
271
+ # Training setup
272
+ LOGGER.info("")
273
+ LOGGER.info("Using precision-focused loss function:")
274
+ LOGGER.info(" - MSE (weight: 1.0)")
275
+ LOGGER.info(" - MAE (weight: 0.5)")
276
+ LOGGER.info(" - Log penalty for small errors (weight: 0.3)")
277
+ LOGGER.info(" - Huber loss with delta=0.01 (weight: 0.5)")
278
+
279
+ optimizer = optim.AdamW(model.parameters(), lr=lr, weight_decay=1e-5)
280
+ scheduler = optim.lr_scheduler.ReduceLROnPlateau(
281
+ optimizer, mode="min", factor=0.5, patience=10
282
+ )
283
+ criterion = precision_focused_loss
284
+
285
+ # MLflow setup
286
+ run_name = setup_mlflow_experiment(
287
+ "from_xyY", f"multi_head_multi_error_predictor_{output_suffix}"
288
+ )
289
+
290
+ LOGGER.info("")
291
+ LOGGER.info("MLflow run: %s", run_name)
292
+
293
+ # Training loop
294
+ best_val_loss = float("inf")
295
+ patience_counter = 0
296
+
297
+ LOGGER.info("")
298
+ LOGGER.info("Starting training...")
299
+
300
+ with mlflow.start_run(run_name=run_name):
301
+ mlflow.log_params(
302
+ {
303
+ "model": f"multi_head_multi_error_predictor_{output_suffix}",
304
+ "num_epochs": epochs,
305
+ "batch_size": batch_size,
306
+ "learning_rate": lr,
307
+ "weight_decay": 1e-5,
308
+ "optimizer": "AdamW",
309
+ "scheduler": "ReduceLROnPlateau",
310
+ "criterion": "precision_focused_loss",
311
+ "patience": patience,
312
+ "total_params": total_params,
313
+ "train_samples": len(X_train),
314
+ "val_samples": len(X_val),
315
+ "dataset": "large_2M",
316
+ }
317
+ )
318
+
319
+ for epoch in range(epochs):
320
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
321
+ val_loss = validate(model, val_loader, criterion, device)
322
+
323
+ scheduler.step(val_loss)
324
+
325
+ log_training_epoch(
326
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
327
+ )
328
+
329
+ LOGGER.info(
330
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f, LR: %.6f",
331
+ epoch + 1,
332
+ epochs,
333
+ train_loss,
334
+ val_loss,
335
+ optimizer.param_groups[0]["lr"],
336
+ )
337
+
338
+ if val_loss < best_val_loss:
339
+ best_val_loss = val_loss
340
+ patience_counter = 0
341
+
342
+ model_directory.mkdir(exist_ok=True)
343
+ checkpoint_file = (
344
+ model_directory / f"multi_head_multi_error_predictor_{output_suffix}_best.pth"
345
+ )
346
+
347
+ torch.save(
348
+ {
349
+ "model_state_dict": model.state_dict(),
350
+ "epoch": epoch,
351
+ "val_loss": val_loss,
352
+ "output_params": output_params,
353
+ },
354
+ checkpoint_file,
355
+ )
356
+
357
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
358
+ else:
359
+ patience_counter += 1
360
+ if patience_counter >= patience:
361
+ LOGGER.info("")
362
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
363
+ break
364
+
365
+ mlflow.log_metrics(
366
+ {
367
+ "best_val_loss": best_val_loss,
368
+ "final_epoch": epoch + 1,
369
+ }
370
+ )
371
+
372
+ # Export to ONNX
373
+ LOGGER.info("")
374
+ LOGGER.info("Exporting Multi-Head error predictor to ONNX...")
375
+ model.eval()
376
+
377
+ checkpoint = torch.load(checkpoint_file, weights_only=False)
378
+ model.load_state_dict(checkpoint["model_state_dict"])
379
+
380
+ dummy_input = torch.randn(1, 7).to(device)
381
+
382
+ onnx_file = model_directory / f"multi_head_multi_error_predictor_{output_suffix}.onnx"
383
+ torch.onnx.export(
384
+ model,
385
+ dummy_input,
386
+ onnx_file,
387
+ export_params=True,
388
+ opset_version=15,
389
+ input_names=["combined_input"],
390
+ output_names=["error_correction"],
391
+ dynamic_axes={
392
+ "combined_input": {0: "batch_size"},
393
+ "error_correction": {0: "batch_size"},
394
+ },
395
+ )
396
+
397
+ LOGGER.info("Multi-Head error predictor ONNX model saved to: %s", onnx_file)
398
+
399
+ mlflow.log_artifact(str(checkpoint_file))
400
+ mlflow.log_artifact(str(onnx_file))
401
+ mlflow.pytorch.log_model(model, "model")
402
+
403
+ LOGGER.info("=" * 80)
404
+
405
+
406
+ if __name__ == "__main__":
407
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
408
+
409
+ main()
learning_munsell/training/from_xyY/train_multi_head_st2084.py ADDED
@@ -0,0 +1,313 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Train multi-head ML model for xyY to Munsell conversion with ST.2084 (PQ) encoded Y.
3
+
4
+ Experiment: Apply SMPTE ST.2084 (Perceptual Quantizer) encoding to Y before
5
+ normalization. ST.2084 is designed for perceptual uniformity across a wide
6
+ luminance range, potentially providing better alignment with Munsell Value
7
+ than simple gamma correction.
8
+
9
+ The multi-head architecture has separate heads for each Munsell component,
10
+ so PQ encoding on Y should primarily benefit Value prediction without
11
+ negatively impacting Chroma prediction.
12
+ """
13
+
14
+ import logging
15
+ from typing import Any
16
+
17
+ import click
18
+ import mlflow
19
+ import mlflow.pytorch
20
+ import numpy as np
21
+ import torch
22
+ from colour.models import eotf_inverse_ST2084
23
+ from numpy.typing import NDArray
24
+ from torch import nn, optim
25
+ from torch.utils.data import DataLoader, TensorDataset
26
+
27
+ from learning_munsell import PROJECT_ROOT
28
+ from learning_munsell.models.networks import MultiHeadMLPToMunsell
29
+ from learning_munsell.utilities.common import log_training_epoch, setup_mlflow_experiment
30
+ from learning_munsell.utilities.data import (
31
+ MUNSELL_NORMALIZATION_PARAMS,
32
+ normalize_munsell,
33
+ )
34
+ from learning_munsell.utilities.losses import weighted_mse_loss
35
+ from learning_munsell.utilities.training import train_epoch, validate
36
+
37
+ LOGGER = logging.getLogger(__name__)
38
+
39
+ # Peak luminance for ST.2084 scaling
40
+ # Munsell Y is relative luminance [0, 1], we scale to cd/m² for ST.2084
41
+ # Using 100 cd/m² as reference white (typical SDR display)
42
+ L_P_REFERENCE = 100.0
43
+
44
+
45
+ def normalize_inputs(
46
+ X: NDArray, L_p: float = L_P_REFERENCE
47
+ ) -> tuple[NDArray, dict[str, Any]]:
48
+ """
49
+ Normalize xyY inputs to [0, 1] range with ST.2084 (PQ) encoding on Y.
50
+
51
+ Parameters
52
+ ----------
53
+ X : ndarray
54
+ xyY values of shape (n, 3) where columns are [x, y, Y].
55
+ L_p : float
56
+ Peak luminance in cd/m² for ST.2084 scaling.
57
+
58
+ Returns
59
+ -------
60
+ ndarray
61
+ Normalized values with ST.2084-encoded Y, dtype float32.
62
+ dict
63
+ Normalization parameters including L_p and encoding type.
64
+ """
65
+ # xyY chromaticity and luminance ranges (all [0, 1])
66
+ x_range = (0.0, 1.0)
67
+ y_range = (0.0, 1.0)
68
+ Y_range = (0.0, 1.0)
69
+
70
+ X_norm = X.copy()
71
+ X_norm[:, 0] = (X[:, 0] - x_range[0]) / (x_range[1] - x_range[0])
72
+ X_norm[:, 1] = (X[:, 1] - y_range[0]) / (y_range[1] - y_range[0])
73
+
74
+ # Normalize Y first, then apply ST.2084
75
+ Y_normalized = (X[:, 2] - Y_range[0]) / (Y_range[1] - Y_range[0])
76
+ # Clip to avoid numerical issues
77
+ Y_normalized = np.clip(Y_normalized, 0, 1)
78
+ # Scale to cd/m² and apply ST.2084 inverse EOTF (PQ encoding)
79
+ # ST.2084 expects absolute luminance in cd/m²
80
+ Y_cdm2 = Y_normalized * L_p
81
+ # eotf_inverse_ST2084 returns values in [0, 1] for the 10000 cd/m² range
82
+ # We use a custom L_p to scale appropriately
83
+ X_norm[:, 2] = eotf_inverse_ST2084(Y_cdm2, L_p=L_p)
84
+
85
+ params = {
86
+ "x_range": x_range,
87
+ "y_range": y_range,
88
+ "Y_range": Y_range,
89
+ "encoding": "ST2084",
90
+ "L_p": L_p,
91
+ }
92
+
93
+ return X_norm, params
94
+
95
+
96
+ @click.command()
97
+ @click.option("--epochs", default=200, help="Number of training epochs")
98
+ @click.option("--batch-size", default=1024, help="Batch size for training")
99
+ @click.option("--lr", default=5e-4, help="Learning rate")
100
+ @click.option("--patience", default=20, help="Early stopping patience")
101
+ def main(epochs: int, batch_size: int, lr: float, patience: int) -> None:
102
+ """
103
+ Train the multi-head model with ST.2084 (PQ) encoded Y input.
104
+
105
+ Notes
106
+ -----
107
+ The training pipeline:
108
+ 1. Loads training and validation data from cache
109
+ 2. Normalizes inputs with ST.2084 (PQ) encoding on Y
110
+ 3. Normalizes Munsell outputs to [0, 1] range
111
+ 4. Trains multi-head MLP with weighted MSE loss
112
+ 5. Uses early stopping based on validation loss
113
+ 6. Exports best model to ONNX format
114
+ 7. Logs metrics and artifacts to MLflow
115
+
116
+ ST.2084 (Perceptual Quantizer) encoding is designed for perceptual
117
+ uniformity across a wide luminance range, potentially providing better
118
+ alignment with Munsell Value than simple gamma correction. The multi-head
119
+ architecture isolates this effect to the Value head without negatively
120
+ impacting Chroma prediction.
121
+ """
122
+
123
+ LOGGER.info("=" * 80)
124
+ LOGGER.info("ML-Based xyY to Munsell Conversion: Multi-Head ST.2084 Experiment")
125
+ LOGGER.info("ST.2084 (PQ) encoding applied to Y component (L_p=%.0f cd/m²)", L_P_REFERENCE)
126
+ LOGGER.info("=" * 80)
127
+
128
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
129
+ LOGGER.info("Using device: %s", device)
130
+
131
+ # Load training data
132
+ data_dir = PROJECT_ROOT / "data"
133
+ cache_file = data_dir / "training_data.npz"
134
+
135
+ if not cache_file.exists():
136
+ LOGGER.error("Error: Training data not found at %s", cache_file)
137
+ LOGGER.error("Please run 01_generate_training_data.py first")
138
+ return
139
+
140
+ LOGGER.info("Loading training data from %s...", cache_file)
141
+ data = np.load(cache_file)
142
+
143
+ X_train = data["X_train"]
144
+ y_train = data["y_train"]
145
+ X_val = data["X_val"]
146
+ y_val = data["y_val"]
147
+
148
+ LOGGER.info("Train samples: %d", len(X_train))
149
+ LOGGER.info("Validation samples: %d", len(X_val))
150
+
151
+ # Normalize data with ST.2084 encoding
152
+ X_train_norm, input_params = normalize_inputs(X_train, L_p=L_P_REFERENCE)
153
+ X_val_norm, _ = normalize_inputs(X_val, L_p=L_P_REFERENCE)
154
+
155
+ output_params = MUNSELL_NORMALIZATION_PARAMS
156
+ y_train_norm = normalize_munsell(y_train, output_params)
157
+ y_val_norm = normalize_munsell(y_val, output_params)
158
+
159
+ LOGGER.info("")
160
+ LOGGER.info("Input normalization with ST.2084 (L_p=%.0f):", L_P_REFERENCE)
161
+ LOGGER.info(" Y range after ST.2084: [%.4f, %.4f]", X_train_norm[:, 2].min(), X_train_norm[:, 2].max())
162
+
163
+ # Convert to PyTorch tensors
164
+ X_train_t = torch.FloatTensor(X_train_norm)
165
+ y_train_t = torch.FloatTensor(y_train_norm)
166
+ X_val_t = torch.FloatTensor(X_val_norm)
167
+ y_val_t = torch.FloatTensor(y_val_norm)
168
+
169
+ # Create data loaders
170
+ train_dataset = TensorDataset(X_train_t, y_train_t)
171
+ val_dataset = TensorDataset(X_val_t, y_val_t)
172
+
173
+ train_loader = DataLoader(train_dataset, batch_size=batch_size, shuffle=True)
174
+ val_loader = DataLoader(val_dataset, batch_size=batch_size, shuffle=False)
175
+
176
+ # Initialize model
177
+ model = MultiHeadMLPToMunsell().to(device)
178
+ LOGGER.info("")
179
+ LOGGER.info("Model architecture:")
180
+ LOGGER.info("%s", model)
181
+
182
+ total_params = sum(p.numel() for p in model.parameters())
183
+ LOGGER.info("Total parameters: %s", f"{total_params:,}")
184
+
185
+ # Training setup
186
+ optimizer = optim.Adam(model.parameters(), lr=lr)
187
+ criterion = weighted_mse_loss
188
+
189
+ # MLflow setup
190
+ run_name = setup_mlflow_experiment("from_xyY", "multi_head_st2084")
191
+
192
+ LOGGER.info("")
193
+ LOGGER.info("MLflow run: %s", run_name)
194
+
195
+ # Training loop
196
+ best_val_loss = float("inf")
197
+ patience_counter = 0
198
+
199
+ LOGGER.info("")
200
+ LOGGER.info("Starting training...")
201
+
202
+ with mlflow.start_run(run_name=run_name):
203
+ mlflow.log_params(
204
+ {
205
+ "model": "multi_head_st2084",
206
+ "num_epochs": epochs,
207
+ "batch_size": batch_size,
208
+ "learning_rate": lr,
209
+ "optimizer": "Adam",
210
+ "criterion": "weighted_mse_loss",
211
+ "patience": patience,
212
+ "total_params": total_params,
213
+ "encoding": "ST2084",
214
+ "L_p": L_P_REFERENCE,
215
+ }
216
+ )
217
+
218
+ for epoch in range(epochs):
219
+ train_loss = train_epoch(model, train_loader, optimizer, criterion, device)
220
+ val_loss = validate(model, val_loader, criterion, device)
221
+
222
+ log_training_epoch(
223
+ epoch, train_loss, val_loss, optimizer.param_groups[0]["lr"]
224
+ )
225
+
226
+ LOGGER.info(
227
+ "Epoch %03d/%d - Train Loss: %.6f, Val Loss: %.6f",
228
+ epoch + 1,
229
+ epochs,
230
+ train_loss,
231
+ val_loss,
232
+ )
233
+
234
+ if val_loss < best_val_loss:
235
+ best_val_loss = val_loss
236
+ patience_counter = 0
237
+
238
+ model_directory = PROJECT_ROOT / "models" / "from_xyY"
239
+ model_directory.mkdir(exist_ok=True)
240
+ checkpoint_file = model_directory / "multi_head_st2084_best.pth"
241
+
242
+ torch.save(
243
+ {
244
+ "model_state_dict": model.state_dict(),
245
+ "input_params": input_params,
246
+ "output_params": output_params,
247
+ "epoch": epoch,
248
+ "val_loss": val_loss,
249
+ },
250
+ checkpoint_file,
251
+ )
252
+
253
+ LOGGER.info(" → Saved best model (val_loss: %.6f)", val_loss)
254
+ else:
255
+ patience_counter += 1
256
+ if patience_counter >= patience:
257
+ LOGGER.info("")
258
+ LOGGER.info("Early stopping after %d epochs", epoch + 1)
259
+ break
260
+
261
+ mlflow.log_metrics(
262
+ {
263
+ "best_val_loss": best_val_loss,
264
+ "final_epoch": epoch + 1,
265
+ }
266
+ )
267
+
268
+ # Export to ONNX
269
+ LOGGER.info("")
270
+ LOGGER.info("Exporting model to ONNX...")
271
+ model.eval()
272
+
273
+ checkpoint = torch.load(checkpoint_file)
274
+ model.load_state_dict(checkpoint["model_state_dict"])
275
+
276
+ dummy_input = torch.randn(1, 3).to(device)
277
+
278
+ onnx_file = model_directory / "multi_head_st2084.onnx"
279
+ torch.onnx.export(
280
+ model,
281
+ dummy_input,
282
+ onnx_file,
283
+ export_params=True,
284
+ opset_version=17,
285
+ input_names=["xyY_st2084"],
286
+ output_names=["munsell_spec"],
287
+ dynamic_axes={"xyY_st2084": {0: "batch_size"}, "munsell_spec": {0: "batch_size"}},
288
+ )
289
+
290
+ # Save normalization parameters (including ST.2084 info)
291
+ params_file = model_directory / "multi_head_st2084_normalization_params.npz"
292
+ np.savez(
293
+ params_file,
294
+ input_params=input_params,
295
+ output_params=output_params,
296
+ )
297
+
298
+ LOGGER.info("ONNX model saved to: %s", onnx_file)
299
+ LOGGER.info("Normalization parameters saved to: %s", params_file)
300
+ LOGGER.info("IMPORTANT: Input Y must be ST.2084-encoded with L_p=%.0f", L_P_REFERENCE)
301
+
302
+ mlflow.log_artifact(str(checkpoint_file))
303
+ mlflow.log_artifact(str(onnx_file))
304
+ mlflow.log_artifact(str(params_file))
305
+ mlflow.pytorch.log_model(model, "model")
306
+
307
+ LOGGER.info("=" * 80)
308
+
309
+
310
+ if __name__ == "__main__":
311
+ logging.basicConfig(level=logging.INFO, format="%(message)s", force=True)
312
+
313
+ main()