ArthurY commited on
Commit
c3d0544
·
1 Parent(s): bda03b1

update source

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitignore +70 -1
  2. Dockerfile +3 -0
  3. physics_mcp/mcp_output/requirements.txt +18 -11
  4. physics_mcp/source/.dockerignore +8 -0
  5. physics_mcp/source/.gitattributes +0 -0
  6. physics_mcp/source/.gitignore +176 -0
  7. physics_mcp/source/CHANGELOG.md +556 -0
  8. physics_mcp/source/CITATION.cff +7 -0
  9. physics_mcp/source/CONTRIBUTING.md +251 -0
  10. physics_mcp/source/FAQ.md +60 -0
  11. physics_mcp/source/LICENSE.txt +201 -0
  12. physics_mcp/source/README.md +472 -0
  13. physics_mcp/source/SECURITY.md +34 -0
  14. physics_mcp/source/__init__.py +4 -0
  15. physics_mcp/source/greptile.json +59 -0
  16. physics_mcp/source/physicsnemo/__init__.py +22 -0
  17. physics_mcp/source/physicsnemo/active_learning/README.md +66 -0
  18. physics_mcp/source/physicsnemo/active_learning/__init__.py +35 -0
  19. physics_mcp/source/physicsnemo/active_learning/_registry.py +332 -0
  20. physics_mcp/source/physicsnemo/active_learning/config.py +808 -0
  21. physics_mcp/source/physicsnemo/active_learning/driver.py +1449 -0
  22. physics_mcp/source/physicsnemo/active_learning/logger.py +330 -0
  23. physics_mcp/source/physicsnemo/active_learning/loop.py +534 -0
  24. physics_mcp/source/physicsnemo/active_learning/protocols.py +1394 -0
  25. physics_mcp/source/physicsnemo/constants.py +48 -0
  26. physics_mcp/source/physicsnemo/datapipes/__init__.py +15 -0
  27. physics_mcp/source/physicsnemo/datapipes/benchmarks/__init__.py +15 -0
  28. physics_mcp/source/physicsnemo/datapipes/benchmarks/darcy.py +322 -0
  29. physics_mcp/source/physicsnemo/datapipes/benchmarks/kelvin_helmholtz.py +436 -0
  30. physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/__init__.py +15 -0
  31. physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/finite_difference.py +139 -0
  32. physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/finite_volume.py +759 -0
  33. physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/indexing.py +182 -0
  34. physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/initialization.py +77 -0
  35. physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/utils.py +141 -0
  36. physics_mcp/source/physicsnemo/datapipes/cae/__init__.py +18 -0
  37. physics_mcp/source/physicsnemo/datapipes/cae/cae_dataset.py +1275 -0
  38. physics_mcp/source/physicsnemo/datapipes/cae/domino_datapipe.py +1334 -0
  39. physics_mcp/source/physicsnemo/datapipes/cae/mesh_datapipe.py +490 -0
  40. physics_mcp/source/physicsnemo/datapipes/cae/readers.py +191 -0
  41. physics_mcp/source/physicsnemo/datapipes/climate/__init__.py +19 -0
  42. physics_mcp/source/physicsnemo/datapipes/climate/climate.py +813 -0
  43. physics_mcp/source/physicsnemo/datapipes/climate/era5_hdf5.py +622 -0
  44. physics_mcp/source/physicsnemo/datapipes/climate/era5_netcdf.py +15 -0
  45. physics_mcp/source/physicsnemo/datapipes/climate/synthetic.py +182 -0
  46. physics_mcp/source/physicsnemo/datapipes/climate/utils/__init__.py +15 -0
  47. physics_mcp/source/physicsnemo/datapipes/climate/utils/invariant.py +139 -0
  48. physics_mcp/source/physicsnemo/datapipes/climate/utils/zenith_angle.py +208 -0
  49. physics_mcp/source/physicsnemo/datapipes/datapipe.py +60 -0
  50. physics_mcp/source/physicsnemo/datapipes/gnn/__init__.py +15 -0
.gitignore CHANGED
@@ -1 +1,70 @@
1
- *.DS_Store
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.DS_Store
2
+
3
+ # ===== 源代码中不必要的目录(来自NVIDIA原项目) =====
4
+ # 文档 (102MB)
5
+ physics_mcp/source/docs/
6
+
7
+ # 测试 (28MB)
8
+ physics_mcp/source/test/
9
+
10
+ # 示例 (17MB)
11
+ physics_mcp/source/examples/
12
+
13
+ # ===== Git和CI/CD配置 =====
14
+ physics_mcp/source/.github/
15
+ physics_mcp/source/.gitlab/
16
+ physics_mcp/source/.gitlab-ci.yml
17
+ physics_mcp/source/.pre-commit-config.yaml
18
+ physics_mcp/source/.markdownlint.yaml
19
+
20
+ # ===== 项目配置文件 =====
21
+ physics_mcp/source/Dockerfile
22
+ physics_mcp/source/Makefile
23
+ physics_mcp/source/.gitmodules
24
+
25
+ # ===== Python缓存 =====
26
+ **/__pycache__/
27
+ **/*.py[cod]
28
+ **/*$py.class
29
+ *.so
30
+ .Python
31
+ build/
32
+ develop-eggs/
33
+ dist/
34
+ downloads/
35
+ eggs/
36
+ .eggs/
37
+ lib/
38
+ lib64/
39
+ parts/
40
+ sdist/
41
+ var/
42
+ wheels/
43
+ *.egg-info/
44
+ .installed.cfg
45
+ *.egg
46
+
47
+ # ===== 虚拟环境 =====
48
+ venv/
49
+ env/
50
+ ENV/
51
+ .venv
52
+
53
+ # ===== IDE配置 =====
54
+ .vscode/
55
+ .idea/
56
+ *.swp
57
+ *.swo
58
+ *~
59
+
60
+ # ===== Pytest和覆盖率 =====
61
+ .pytest_cache/
62
+ .coverage
63
+ htmlcov/
64
+
65
+ # ===== 日志和临时文件 =====
66
+ *.log
67
+ *.tmp
68
+ *.tmp.txt
69
+ physics_mcp/mcp_output/mcp_logs/
70
+ physics_mcp/mcp_output/output/
Dockerfile CHANGED
@@ -11,6 +11,9 @@ RUN apt-get update && apt-get install -y \
11
  wget \
12
  && rm -rf /var/lib/apt/lists/*
13
 
 
 
 
14
  # Copy physics_mcp folder
15
  COPY physics_mcp /app/physics_mcp
16
 
 
11
  wget \
12
  && rm -rf /var/lib/apt/lists/*
13
 
14
+ # Copy source directory (original NVIDIA physicsnemo code) - REQUIRED
15
+ COPY physics_mcp/source /app/physics_mcp/source
16
+
17
  # Copy physics_mcp folder
18
  COPY physics_mcp /app/physics_mcp
19
 
physics_mcp/mcp_output/requirements.txt CHANGED
@@ -1,19 +1,26 @@
1
  fastmcp>=0.1.0
2
  pydantic>=2.0.0
3
- torch
4
- numpy
5
- scipy
6
- onnx
7
  tritonclient
8
  matplotlib
9
  pandas
10
- pyyaml
11
- cuml
 
 
 
 
12
 
13
- # Optional Dependencies
14
- # wandb
15
- # mlflow
 
16
  # dgl
17
  # pyg
18
- # vtk
19
- # netCDF4
 
 
 
1
  fastmcp>=0.1.0
2
  pydantic>=2.0.0
3
+ torch>=2.4.0
4
+ numpy>=1.22.4
5
+ scipy>=1.9.0
6
+ onnx>=1.14.0
7
  tritonclient
8
  matplotlib
9
  pandas
10
+ pyyaml>=6.0
11
+ tqdm>=4.60.0
12
+ xarray>=2023.1.0
13
+ zarr>=2.14.2
14
+ s3fs>=2023.5.0
15
+ timm>=1.0.0
16
 
17
+ # Optional Dependencies (can be uncommented as needed)
18
+ # cuml>=24.0.0 (requires RAPIDS conda channel - use scipy fallback instead)
19
+ # wandb>=0.13.7
20
+ # mlflow>=2.1.1
21
  # dgl
22
  # pyg
23
+ # vtk>=9.2.6
24
+ # netCDF4>=1.6.3
25
+ # h5py>=3.7.0
26
+ # nvidia-dali-cuda120>=1.35.0
physics_mcp/source/.dockerignore ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ .git
2
+ .github
3
+ .gitlab
4
+ .coverage*
5
+ .*cache
6
+ examples
7
+ docs
8
+ test
physics_mcp/source/.gitattributes ADDED
File without changes
physics_mcp/source/.gitignore ADDED
@@ -0,0 +1,176 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Byte-compiled / optimized / DLL files
2
+ __pycache__/
3
+ *.py[cod]
4
+ *$py.class
5
+
6
+ # C extensions
7
+ *.so
8
+
9
+ # Distribution / packaging
10
+ .Python
11
+ build/
12
+ develop-eggs/
13
+ dist/
14
+ downloads/
15
+ eggs/
16
+ .eggs/
17
+ lib/
18
+ lib64/
19
+ parts/
20
+ sdist/
21
+ var/
22
+ wheels/
23
+ share/python-wheels/
24
+ *.egg-info/
25
+ .installed.cfg
26
+ *.egg
27
+ MANIFEST
28
+
29
+ # PyInstaller
30
+ # Usually these files are written by a python script from a template
31
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
32
+ *.manifest
33
+ *.spec
34
+
35
+ # Installer logs
36
+ pip-log.txt
37
+ pip-delete-this-directory.txt
38
+
39
+ # Unit test / coverage reports
40
+ htmlcov/
41
+ .tox/
42
+ .nox/
43
+ .coverage
44
+ .coverage.*
45
+ .cache
46
+ nosetests.xml
47
+ coverage.xml
48
+ *.cover
49
+ *.py,cover
50
+ .hypothesis/
51
+ .pytest_cache/
52
+ cover/
53
+
54
+ # Translations
55
+ *.mo
56
+ *.pot
57
+
58
+ # Django stuff:
59
+ *.log
60
+ local_settings.py
61
+ db.sqlite3
62
+ db.sqlite3-journal
63
+
64
+ # Flask stuff:
65
+ instance/
66
+ .webassets-cache
67
+
68
+ # Scrapy stuff:
69
+ .scrapy
70
+
71
+ # Sphinx documentation
72
+ docs/_build/
73
+ docs/examples/
74
+
75
+ # PyBuilder
76
+ .pybuilder/
77
+ target/
78
+
79
+ # Jupyter Notebook
80
+ .ipynb_checkpoints
81
+
82
+ # IPython
83
+ profile_default/
84
+ ipython_config.py
85
+
86
+ # pyenv
87
+ # For a library or package, you might want to ignore these files since the code is
88
+ # intended to run in multiple environments; otherwise, check them in:
89
+ # .python-version
90
+
91
+ # pipenv
92
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
93
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
94
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
95
+ # install all needed dependencies.
96
+ #Pipfile.lock
97
+
98
+ # poetry
99
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
100
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
101
+ # commonly ignored for libraries.
102
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
103
+ #poetry.lock
104
+
105
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow
106
+ __pypackages__/
107
+
108
+ # Celery stuff
109
+ celerybeat-schedule
110
+ celerybeat.pid
111
+
112
+ # SageMath parsed files
113
+ *.sage.py
114
+
115
+ # Environments
116
+ .env
117
+ .venv
118
+ env/
119
+ venv/
120
+ ENV/
121
+ env.bak/
122
+ venv.bak/
123
+
124
+ # Spyder project settings
125
+ .spyderproject
126
+ .spyproject
127
+
128
+ # Rope project settings
129
+ .ropeproject
130
+
131
+ # mkdocs documentation
132
+ /site
133
+
134
+ # mypy
135
+ .mypy_cache/
136
+ .dmypy.json
137
+ dmypy.json
138
+
139
+ # Pyre type checker
140
+ .pyre/
141
+
142
+ # pytype static type analyzer
143
+ .pytype/
144
+
145
+ # Cython debug symbols
146
+ cython_debug/
147
+
148
+ # PyCharm
149
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
150
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
151
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
152
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
153
+ .idea/
154
+
155
+ # VsCode
156
+ .vscode/
157
+ .cursor/
158
+
159
+ # VIM
160
+ *.swp
161
+ *~
162
+
163
+ # Additional stuff
164
+ nsight-systems*
165
+ build/
166
+ mlruns/
167
+ checkpoints/
168
+
169
+ # Hydra
170
+ outputs/
171
+ multirun/
172
+ .hydra/
173
+
174
+ # SLURM
175
+ slurm-*.out
176
+ sbatch_logs/
physics_mcp/source/CHANGELOG.md ADDED
@@ -0,0 +1,556 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <!-- markdownlint-disable MD024 -->
2
+ # Changelog
3
+
4
+ All notable changes to this project will be documented in this file.
5
+
6
+ The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
7
+ and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
8
+
9
+ ## [1.3.0a0] - 2025-XX-YY
10
+
11
+ ### Added
12
+
13
+ - Added mixture_of_experts for weather example in physicsnemo.examples.weather.
14
+ **⚠️Warning:** - It uses experimental DiT model subject to future API changes.
15
+ Added some modifications to DiT architecture in physicsnemo.experimental.models.dit.
16
+ Added learnable option to PositionalEmbedding in physicsnemo.models.diffusion.layers.
17
+ - Added lead-time aware training support to the StormCast example.
18
+ - Add a device aware kNN method to physicsnemo.utils.neighbors. Works with CPU or GPU
19
+ by dispatching to the proper optimized library, and torch.compile compatible.
20
+ - Added additional testing of the DoMINO datapipe.
21
+ - Examples: added a new example for full-waveform inversion using diffusion
22
+ models. Accessible in `examples/geophysics/diffusion_fwi`.
23
+ - Domain Parallelism: Domain Parallelism is now available for kNN, radius_search,
24
+ and torch.nn.functional.pad.
25
+ - Unified recipe for crash modeling, supporting Transolver and MeshGraphNet,
26
+ and three transient schemes.
27
+ - Added a check to `stochastic_sampler` that helps handle the `EDMPrecond` model,
28
+ which has a specific `.forward()` signature
29
+ - Added abstract interfaces for constructing active learning workflows, contained
30
+ under the `physicsnemo.active_learning` namespace. A preliminary example of how
31
+ to compose and define an active learning workflow is provided in `examples/active_learning`.
32
+ The `moons` example provides a minimal (pedagogical) composition that is meant to
33
+ illustrate how to define the necessary parts of the workflow.
34
+
35
+ ### Changed
36
+
37
+ - Migrated Stokes MGN example to PyTorch Geometric.
38
+ - Migrated Lennard Jones example to PyTorch Geometric.
39
+ - Migrated physicsnemo.utils.sdf.signed_distance_field to a static return,
40
+ torch-only interface. It also now works on distributed meshes and input fields.
41
+ - Refactored DiTBlock to be more modular
42
+ - Added NATTEN 2D neighborhood attention backend for DiTBlock
43
+ - Migrated blood flow example to PyTorch Geometric.
44
+ - Refactored DoMINO model code and examples for performance optimizations and improved readability.
45
+ - Migrated HydroGraphNet example to PyTorch Geometric.
46
+ - Support for saving and loading nested `physicsnemo.Module`s. It is now
47
+ possible to create nested modules with `m = Module(submodule, ...)`, and save
48
+ and load them with `Module.save` and `Module.from_checkpoint`.
49
+ **⚠️Warning:** - The modules have to be `physicsnemo.Module`s, and not
50
+ `torch.nn.Module`s.
51
+ - Support passing custom tokenizer, detokenizer, and attention `Module`s in
52
+ experimental DiT architecture
53
+ - Improved Transolver training recipe's configuration for checkpointing and normalization.
54
+ - Bumped `multi-storage-client` version to 0.33.0 with rust client.
55
+
56
+ ### Deprecated
57
+
58
+ ### Removed
59
+
60
+ ### Fixed
61
+
62
+ - Set `skip_scale` to Python float in U-Net to ensure compilation works.
63
+ - Ensure stream dependencies are handled correctly in physicsnemo.utils.neighbors
64
+ - Fixed the issue with incorrect handling of files with consecutive runs of
65
+ `combine_stl_solids.py` in the X-MGN recipe.
66
+ - Fixed the `RuntimeError: Worker data receiving interrupted` error in the datacenter example.
67
+
68
+ ### Security
69
+
70
+ ### Dependencies
71
+
72
+ ## [1.2.0] - 2025-08-26
73
+
74
+ ### Added
75
+
76
+ - Diffusion Transformer (DiT) model. The DiT model can be accessed in
77
+ `physicsnemo.experimental.models.dit.DiT`. **⚠️Warning:** - Experimental feature
78
+ subject to future API changes.
79
+ - Improved documentation for diffusion models and diffusion utils.
80
+ - Safe API to override `__init__`'s arguments saved in checkpoint file with
81
+ `Module.from_checkpoint("chkpt.mdlus", override_args=set(...))`.
82
+ - PyTorch Geometric MeshGraphNet backend.
83
+ - Functionality in DoMINO to take arbitrary number of `scalar` or `vector`
84
+ global parameters and encode them using `class ParameterModel`
85
+ - TopoDiff model and example.
86
+ - Added ability for DoMINO model to return volume neighbors.
87
+ - Added functionality in DoMINO recipe to introduce physics residual losses.
88
+ - Diffusion models, metrics, and utils: implementation of Student-t
89
+ distribution for EDM-based diffusion models (t-EDM). This feature is adapted
90
+ from the paper [Heavy-Tailed Diffusion Models, Pandey et al.](https://arxiv.org/abs/2410.14171>).
91
+ This includes a new EDM preconditioner (`tEDMPrecondSuperRes`), a loss
92
+ function (`tEDMResidualLoss`), and a new option in corrdiff `diffusion_step`.
93
+ &#9888;&#65039; This is an experimental feature that can be accessed through the
94
+ `physicsnemo.experimental` module; it might also be subjected to API changes
95
+ without notice.
96
+ - Bumped Ruff version from 0.0.290 to 0.12.5. Replaced Black with `ruff-format`.
97
+ - Domino improvements with Unet attention module and user configs
98
+ - Hybrid MeshGraphNet for modeling structural deformation
99
+ - Enabled TransformerEngine backend in the `transolver` model.
100
+ - Inference code for x-meshgraphnet example for external aerodynamics.
101
+ - Added a new example for external_aerodynamics: training `transolver` on
102
+ irregular mesh data for DrivaerML surface data.
103
+ - Added a new example for external aerodynamics for finetuning pretrained models.
104
+
105
+ ### Changed
106
+
107
+ - Diffusion utils: `physicsnemo.utils.generative` renamed into `physicsnemo.utils.diffusion`
108
+ - Diffusion models: in CorrDiff model wrappers (`EDMPrecondSuperResolution` and
109
+ `UNet`), the arguments `profile_mode` and `amp_mode` cannot be overriden by
110
+ `from_checkpoint`. They are now properties that can be dynamically changed
111
+ *after* the model instantiation with, for example, `model.amp_mode = True`
112
+ and `model.profile_mode = False`.
113
+ - Updated healpix data module to use correct `DistributedSampler` target for
114
+ test data loader
115
+ - Existing DGL-based vortex shedding example has been renamed to `vortex_shedding_mgn_dgl`.
116
+ Added new `vortex_shedding_mgn` example that uses PyTorch Geometric instead.
117
+ - HEALPixLayer can now use earth2grid HEALPix padding ops, if desired
118
+ - Migrated Vortex Shedding Reduced Mesh example to PyTorch Geometric.
119
+ - CorrDiff example: fixed bugs when training regression `UNet`.
120
+ - Diffusion models: fixed bugs related to gradient checkpointing on non-square
121
+ images.
122
+ - Diffusion models: created a separate class `Attention` for clarity and
123
+ modularity. Updated `UNetBlock` accordingly to use the `Attention` class
124
+ instead of custom attention logic. This will update the model architecture
125
+ for `SongUNet`-based diffusion models. Changes are not BC-breaking and are
126
+ transparent to the user.
127
+ - &#9888;&#65039; **BC-breaking:** refactored the automatic mixed precision
128
+ (AMP) API in layers and models defined in `physicsnemo/models/diffusion/` for
129
+ improved usability. Note: it is now, not only possible, but *required* to
130
+ explicitly set `model.amp_mode = True` in order to use the model in a
131
+ `torch.autocast` clause. This applies to all `SongUNet`-based models.
132
+ - Diffusion models: fixed and improved API to enable fp16 forward pass in
133
+ `UNet` and `EDMPrecondSuperResolution` model wrappers; fp16 forward pass can
134
+ now be toggled/untoggled by setting `model.use_fp16 = True`.
135
+ - Diffusion models: improved API for Apex group norm. `SongUNet`-based models
136
+ will automatically perform conversion of the input tensors to
137
+ `torch.channels_last` memory format when `model.use_apex_gn` is `True`. New
138
+ warnings are raised when attempting to use Apex group norm on CPU.
139
+ - Diffusion utils: systematic compilation of patching operations in `stochastic_sampler`
140
+ for improved performance.
141
+ - CorrDiff example: added option for Student-t EDM (t-EDM) in `train.py` and
142
+ `generate.py`. When training a CorrDiff diffusion model, this feature can be
143
+ enabled with the hydra overrides `++training.hp.distribution=student_t` and
144
+ `++training.hp.nu_student_t=<nu_value>`. For generation, this feature can be
145
+ enabled with similar overrides: `++generation.distribution=student_t` and
146
+ `++generation.nu_student_t=<nu_value>`.
147
+ - CorrDiff example: the parameters `P_mean` and `P_std` (used to compute the
148
+ noise level `sigma`) are now configurable. They can be set with the hydra
149
+ overrides `++training.hp.P_mean=<P_mean_value>` and
150
+ `++training.hp.P_std=<P_std_value>` for training (and similar ones with
151
+ `training.hp` replaced by `generation` for generation).
152
+ - Diffusion utils: patch-based inference and lead time support with
153
+ deterministic sampler.
154
+ - Existing DGL-based XAeroNet example has been renamed to `xaeronet_dgl`.
155
+ Added new `xaeronet` example that uses PyTorch Geometric instead.
156
+ - Updated the deforming plate example to use the Hybrid MeshGraphNet model.
157
+ - &#9888;&#65039; **BC-breaking:** Refactored the `transolver` model to improve
158
+ readability and performance, and extend to more use cases.
159
+ - Diffusion models: improved lead time support for `SongUNetPosLtEmbd` and
160
+ `EDMLoss`. Lead-time embeddings can now be used with/without positional
161
+ embeddings.
162
+ - Diffusion models: consolidate `ApexGroupNorm` and `GroupNorm` in
163
+ `models/diffusion/layers.py` with a factory `get_group_norm` that can
164
+ be used to instantiate either one of them. `get_group_norm` is now the
165
+ recommended way to instantiate a GroupNorm layer in `SongUNet`-based and
166
+ other diffusion models.
167
+ - Physicsnemo models: improved checkpoint loading API in
168
+ `Module.from_checkpoint` that now exposes a `strict` parameter to raise error
169
+ on missing/unexpected keys, similar to that used in
170
+ `torch.nn.Module.load_state_dict`.
171
+ - Migrated Hybrid MGN and deforming plate example to PyTorch Geometric.
172
+
173
+ ### Fixed
174
+
175
+ - Bug fixes in DoMINO model in sphere sampling and tensor reshaping
176
+ - Bug fixes in DoMINO utils random sampling and test.py
177
+ - Optimized DoMINO config params based on DrivAer ML
178
+
179
+ ## [1.1.1] - 2025-06-16
180
+
181
+ ### Fixed
182
+
183
+ - Fixed an inadvertent change to the deterministic sampler 2nd order correction
184
+ - Bug Fix in Domino model ball query layer
185
+ - Fixed bug models/unet/unet.py: setting num_conv_layers=1 gives errors
186
+
187
+ ## [1.1.0] - 2025-06-05
188
+
189
+ ### Added
190
+
191
+ - Added ReGen score-based data assimilation example
192
+ - General purpose patching API for patch-based diffusion
193
+ - New positional embedding selection strategy for CorrDiff SongUNet models
194
+ - Added Multi-Storage Client to allow checkpointing to/from Object Storage
195
+ - Added a new aerodynamics example using DoMINO to compute design sensitivities
196
+ (e.g., drag adjoint) with respect to underlying input geometry.
197
+
198
+ ### Changed
199
+
200
+ - Simplified CorrDiff config files, updated default values
201
+ - Refactored CorrDiff losses and samplers to use the patching API
202
+ - Support for non-square images and patches in patch-based diffusion
203
+ - ERA5 download example updated to use current file format convention and
204
+ restricts global statistics computation to the training set
205
+ - Support for training custom StormCast models and various other improvements for StormCast
206
+ - Updated CorrDiff training code to support multiple patch iterations to amortize
207
+ regression cost and usage of `torch.compile`
208
+ - Refactored `physicsnemo/models/diffusion/layers.py` to optimize data type
209
+ casting workflow, avoiding unnecessary casting under autocast mode
210
+ - Refactored Conv2d to enable fusion of conv2d with bias addition
211
+ - Refactored GroupNorm, UNetBlock, SongUNet, SongUNetPosEmbd to support usage of
212
+ Apex GroupNorm, fusion of activation with GroupNorm, and AMP workflow.
213
+ - Updated SongUNetPosEmbd to avoid unnecessary HtoD Memcpy of `pos_embd`
214
+ - Updated `from_checkpoint` to accommodate conversion between Apex optimized ckp
215
+ and non-optimized ckp
216
+ - Refactored CorrDiff NVTX annotation workflow to be configurable
217
+ - Refactored `ResidualLoss` to support patch-accumlating training for
218
+ amortizing regression costs
219
+ - Explicit handling of Warp device for ball query and sdf
220
+ - Merged SongUNetPosLtEmb with SongUNetPosEmb, add support for batch>1
221
+ - Add lead time embedding support for `positional_embedding_selector`. Enable
222
+ arbitrary positioning of probabilistic variables
223
+ - Enable lead time aware regression without CE loss
224
+ - Bumped minimum PyTorch version from 2.0.0 to 2.4.0, to minimize
225
+ support surface for `physicsnemo.distributed` functionality.
226
+
227
+ ### Dependencies
228
+
229
+ - Made `nvidia.dali` an optional dependency
230
+
231
+ ## [1.0.1] - 2025-03-25
232
+
233
+ ### Added
234
+
235
+ - Added version checks to ensure compatibility with older PyTorch for distributed
236
+ utilities and ShardTensor
237
+
238
+ ### Fixed
239
+
240
+ - `EntryPoint` error that occured during physicsnemo checkpoint loading
241
+
242
+ ## [1.0.0] - 2025-03-18
243
+
244
+ ### Added
245
+
246
+ - DoMINO model architecture, datapipe and training recipe
247
+ - Added matrix decomposition scheme to improve graph partitioning
248
+ - DrivAerML dataset support in FIGConvNet example.
249
+ - Retraining recipe for DoMINO from a pretrained model checkpoint
250
+ - Prototype support for domain parallelism of using ShardTensor (new).
251
+ - Enable DeviceMesh initialization via DistributedManager.
252
+ - Added Datacenter CFD use case.
253
+ - Add leave-in profiling utilities to physicsnemo, to easily enable torch/python/nsight
254
+ profiling in all aspects of the codebase.
255
+
256
+ ### Changed
257
+
258
+ - Refactored StormCast training example
259
+ - Enhancements and bug fixes to DoMINO model and training example
260
+ - Enhancement to parameterize DoMINO model with inlet velocity
261
+ - Moved non-dimensionaliztion out of domino datapipe to datapipe in domino example
262
+ - Updated utils in `physicsnemo.launch.logging` to avoid unnecessary `wandb` and `mlflow`
263
+ imports
264
+ - Moved to experiment-based Hydra config in Lagrangian-MGN example
265
+ - Make data caching optional in `MeshDatapipe`
266
+ - The use of older `importlib_metadata` library is removed
267
+
268
+ ### Deprecated
269
+
270
+ - ProcessGroupConfig is tagged for future deprecation in favor of DeviceMesh.
271
+
272
+ ### Fixed
273
+
274
+ - Update pytests to skip when the required dependencies are not present
275
+ - Bug in data processing script in domino training example
276
+ - Fixed NCCL_ASYNC_ERROR_HANDLING deprecation warning
277
+
278
+ ### Dependencies
279
+
280
+ - Remove the numpy dependency upper bound
281
+ - Moved pytz and nvtx to optional
282
+ - Update the base image for the Dockerfile
283
+ - Introduce Multi-Storage Client (MSC) as an optional dependency.
284
+ - Introduce `wrapt` as an optional dependency, needed when using
285
+ ShardTensor's automatic domain parallelism
286
+
287
+ ## [0.9.0] - 2024-12-04
288
+
289
+ ### Added
290
+
291
+ - Graph Transformer processor for GraphCast/GenCast.
292
+ - Utility to generate STL from Signed Distance Field.
293
+ - Metrics for CAE and CFD domain such as integrals, drag, and turbulence invariances and
294
+ spectrum.
295
+ - Added gradient clipping to StaticCapture utilities.
296
+ - Bistride Multiscale MeshGraphNet example.
297
+ - FIGConvUNet model and example.
298
+ - The Transolver model.
299
+ - The XAeroNet model.
300
+ - Incoporated CorrDiff-GEFS-HRRR model into CorrDiff, with lead-time aware SongUNet and
301
+ cross entropy loss.
302
+ - Option to offload checkpoints to further reduce memory usage
303
+ - Added StormCast model training and simple inference to examples
304
+ - Multi-scale geometry features for DoMINO model.
305
+
306
+ ### Changed
307
+
308
+ - Refactored CorrDiff training recipe for improved usability
309
+ - Fixed timezone calculation in datapipe cosine zenith utility.
310
+ - Refactored EDMPrecondSRV2 preconditioner and fixed the bug related to the metadata
311
+ - Extended the checkpointing utility to store metadata.
312
+ - Corrected missing export of loggin function used by transolver model
313
+
314
+ ## [0.8.0] - 2024-09-24
315
+
316
+ ### Added
317
+
318
+ - Graph Transformer processor for GraphCast/GenCast.
319
+ - Utility to generate STL from Signed Distance Field.
320
+ - Metrics for CAE and CFD domain such as integrals, drag, and turbulence invariances and
321
+ spectrum.
322
+ - Added gradient clipping to StaticCapture utilities.
323
+ - Bistride Multiscale MeshGraphNet example.
324
+
325
+ ### Changed
326
+
327
+ - Refactored CorrDiff training recipe for improved usability
328
+ - Fixed timezone calculation in datapipe cosine zenith utility.
329
+
330
+ ## [0.7.0] - 2024-07-23
331
+
332
+ ### Added
333
+
334
+ - Code logging for CorrDiff via Wandb.
335
+ - Augmentation pipeline for CorrDiff.
336
+ - Regression output as additional conditioning for CorrDiff.
337
+ - Learnable positional embedding for CorrDiff.
338
+ - Support for patch-based CorrDiff training and generation (stochastic sampling only)
339
+ - Enable CorrDiff multi-gpu generation
340
+ - Diffusion model for fluid data super-resolution (CMU contribution).
341
+ - The Virtual Foundry GraphNet.
342
+ - A synthetic dataloader for global weather prediction models, demonstrated on GraphCast.
343
+ - Sorted Empirical CDF CRPS algorithm
344
+ - Support for history, cos zenith, and downscaling/upscaling in the ERA5 HDF5 dataloader.
345
+ - An example showing how to train a "tensor-parallel" version of GraphCast on a
346
+ Shallow-Water-Equation example.
347
+ - 3D UNet
348
+ - AeroGraphNet example of training of MeshGraphNet on Ahmed body and DrivAerNet datasets.
349
+ - Warp SDF routine
350
+ - DLWP HEALPix model
351
+ - Pangu Weather model
352
+ - Fengwu model
353
+ - SwinRNN model
354
+ - Modulated AFNO model
355
+
356
+ ### Changed
357
+
358
+ - Raise `PhysicsNeMoUndefinedGroupError` when querying undefined process groups
359
+ - Changed Indexing error in `examples/cfd/swe_nonlinear_pino` for `physicsnemo` loss function
360
+ - Safeguarding against uninitialized usage of `DistributedManager`
361
+
362
+ ### Removed
363
+
364
+ - Remove mlflow from deployment image
365
+
366
+ ### Fixed
367
+
368
+ - Fixed bug in the partitioning logic for distributing graph structures
369
+ intended for distributed message-passing.
370
+ - Fixed bugs for corrdiff diffusion training of `EDMv1` and `EDMv2`
371
+ - Fixed bug when trying to save DDP model trained through unified recipe
372
+
373
+ ### Dependencies
374
+
375
+ - Update DALI to CUDA 12 compatible version.
376
+ - Update minimum python version to 3.10
377
+
378
+ ## [0.6.0] - 2024-04-17
379
+
380
+ ### Added
381
+
382
+ - The citation file.
383
+ - Link to the CWA dataset.
384
+ - ClimateDatapipe: an improved datapipe for HDF5/NetCDF4 formatted climate data
385
+ - Performance optimizations to CorrDiff.
386
+ - Physics-Informed Nonlinear Shallow Water Equations example.
387
+ - Warp neighbor search routine with a minimal example.
388
+ - Strict option for loading PhysicsNeMo checkpoints.
389
+ - Regression only or diffusion only inference for CorrDiff.
390
+ - Support for organization level model files on NGC file system
391
+ - Physics-Informed Magnetohydrodynamics example.
392
+
393
+ ### Changed
394
+
395
+ - Updated Ahmed Body and Vortex Shedding examples to use Hydra config.
396
+ - Added more config options to FCN AFNO example.
397
+ - Moved posiitonal embedding in CorrDiff from the dataloader to network architecture
398
+
399
+ ### Deprecated
400
+
401
+ - `physicsnemo.models.diffusion.preconditioning.EDMPrecondSR`. Use `EDMPecondSRV2` instead.
402
+
403
+ ### Removed
404
+
405
+ - Pickle dependency for CorrDiff.
406
+
407
+ ### Fixed
408
+
409
+ - Consistent handling of single GPU runs in DistributedManager
410
+ - Output location of objects downloaded with NGC file system
411
+ - Bug in scaling the conditional input in CorrDiff deterministic sampler
412
+
413
+ ### Dependencies
414
+
415
+ - Updated DGL build in Dockerfile
416
+ - Updated default base image
417
+ - Moved Onnx from optional to required dependencies
418
+ - Optional Makani dependency required for SFNO model.
419
+
420
+ ## [0.5.0] - 2024-01-25
421
+
422
+ ### Added
423
+
424
+ - Distributed process group configuration mechanism.
425
+ - DistributedManager utility to instantiate process groups based on a process group config.
426
+ - Helper functions to faciliate distributed training with shared parameters.
427
+ - Brain anomaly detection example.
428
+ - Updated Frechet Inception Distance to use Wasserstein 2-norm with improved stability.
429
+ - Molecular Dynamics example.
430
+ - Improved usage of GraphPartition, added more flexible ways of defining a partitioned graph.
431
+ - Physics-Informed Stokes Flow example.
432
+ - Profiling markers, benchmarking and performance optimizations for CorrDiff inference.
433
+ - Unified weather model training example.
434
+
435
+ ### Changed
436
+
437
+ - MLFLow logging such that only proc 0 logs to MLFlow.
438
+ - FNO given seperate methods for constructing lift and spectral encoder layers.
439
+
440
+ ### Removed
441
+
442
+ - The experimental SFNO
443
+
444
+ ### Dependencies
445
+
446
+ - Removed experimental SFNO dependencies
447
+ - Added CorrDiff dependencies (cftime, einops, pyspng, nvtx)
448
+ - Made tqdm a required dependency
449
+
450
+ ## [0.4.0] - 2023-11-20
451
+
452
+ ### Added
453
+
454
+ - Added Stokes flow dataset
455
+ - An experimental version of SFNO to be used in unified training recipe for
456
+ weather models
457
+ - Added distributed FFT utility.
458
+ - Added ruff as a linting tool.
459
+ - Ported utilities from PhysicsNeMo Launch to main package.
460
+ - EDM diffusion models and recipes for training and sampling.
461
+ - NGC model registry download integration into package/filesystem.
462
+ - Denoising diffusion tutorial.
463
+
464
+ ### Changed
465
+
466
+ - The AFNO input argument `img_size` to `inp_shape`
467
+ - Integrated the network architecture layers from PhysicsNeMo-Sym.
468
+ - Updated the SFNO model, and the training and inference recipes.
469
+
470
+ ### Fixed
471
+
472
+ - Fixed physicsnemo.Module `from_checkpoint` to work from custom model classes
473
+
474
+ ### Dependencies
475
+
476
+ - Updated the base container to PyTorch 23.10.
477
+ - Updated examples to use Pydantic v2.
478
+
479
+ ## [0.3.0] - 2023-09-21
480
+
481
+ ### Added
482
+
483
+ - Added ability to compute CRPS(..., dim: int = 0).
484
+ - Added EFI for arbitrary climatological CDF.
485
+ - Added Kernel CRPS implementation (kcrps)
486
+ - Added distributed utilities to create process groups and orthogonal process groups.
487
+ - Added distributed AFNO model implementation.
488
+ - Added distributed utilities for communication of buffers of varying size per rank.
489
+ - Added distributed utilities for message passing across multiple GPUs.
490
+ - Added instructions for docker build on ARM architecture.
491
+ - Added batching support and fix the input time step for the DLWP wrapper.
492
+
493
+ ### Changed
494
+
495
+ - Updating file system cache location to physicsnemo folder
496
+
497
+ ### Fixed
498
+
499
+ - Fixed physicsnemo uninstall in CI docker image
500
+
501
+ ### Security
502
+
503
+ - Handle the tar ball extracts in a safer way.
504
+
505
+ ### Dependencies
506
+
507
+ - Updated the base container to latest PyTorch 23.07.
508
+ - Update DGL version.
509
+ - Updated require installs for python wheel
510
+ - Added optional dependency list for python wheel
511
+
512
+ ## [0.2.1] - 2023-08-08
513
+
514
+ ### Fixed
515
+
516
+ - Added a workaround fix for the CUDA graphs error in multi-node runs
517
+
518
+ ### Security
519
+
520
+ - Update `certifi` package version
521
+
522
+ ## [0.2.0] - 2023-08-07
523
+
524
+ ### Added
525
+
526
+ - Added a CHANGELOG.md
527
+ - Added build support for internal DGL
528
+ - 4D Fourier Neural Operator model
529
+ - Ahmed body dataset
530
+ - Unified Climate Datapipe
531
+
532
+ ### Changed
533
+
534
+ - DGL install changed from pypi to source
535
+ - Updated SFNO to add support for super resolution, flexible checkpoining, etc.
536
+
537
+ ### Fixed
538
+
539
+ - Fixed issue with torch-harmonics version locking
540
+ - Fixed the PhysicsNeMo editable install
541
+ - Fixed AMP bug in static capture
542
+
543
+ ### Security
544
+
545
+ - Fixed security issues with subprocess and urllib in `filesystem.py`
546
+
547
+ ### Dependencies
548
+
549
+ - Updated the base container to latest PyTorch base container which is based on torch 2.0
550
+ - Container now supports CUDA 12, Python 3.10
551
+
552
+ ## [0.1.0] - 2023-05-08
553
+
554
+ ### Added
555
+
556
+ - Initial public release.
physics_mcp/source/CITATION.cff ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ cff-version: 1.2.0
2
+ message: "If you use this software, please cite it as below."
3
+ title: "NVIDIA PhysicsNeMo: An open-source framework for physics-based deep learning in science and engineering"
4
+ date-released: "2023-02-24"
5
+ authors:
6
+ - name: "PhysicsNeMo Contributors"
7
+ repository-code: "https://github.com/NVIDIA/physicsnemo"
physics_mcp/source/CONTRIBUTING.md ADDED
@@ -0,0 +1,251 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # PhysicsNeMo Contribution Guide
2
+
3
+ ## Introduction
4
+
5
+ Welcome to Project PhysicsNeMo! We're excited you're here and want to contribute.
6
+ This documentation is intended for individuals and institutions interested in
7
+ contributing to PhysicsNeMo. PhysicsNeMo is an open-source project and, as such, its
8
+ success relies on its community of contributors willing to keep improving it.
9
+ Your contribution will be a valued addition to the code base; we simply ask
10
+ that you read this page and understand our contribution process, whether you
11
+ are a seasoned open-source contributor or whether you are a first-time
12
+ contributor.
13
+
14
+ ### Communicate with Us
15
+
16
+ We are happy to talk with you about your needs for PhysicsNeMo and your ideas for
17
+ contributing to the project. One way to do this is to create an issue discussing
18
+ your thoughts. It might be that a very similar feature is under development or
19
+ already exists, so an issue is a great starting point. If you are looking for an
20
+ issue to resolve that will help, refer to the
21
+ [issue](https://github.com/NVIDIA/physicsnemo/issues) section.
22
+ If you are considering collaborating with NVIDIA PhysicsNeMo team to enhance PhysicsNeMo,
23
+ fill this [proposal form](https://forms.gle/fYsbZEtgRWJUQ3oQ9) and
24
+ we will get back to you.
25
+
26
+ ## Contribute to PhysicsNeMo-Core
27
+
28
+ ### Pull Requests
29
+
30
+ Developer workflow for code contributions is as follows:
31
+
32
+ 1. Developers must first [fork](https://help.github.com/en/articles/fork-a-repo)
33
+ the [upstream](https://github.com/NVIDIA/physicsnemo) PhysicsNeMo repository.
34
+
35
+ 2. Git clone the forked repository and push changes to the personal fork.
36
+
37
+ 3. Once the code changes are staged on the fork and ready for review, a
38
+ [Pull Request](https://help.github.com/en/articles/about-pull-requests) (PR)
39
+ can be [requested](https://help.github.com/en/articles/creating-a-pull-request)
40
+ to merge the changes from a branch of the fork into a selected branch of upstream.
41
+
42
+ - Exercise caution when selecting the source and target branches for the PR.
43
+ - Ensure that you update the [`CHANGELOG.md`](CHANGELOG.md) to reflect your contributions.
44
+ - Creation of a PR creation kicks off CI and a code review process.
45
+ - Atleast one PhysicsNeMo engineer will be assigned for the review.
46
+
47
+ 4. The PR will be accepted and the corresponding issue closed after adequate review and
48
+ testing has been completed. Note that every PR should correspond to an open issue and
49
+ should be linked on Github.
50
+
51
+ ### Licensing Information
52
+
53
+ All source code files should start with this paragraph:
54
+
55
+ ```bash
56
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
57
+ # SPDX-FileCopyrightText: All rights reserved.
58
+ # SPDX-License-Identifier: Apache-2.0
59
+ #
60
+ # Licensed under the Apache License, Version 2.0 (the "License");
61
+ # you may not use this file except in compliance with the License.
62
+ # You may obtain a copy of the License at
63
+ #
64
+ # http://www.apache.org/licenses/LICENSE-2.0
65
+ #
66
+ # Unless required by applicable law or agreed to in writing, software
67
+ # distributed under the License is distributed on an "AS IS" BASIS,
68
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
69
+ # See the License for the specific language governing permissions and
70
+ # limitations under the License.
71
+ ```
72
+
73
+ ### Signing Your Work
74
+
75
+ - We require that all contributors "sign-off" on their commits. This certifies that the
76
+ contribution is your original work, or you have rights to submit it under the same
77
+ license, or a compatible license.
78
+
79
+ - Any contribution which contains commits that are not Signed-Off will not be accepted.
80
+
81
+ - To sign off on a commit you simply use the `--signoff` (or `-s`) option when
82
+ committing your changes:
83
+
84
+ ```bash
85
+ git commit -s -m "Add cool feature."
86
+ ```
87
+
88
+ This will append the following to your commit message:
89
+
90
+ ```text
91
+ Signed-off-by: Your Name <your@email.com>
92
+ ```
93
+
94
+ - Full text of the DCO:
95
+
96
+ ```text
97
+ Developer Certificate of Origin
98
+ Version 1.1
99
+
100
+ Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
101
+ 1 Letterman Drive
102
+ Suite D4700
103
+ San Francisco, CA, 94129
104
+
105
+ Everyone is permitted to copy and distribute verbatim copies of this license
106
+ document, but changing it is not allowed.
107
+ ```
108
+
109
+ ```text
110
+ Developer's Certificate of Origin 1.1
111
+
112
+ By making a contribution to this project, I certify that:
113
+
114
+ (a) The contribution was created in whole or in part by me and I have the right to
115
+ submit it under the open source license indicated in the file; or
116
+
117
+ (b) The contribution is based upon previous work that, to the best of my knowledge,
118
+ is covered under an appropriate open source license and I have the right under that
119
+ license to submit that work with modifications, whether created in whole or in part
120
+ by me, under the same open source license (unless I am permitted to submit under a
121
+ different license), as indicated in the file; or
122
+
123
+ (c) The contribution was provided directly to me by some other person who certified
124
+ (a), (b) or (c) and I have not modified it.
125
+
126
+ (d) I understand and agree that this project and the contribution are public and
127
+ that a record of the contribution (including all personal information I submit with
128
+ it, including my sign-off) is maintained indefinitely and may be redistributed
129
+ consistent with this project or the open source license(s) involved.
130
+
131
+ ```
132
+
133
+ ### Pre-commit
134
+
135
+ For PhysicsNeMo development, [pre-commit](https://pre-commit.com/) is **required**.
136
+ This will not only help developers pass the CI pipeline, but also accelerate reviews.
137
+ Contributions that have not used pre-commit will *not be reviewed*.
138
+
139
+ `pre-commit` is installed as part of the `dev` optional dependencies defined in `pyproject.toml`.
140
+ To install `pre-commit` in an existing environment, follow the below steps inside the PhysicsNeMo
141
+ repository folder:
142
+
143
+ ```bash
144
+ pip install pre-commit
145
+ pre-commit install
146
+ ```
147
+
148
+ Once the above commands are executed, the pre-commit hooks will be activated and all
149
+ the commits will be checked for appropriate formatting.
150
+
151
+ ### Continuous Integration (CI)
152
+
153
+ To ensure quality of the code, your merge request (MR) will pass through several CI checks.
154
+ It is mandatory for your MRs to pass these pipelines to ensure a successful merge.
155
+ Please keep checking this document for the latest guidelines on pushing code. Currently,
156
+ The pipeline has following stages:
157
+
158
+ 1. `format`
159
+ *Pre-commit will check this for you!* Checks for formatting of your
160
+ Python code, using `ruff format` via [Ruff](https://docs.astral.sh/ruff/).
161
+ If your MR fails this test, run `ruff format <script-name>.py` on
162
+ problematic scripts and Ruff will take care of the rest.
163
+
164
+ 2. `interrogate`
165
+ *Pre-commit will check this for you!*
166
+ Checks if the code being pushed is well documented. The goal is to make the
167
+ documentation live inside code. Very few exceptions are made.
168
+ Elements that are fine to have no documentation include `init-module`, `init-method`,
169
+ `private` and `semiprivate` classes/functions and `dunder` methods. For definitions of
170
+ these, refer [interrogate](https://interrogate.readthedocs.io/en/latest/). Meaning for
171
+ some methods/functions is very explicit and exceptions for these are made. These
172
+ include `forward`, `reset_parameters`, `extra_repr`, `MetaData`. If your MR fails this
173
+ test, add the missing documentation. Take a look at the pipeline output for hints on
174
+ which functions/classes need documentation.
175
+ To test the documentation before making a commit, you can run the following during
176
+ your development
177
+
178
+ ```bash
179
+ interrogate \
180
+ --ignore-init-method \
181
+ --ignore-init-module \
182
+ --ignore-module \
183
+ --ignore-private \
184
+ --ignore-semiprivate \
185
+ --ignore-magic \
186
+ --fail-under 99 \
187
+ --exclude '[setup.py]' \
188
+ --ignore-regex forward \
189
+ --ignore-regex reset_parameters \
190
+ --ignore-regex extra_repr \
191
+ --ignore-regex MetaData \
192
+ -vv \
193
+ --color \
194
+ ./physicsnemo/
195
+ ```
196
+
197
+ 3. `lint`
198
+ *Pre-commit will check this for you!*
199
+ Linters will perform static analysis to check the style, complexity, errors
200
+ and more. For markdown files `markdownlint` is used, its suggested to use
201
+ the vscode, neovim or sublime
202
+ [extensions](https://github.com/DavidAnson/markdownlint#related).
203
+ PhysicsNeMo uses `ruff check` via[Ruff](https://docs.astral.sh/ruff/) for
204
+ linting of various types. Currently we use flake8/pycodestyle (`E`),
205
+ Pyflakes (`F`), flake8-bandit (`S`), isort (`I`), and performance 'PERF'
206
+ rules. Many rule violations will be automatically fixed by Ruff; others may
207
+ require manual changes.
208
+
209
+ 4. `license`
210
+ *Pre-commit will check this for you!*
211
+ Checks for correct license headers of all files.
212
+ To run this locally use `make license`.
213
+ See the Licensing Information section above for details about the license header required.
214
+
215
+ 5. `pytest`
216
+ Checks if the test scripts from the `test` folder run and produce desired outputs. It
217
+ is imperative that your changes don't break the existing tests. If your MR fails this
218
+ test, you will have to review your changes and fix the issues.
219
+ To run pytest locally you can simply run `pytest` inside the `test` folder.
220
+
221
+ While writing these tests, we encourage you to make use of the [`@import_of_fail`](https://github.com/NVIDIA/physicsnemo/blob/main/test/pytest_utils.py#L25)
222
+ decorator to appropriately skip your tests for developers and users not having your
223
+ test specific dependencies. This mechanism helps us provide a better developer and
224
+ user experience when working with the unit tests.
225
+
226
+ Some of the tests require test data to be run; otherwise, they will be skipped.
227
+ To get the data (available to NVIDIANs only), set the `TEST_DATA_DIR` environment variable
228
+ to a desired value and run make get-data. After that, pytest will use the same
229
+ variable to find the test data. Alternatively, you can pass it explicitly using
230
+ `pytest --nfs-data-dir=<path to test data>`.
231
+
232
+ 6. `doctest`
233
+ Checks if the examples in the docstrings run and produce desired outputs.
234
+ It is highly recommended that you provide simple examples of your functions/classes
235
+ in the code's docstring itself.
236
+ Keep these examples simple and also add the expected outputs.
237
+ Refer [doctest](https://docs.python.org/3/library/doctest.html) for more information.
238
+ If your MR fails this test, check your changes and the docstrings.
239
+ To run doctest locally, you can simply run `pytest --doctest-modules` inside the
240
+ `physicsnemo` folder.
241
+
242
+ 7. `coverage`
243
+ Checks if your code additions have sufficient coverage.
244
+ Refer [coverage](https://coverage.readthedocs.io/en/6.5.0/index.html#) for more details.
245
+ If your MR fails this test, this means that you have not added enough tests to the `test`
246
+ folder for your module/functions.
247
+ Add extensive test scripts to cover different
248
+ branches and lines of your additions.
249
+ Aim for more than 80% code coverage.
250
+ To test coverage locally, run the `get_coverage.sh` script from the `test` folder and
251
+ check the coverage of the module that you added/edited.
physics_mcp/source/FAQ.md ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Frequently Asked Questions about PhysicsNeMo
2
+
3
+ ## Table of contents
4
+
5
+ - [What is the recommended hardware for training using PhysicsNeMo framework?](#what-is-the-recommended-hardware-for-training-using-physicsnemo-framework)
6
+ - [What model architectures are in PhysicsNeMo?](#what-model-architectures-are-in-physicsnemo)
7
+ - [What is the difference between PhysicsNeMo Core and Symbolic?](#what-is-the-difference-between-physicsnemo-core-and-symbolic)
8
+ - [What can I do if I dont see a PDE in PhysicsNeMo?](#what-can-i-do-if-i-dont-see-a-pde-in-physicsnemo)
9
+ - [What is the difference between the pip install and the container?](#what-is-the-difference-between-the-pip-install-and-the-container)
10
+
11
+ ## What is the recommended hardware for training using PhysicsNeMo framework?
12
+
13
+ Please refer to the recommended hardware section:
14
+ [System Requirements](https://docs.nvidia.com/deeplearning/physicsnemo/getting-started/index.html#system-requirements)
15
+
16
+ ## What model architectures are in PhysicsNeMo?
17
+
18
+ Nvidia PhysicsNeMo is built on top of PyTorch and you can build and train any model
19
+ architecture you want in PhysicsNeMo. PhysicsNeMo however has a catalog of models that
20
+ have been packaged in a configurable form to make it easy to retrain with new data or certain
21
+ config parameters. Examples include GNNs like MeshGraphNet or Neural Operators like FNO.
22
+ PhysicsNeMo samples have more models that illustrate how a specific approach with a specific
23
+ model architecture can be applied to a specific problem.
24
+ These are reference starting points for users to get started.
25
+
26
+ You can find the list of built in model architectures
27
+ [here](https://github.com/NVIDIA/physicsnemo/tree/main/physicsnemo/models) and
28
+ [here](https://github.com/NVIDIA/physicsnemo-sym/tree/main/physicsnemo/sym/models)
29
+
30
+ ## What is the difference between PhysicsNeMo Core and Symbolic?
31
+
32
+ PhysicsNeMo core is the foundational module that provides the core algorithms, network
33
+ architectures and utilities that cover a broad spectrum of Physics-ML approaches.
34
+ PhysicsNeMo Symbolic provides pythonic APIs, algorithms and utilities to be used with
35
+ PhysicsNeMo core, to explicitly physics inform the model training. This includes symbolic
36
+ APIs for PDEs, domain sampling and PDE-based residuals. It also provides higher level
37
+ abstraction to compose a training loop from specification of the geometry, PDEs and
38
+ constraints like boundary conditions using simple symbolic APIs.
39
+ So if you are familiar with PyTorch and want to train model from a dataset, you start
40
+ with PhysicsNeMo core and you import PhysicsNeMo symbolic to bring in explicit domain knowledge.
41
+ Please refer to the [DeepONet example](https://github.com/physicsnemo/tree/main/examples/cfd/darcy_deeponet_physics)
42
+ that illustrates the concept.
43
+ If you are an engineer or domain expert accustomed to using numerical solvers, you can
44
+ use PhysicsNeMo Symbolic to define your problem at a higher level of abstraction. Please
45
+ refer to the [Lid Driven cavity](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/basics/lid_driven_cavity_flow.html)
46
+ that illustrates the concept.
47
+
48
+ ## What can I do if I dont see a PDE in PhysicsNeMo?
49
+
50
+ PhysicsNeMo Symbolic provides a well documented
51
+ [example](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/foundational/1d_wave_equation.html#writing-custom-pdes-and-boundary-initial-conditions)
52
+ that walks you through how to define a custom PDE. Please see the source [here](https://github.com/NVIDIA/physicsnemo-sym/tree/main/physicsnemo/sym/eq/pdes)
53
+ to see the built-in PDE implementation as an additional reference for your own implementation.
54
+
55
+ ## What is the difference between the pip install and the container?
56
+
57
+ There is no functional difference between the two. This is to simplify the ease of
58
+ installing and setting up the PhysicsNeMo environment. Please refer to the
59
+ [getting started guide](https://docs.nvidia.com/deeplearning/physicsnemo/getting-started/index.html#physicsnemo-with-docker-image-recommended)
60
+ on how to install using Pip or using the container.
physics_mcp/source/LICENSE.txt ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "{}"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright 2022 NVIDIA Corporation
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
physics_mcp/source/README.md ADDED
@@ -0,0 +1,472 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # NVIDIA PhysicsNeMo
2
+
3
+ <!-- markdownlint-disable -->
4
+
5
+ 📝 NVIDIA Modulus has been renamed to NVIDIA PhysicsNeMo
6
+
7
+ [![Project Status: Active - The project has reached a stable, usable state and is being actively developed.](https://www.repostatus.org/badges/latest/active.svg)](https://www.repostatus.org/#active)
8
+ [![GitHub](https://img.shields.io/github/license/NVIDIA/physicsnemo)](https://github.com/NVIDIA/physicsnemo/blob/master/LICENSE.txt)
9
+ [![Code style: black](https://img.shields.io/badge/code%20style-black-000000.svg)](https://github.com/psf/black)
10
+ <!-- markdownlint-enable -->
11
+ [**NVIDIA PhysicsNeMo**](#what-is-physicsnemo)
12
+ | [**Documentation**](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/index.html)
13
+ | [**Install Guide**](#installation)
14
+ | [**Getting Started**](#getting-started)
15
+ | [**Contributing Guidelines**](#contributing-to-physicsnemo)
16
+ | [**License**](#license)
17
+
18
+ ## What is PhysicsNeMo?
19
+
20
+ NVIDIA PhysicsNeMo is an open-source deep-learning framework for building, training,
21
+ fine-tuning, and inferring Physics AI models using state-of-the-art SciML methods for
22
+ AI4Science and engineering.
23
+
24
+ PhysicsNeMo provides Python modules to compose scalable and optimized training and
25
+ inference pipelines to explore, develop, validate, and deploy AI models that combine
26
+ physics knowledge with data, enabling real-time predictions.
27
+
28
+ Whether you are exploring the use of neural operators, GNNs, or transformers, or are
29
+ interested in Physics-Informed Neural Networks or a hybrid approach in between, PhysicsNeMo
30
+ provides you with an optimized stack that will enable you to train your models at scale.
31
+
32
+ <!-- markdownlint-disable -->
33
+ <p align="center">
34
+ <img src=https://raw.githubusercontent.com/NVIDIA/physicsnemo/main/docs/img/value_prop/Knowledge_guided_models.gif alt="PhysicsNeMo"/>
35
+ </p>
36
+ <!-- markdownlint-enable -->
37
+
38
+ <!-- toc -->
39
+
40
+ - [More About PhysicsNeMo](#more-about-physicsnemo)
41
+ - [Scalable GPU-Optimized Training Library](#scalable-gpu-optimized-training-library)
42
+ - [A Suite of Physics-Informed ML Models](#a-suite-of-physics-informed-ml-models)
43
+ - [Seamless PyTorch Integration](#seamless-pytorch-integration)
44
+ - [Easy Customization and Extension](#easy-customization-and-extension)
45
+ - [AI4Science Library](#ai4science-library)
46
+ - [Domain-Specific Packages](#domain-specific-packages)
47
+ - [Who is Using and Contributing to PhysicsNeMo](#who-is-using-and-contributing-to-physicsnemo)
48
+ - [Why Use PhysicsNeMo](#why-are-they-using-physicsnemo)
49
+ - [Getting Started](#getting-started)
50
+ - [Resources](#resources)
51
+ - [Installation](#installation)
52
+ - [Contributing](#contributing-to-physicsnemo)
53
+ - [Communication](#communication)
54
+ - [License](#license)
55
+
56
+ <!-- tocstop -->
57
+
58
+ ## More About PhysicsNeMo
59
+
60
+ At a granular level, PhysicsNeMo is developed as modular functionality and therefore
61
+ provides built-in composable modules that are packaged into a few key components:
62
+
63
+ <!-- markdownlint-disable -->
64
+ Component | Description |
65
+ ---- | --- |
66
+ [**physicsnemo.models**](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.models.html) | A collection of optimized, customizable, and easy-to-use families of model architectures such as Neural Operators, Graph Neural Networks, Diffusion models, Transformer models, and many more|
67
+ [**physicsnemo.datapipes**](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.datapipes.html) | Optimized and scalable built-in data pipelines fine-tuned to handle engineering and scientific data structures like point clouds, meshes, etc.|
68
+ [**physicsnemo.distributed**](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.distributed.html) | A distributed computing sub-module built on top of `torch.distributed` to enable parallel training with just a few steps|
69
+ [**physicsnemo.curator**](https://github.com/NVIDIA/physicsnemo-curator) | A sub-module to streamline and accelerate the process of data curation for engineering datasets|
70
+ [**physicsnemo.sym.geometry**](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/features/csg_and_tessellated_module.html) | A sub-module to handle geometry for DL training using Constructive Solid Geometry modeling and CAD files in STL format|
71
+ [**physicsnemo.sym.eq**](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/features/nodes.html) | A sub-module to use PDEs in your DL training with several implementations of commonly observed equations and easy ways for customization|
72
+ <!-- markdownlint-enable -->
73
+
74
+ For a complete list, refer to the PhysicsNeMo API documentation for
75
+ [PhysicsNeMo](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/index.html).
76
+
77
+ ## AI4Science Library
78
+
79
+ Usually, PhysicsNeMo is used either as:
80
+
81
+ - A complementary tool to PyTorch when exploring AI for SciML and AI4Science applications.
82
+ - A deep learning research platform that provides scale and optimal performance on
83
+ NVIDIA GPUs.
84
+
85
+ ### Domain-Specific Packages
86
+
87
+ The following are packages dedicated to domain experts of specific communities, catering
88
+ to their unique exploration needs:
89
+
90
+ - [PhysicsNeMo CFD](https://github.com/NVIDIA/physicsnemo-cfd): Inference sub-module of PhysicsNeMo
91
+ to enable CFD domain experts to explore, experiment, and validate using pretrained
92
+ AI models for CFD use cases.
93
+ - [PhysicsNeMo Curator](https://github.com/NVIDIA/physicsnemo-curator): Inference sub-module
94
+ of PhysicsNeMo to streamline and accelerate the process of data curation for engineering
95
+ datasets.
96
+ - [Earth-2 Studio](https://github.com/NVIDIA/earth2studio): Inference sub-module of PhysicsNeMo
97
+ to enable climate researchers and scientists to explore and experiment with pretrained
98
+ AI models for weather and climate.
99
+
100
+ ### Scalable GPU-Optimized Training Library
101
+
102
+ PhysicsNeMo provides a highly optimized and scalable training library for maximizing the
103
+ power of NVIDIA GPUs.
104
+ [Distributed computing](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.distributed.html)
105
+ utilities allow for efficient scaling from a single GPU to multi-node GPU clusters with
106
+ a few lines of code, ensuring that large-scale
107
+ physics-informed machine learning (ML) models can be trained quickly and effectively.
108
+ The framework includes support for advanced
109
+ [optimization utilities](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.utils.html#module-physicsnemo.utils.capture),
110
+ [tailor-made datapipes](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.datapipes.html),
111
+ and [validation utilities](https://github.com/NVIDIA/physicsnemo-sym/tree/main/physicsnemo/sym/eq)
112
+ to enhance end-to-end training speed.
113
+
114
+ ### A Suite of Physics-Informed ML Models
115
+
116
+ PhysicsNeMo offers a library of state-of-the-art models specifically designed
117
+ for Physics-ML applications. Users can build any model architecture by using the underlying
118
+ PyTorch layers and combining them with curated PhysicsNeMo layers.
119
+
120
+ The [Model Zoo](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.models.html#model-zoo)
121
+ includes optimized implementations of families of model architectures such as
122
+ Neural Operators:
123
+
124
+ - [Fourier Neural Operators (FNOs)](physicsnemo/models/fno)
125
+ - [DeepONet](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/neural_operators/deeponet.html)
126
+ - [DoMINO](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/cfd/external_aerodynamics/domino/readme.html)
127
+ - [Graph Neural Networks (GNNs)](physicsnemo/models/gnn_layers)
128
+ - [MeshGraphNet](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/cfd/vortex_shedding_mgn/readme.html)
129
+ - [MeshGraphNet for Lagrangian](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/cfd/lagrangian_mgn/readme.html)
130
+ - [XAeroNet](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/cfd/external_aerodynamics/xaeronet/readme.html)
131
+ - [Diffusion Models](physicsnemo/models/diffusion)
132
+ - [Correction Diffusion Model](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/generative/corrdiff/readme.html)
133
+ - [DDPM](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/generative/diffusion/readme.html)
134
+ - [PhysicsNeMo GraphCast](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/weather/graphcast/readme.html)
135
+ - [Transsolver](https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/darcy_transolver)
136
+ - [RNNs](https://github.com/NVIDIA/physicsnemo/tree/main/physicsnemo/models)
137
+ - [SwinVRNN](https://github.com/NVIDIA/physicsnemo/tree/main/physicsnemo/models/swinvrnn)
138
+ - [Physics-Informed Neural Networks (PINNs)](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/foundational/1d_wave_equation.html)
139
+
140
+ And many others.
141
+
142
+ These models are optimized for various physics domains, such as computational fluid
143
+ dynamics, structural mechanics, and electromagnetics. Users can download, customize, and
144
+ build upon these models to suit their specific needs, significantly reducing the time
145
+ required to develop high-fidelity simulations.
146
+
147
+ ### Seamless PyTorch Integration
148
+
149
+ PhysicsNeMo is built on top of PyTorch, providing a familiar and user-friendly experience
150
+ for those already proficient with PyTorch.
151
+ This includes a simple Python interface and modular design, making it easy to use
152
+ PhysicsNeMo with existing PyTorch workflows.
153
+ Users can leverage the extensive PyTorch ecosystem, including its libraries and tools,
154
+ while benefiting from PhysicsNeMo's specialized capabilities for physics-ML. This seamless
155
+ integration ensures users can quickly adopt PhysicsNeMo without a steep learning curve.
156
+
157
+ For more information, refer to [Converting PyTorch Models to PhysicsNeMo Models](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.models.html#converting-pytorch-models-to-physicsnemo-models).
158
+
159
+ ### Easy Customization and Extension
160
+
161
+ PhysicsNeMo is designed to be highly extensible, allowing users to add new functionality
162
+ with minimal effort. The framework provides Pythonic APIs for
163
+ defining new physics models, geometries, and constraints, making it easy to extend its
164
+ capabilities to new use cases.
165
+ The adaptability of PhysicsNeMo is further enhanced by key features such as
166
+ [ONNX support](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.deploy.html)
167
+ for flexible model deployment,
168
+ robust [logging utilities](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.launch.logging.html)
169
+ for streamlined error handling,
170
+ and efficient
171
+ [checkpointing](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.launch.utils.html#module-physicsnemo.launch.utils.checkpoint)
172
+ to simplify model loading and saving.
173
+
174
+ This extensibility ensures that PhysicsNeMo can adapt to the evolving needs of researchers
175
+ and engineers, facilitating the development of innovative solutions in the field of physics-ML.
176
+
177
+ Detailed information on features and capabilities can be found in the [PhysicsNeMo documentation](https://docs.nvidia.com/physicsnemo/index.html#core).
178
+
179
+ [Reference samples](examples/README.md) cover a broad spectrum of physics-constrained
180
+ and data-driven
181
+ workflows to suit the diversity of use cases in the science and engineering disciplines.
182
+
183
+ > [!TIP]
184
+ > Have questions about how PhysicsNeMo can assist you? Try our [Experimental] chatbot,
185
+ > [PhysicsNeMo Guide](https://chatgpt.com/g/g-PXrBv20SC-modulus-guide), for answers.
186
+
187
+ ### Hello World
188
+
189
+ You can start using PhysicsNeMo in your PyTorch code as simply as shown here:
190
+
191
+ ```python
192
+ >>> import torch
193
+ >>> from physicsnemo.models.mlp.fully_connected import FullyConnected
194
+ >>> model = FullyConnected(in_features=32, out_features=64)
195
+ >>> input = torch.randn(128, 32)
196
+ >>> output = model(input)
197
+ >>> output.shape
198
+ torch.Size([128, 64])
199
+ ```
200
+
201
+ To use the distributed module, you can do the following (example for
202
+ distributed data parallel training; for a more in-depth tutorial, refer to
203
+ [PhysicsNeMo Distributed](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.distributed.html#)):
204
+
205
+ ```python
206
+ import torch
207
+ from torch.nn.parallel import DistributedDataParallel
208
+ from physicsnemo.distributed import DistributedManager
209
+ from physicsnemo.models.mlp.fully_connected import FullyConnected
210
+
211
+ def main():
212
+ DistributedManager.initialize()
213
+ dist = DistributedManager()
214
+
215
+ arch = FullyConnected(in_features=32, out_features=64).to(dist.device)
216
+
217
+ if dist.distributed:
218
+ ddps = torch.cuda.Stream()
219
+ with torch.cuda.stream(ddps):
220
+ arch = DistributedDataParallel(
221
+ arch,
222
+ device_ids=[dist.local_rank],
223
+ output_device=dist.device,
224
+ broadcast_buffers=dist.broadcast_buffers,
225
+ find_unused_parameters=dist.find_unused_parameters,
226
+ )
227
+ torch.cuda.current_stream().wait_stream(ddps)
228
+
229
+ # Set up the optimizer
230
+ optimizer = torch.optim.Adam(
231
+ arch.parameters(),
232
+ lr=0.001,
233
+ )
234
+
235
+ def training_step(invar, target):
236
+ pred = arch(invar)
237
+ loss = torch.sum(torch.pow(pred - target, 2))
238
+ loss.backward()
239
+ optimizer.step()
240
+ return loss
241
+
242
+ # Sample training loop
243
+ for i in range(20):
244
+ # Random inputs and targets for simplicity
245
+ input = torch.randn(128, 32, device=dist.device)
246
+ target = torch.randn(128, 64, device=dist.device)
247
+
248
+ # Training step
249
+ loss = training_step(input, target)
250
+
251
+ if __name__ == "__main__":
252
+ main()
253
+ ```
254
+
255
+ To use the PDE module, you can do the following:
256
+
257
+ ```python
258
+ >>> from physicsnemo.sym.eq.pdes.navier_stokes import NavierStokes
259
+ >>> ns = NavierStokes(nu=0.01, rho=1, dim=2)
260
+ >>> ns.pprint()
261
+ continuity: u__x + v__y
262
+ momentum_x: u*u__x + v*u__y + p__x + u__t - 0.01*u__x__x - 0.01*u__y__y
263
+ momentum_y: u*v__x + v*v__y + p__y + v__t - 0.01*v__x__x - 0.01*v__y__y
264
+ ```
265
+
266
+ ## Who is Using and Contributing to PhysicsNeMo
267
+
268
+ PhysicsNeMo is an open-source project and gets contributions from researchers in
269
+ the SciML and AI4Science fields. While the PhysicsNeMo team works on optimizing the
270
+ underlying software stack, the community collaborates and contributes model architectures,
271
+ datasets, and reference applications so we can innovate in the pursuit of
272
+ developing generalizable model architectures and algorithms.
273
+
274
+ Some recent examples of community contributors are the [HP Labs 3D Printing team](https://developer.nvidia.com/blog/spotlight-hp-3d-printing-and-nvidia-physicsnemo-collaborate-on-open-source-manufacturing-digital-twin/),
275
+ [Stanford Cardiovascular research team](https://developer.nvidia.com/blog/enabling-greater-patient-specific-cardiovascular-care-with-ai-surrogates/),
276
+ [UIUC team](https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/mhd_pino),
277
+ [CMU team](https://github.com/NVIDIA/physicsnemo/tree/main/examples/generative/diffusion),
278
+ etc.
279
+
280
+ Recent examples of research teams using PhysicsNeMo are the
281
+ [ORNL team](https://arxiv.org/abs/2404.05768),
282
+ [TU Munich CFD team](https://www.nvidia.com/en-us/on-demand/session/gtc24-s62237/), etc.
283
+
284
+ Please navigate to this page for a complete list of research work leveraging PhysicsNeMo.
285
+ For a list of enterprises using PhysicsNeMo, refer to the [PhysicsNeMo Webpage](https://developer.nvidia.com/physicsnemo).
286
+
287
+ Using PhysicsNeMo and interested in showcasing your work on
288
+ [NVIDIA Blogs](https://developer.nvidia.com/blog/category/simulation-modeling-design/)?
289
+ Fill out this [proposal form](https://forms.gle/XsBdWp3ji67yZAUF7) and we will get back
290
+ to you!
291
+
292
+ ## Why Are They Using PhysicsNeMo
293
+
294
+ Here are some of the key benefits of PhysicsNeMo for SciML model development:
295
+
296
+ <!-- markdownlint-disable -->
297
+ <img src="docs/img/value_prop/benchmarking.svg" width="100"> | <img src="docs/img/value_prop/recipe.svg" width="100"> | <img src="docs/img/value_prop/performance.svg" width="100">
298
+ ---|---|---|
299
+ |SciML Benchmarking and Validation|Ease of Using Generalized SciML Recipes with Heterogeneous Datasets |Out-of-the-Box Performance and Scalability
300
+ |PhysicsNeMo enables researchers to benchmark their AI models against proven architectures for standard benchmark problems with detailed domain-specific validation criteria.|PhysicsNeMo enables researchers to pick from state-of-the-art SciML architectures and use built-in data pipelines for their use case.| PhysicsNeMo provides out-of-the-box performant training pipelines, including optimized ETL pipelines for heterogeneous engineering and scientific datasets and out-of-the-box scaling across multi-GPU and multi-node GPUs.
301
+ <!-- markdownlint-enable -->
302
+
303
+ See what your peer SciML researchers are saying about PhysicsNeMo (coming soon).
304
+
305
+ ## Getting Started
306
+
307
+ The following resources will help you learn how to use PhysicsNeMo. The best
308
+ way is to start with a reference sample and then update it for your own use case.
309
+
310
+ - [Using PhysicsNeMo with your PyTorch model](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/tutorials/simple_training_example.html#using-custom-models-in-physicsnemo)
311
+ - [Using PhysicsNeMo built-in models](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/tutorials/simple_training_example.html#using-built-in-models)
312
+ - [Getting Started Guide](https://docs.nvidia.com/deeplearning/physicsnemo/getting-started/index.html)
313
+ - [Reference Samples](https://github.com/NVIDIA/physicsnemo/blob/main/examples/README.md)
314
+ - [User Guide Documentation](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/index.html)
315
+
316
+ ## Resources
317
+
318
+ - [Getting Started Webinar](https://www.nvidia.com/en-us/on-demand/session/gtc24-dlit61460/?playlistId=playList-bd07f4dc-1397-4783-a959-65cec79aa985)
319
+ - [AI4Science PhysicsNeMo Bootcamp](https://github.com/openhackathons-org/End-to-End-AI-for-Science)
320
+ - [PhysicsNeMo Pretrained Models](https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=PhysicsNeMo&page=&pageSize=)
321
+ - [PhysicsNeMo Datasets and Supplementary Materials](https://catalog.ngc.nvidia.com/resources?filters=&orderBy=scoreDESC&query=PhysicsNeMo&page=&pageSize=)
322
+ - [Self-Paced PhysicsNeMo DLI Training](https://learn.nvidia.com/courses/course-detail?course_id=course-v1:DLI+S-OV-04+V1)
323
+ - [Deep Learning for Science and Engineering Lecture Series with PhysicsNeMo](https://www.nvidia.com/en-us/on-demand/deep-learning-for-science-and-engineering/)
324
+ - [PhysicsNeMo: Purpose and Usage](https://www.nvidia.com/en-us/on-demand/session/dliteachingkit-setk5002/)
325
+ - [Video Tutorials](https://www.nvidia.com/en-us/on-demand/search/?facet.mimetype[]=event%20session&layout=list&page=1&q=physicsnemo&sort=relevance&sortDir=desc)
326
+
327
+ ## Installation
328
+
329
+ The following instructions help you install the base PhysicsNeMo modules to get started.
330
+ There are additional optional dependencies for specific models that are listed under
331
+ [optional dependencies](#optional-dependencies).
332
+ The training recipes are not packaged into the pip wheels or the container to keep the
333
+ footprint low. We recommend users clone the appropriate training recipes and use them
334
+ as a starting point. These training recipes may require additional example-specific dependencies,
335
+ as indicated through their associated `requirements.txt` file.
336
+
337
+ ### PyPI
338
+
339
+ The recommended method for installing the latest version of PhysicsNeMo is using PyPI:
340
+
341
+ ```Bash
342
+ pip install nvidia-physicsnemo
343
+ ```
344
+
345
+ The installation can be verified by running the [Hello World](#hello-world) example.
346
+
347
+ #### Optional Dependencies
348
+
349
+ PhysicsNeMo has many optional dependencies that are used in specific components.
350
+ When using pip, all dependencies used in PhysicsNeMo can be installed with
351
+ `pip install nvidia-physicsnemo[all]`. If you are developing PhysicsNeMo, developer dependencies
352
+ can be installed using `pip install nvidia-physicsnemo[dev]`. Otherwise, additional dependencies
353
+ can be installed on a case-by-case basis. Detailed information on installing the
354
+ optional dependencies can be found in the
355
+ [Getting Started Guide](https://docs.nvidia.com/deeplearning/physicsnemo/getting-started/index.html).
356
+
357
+ ### NVCR Container
358
+
359
+ The recommended PhysicsNeMo Docker image can be pulled from the
360
+ [NVIDIA Container Registry](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/physicsnemo/containers/physicsnemo)
361
+ (refer to the NGC registry for the latest tag):
362
+
363
+ ```Bash
364
+ docker pull nvcr.io/nvidia/physicsnemo/physicsnemo:25.06
365
+ ```
366
+
367
+ Inside the container, you can clone the PhysicsNeMo git repositories and get
368
+ started with the examples. The command below shows the instructions to launch
369
+ the PhysicsNeMo container and run examples from this repo:
370
+
371
+ ```bash
372
+ docker run --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --runtime nvidia \
373
+ --rm -it nvcr.io/nvidia/physicsnemo/physicsnemo:25.06 bash
374
+ git clone https://github.com/NVIDIA/physicsnemo.git
375
+ cd physicsnemo/examples/cfd/darcy_fno/
376
+ pip install warp-lang # install NVIDIA Warp to run the Darcy example
377
+ python train_fno_darcy.py
378
+ ```
379
+
380
+ ## From Source
381
+
382
+ ### Package
383
+
384
+ For a local build of the PhysicsNeMo Python package from source, use:
385
+
386
+ ```Bash
387
+ git clone git@github.com:NVIDIA/physicsnemo.git && cd physicsnemo
388
+
389
+ pip install --upgrade pip
390
+ pip install .
391
+ ```
392
+
393
+ ### Source Container
394
+
395
+ To build the PhysicsNeMo Docker image:
396
+
397
+ ```bash
398
+ docker build -t physicsnemo:deploy \
399
+ --build-arg TARGETPLATFORM=linux/amd64 --target deploy -f Dockerfile .
400
+ ```
401
+
402
+ Alternatively, you can run `make container-deploy`.
403
+
404
+ To build the CI image:
405
+
406
+ ```bash
407
+ docker build -t physicsnemo:ci \
408
+ --build-arg TARGETPLATFORM=linux/amd64 --target ci -f Dockerfile .
409
+ ```
410
+
411
+ Alternatively, you can run `make container-ci`.
412
+
413
+ Currently, only `linux/amd64` and `linux/arm64` platforms are supported. If using
414
+ `linux/arm64`, some dependencies like `warp-lang` might not install correctly.
415
+
416
+ ## PhysicsNeMo Migration Guide
417
+
418
+ NVIDIA Modulus has been renamed to NVIDIA PhysicsNeMo. For migration:
419
+
420
+ - Use `pip install nvidia-physicsnemo` rather than `pip install nvidia-modulus`
421
+ for PyPI wheels.
422
+ - Use `nvcr.io/nvidia/physicsnemo/physicsnemo:<tag>` rather than
423
+ `nvcr.io/nvidia/modulus/modulus:<tag>` for Docker containers.
424
+ - Replace `nvidia-modulus` with `nvidia-physicsnemo` in your pip requirements
425
+ files (`requirements.txt`, `setup.py`, `setup.cfg`, `pyproject.toml`, etc.).
426
+ - In your code, change the import statements from `import modulus` to
427
+ `import physicsnemo`.
428
+
429
+ The old PyPI registry and the NGC container registry will be deprecated soon
430
+ and will not receive any bug fixes/updates. The old checkpoints will remain
431
+ compatible with these updates.
432
+
433
+ More details to follow soon.
434
+
435
+ ## DGL to PyTorch Geometric Migration Guide
436
+
437
+ PhysicsNeMo supports a wide range of Graph Neural Networks (GNNs),
438
+ including MeshGraphNet and others.
439
+ Currently, PhysicsNeMo uses the DGL library as its GNN backend,
440
+ with plans to completely transition to PyTorch Geometric (PyG) in a future release.
441
+ For more details, please refer to the [DGL-to-PyG migration guide](https://github.com/NVIDIA/physicsnemo/blob/main/examples/dgl_to_pyg_migration.md).
442
+
443
+ ## Contributing to PhysicsNeMo
444
+
445
+ PhysicsNeMo is an open-source collaboration, and its success is rooted in community
446
+ contributions to further the field of Physics-ML. Thank you for contributing to the
447
+ project so others can build on top of your contributions.
448
+
449
+ For guidance on contributing to PhysicsNeMo, please refer to the
450
+ [contributing guidelines](CONTRIBUTING.md).
451
+
452
+ ## Cite PhysicsNeMo
453
+
454
+ If PhysicsNeMo helped your research and you would like to cite it, please refer to the [guidelines](https://github.com/NVIDIA/physicsnemo/blob/main/CITATION.cff).
455
+
456
+ ## Communication
457
+
458
+ - GitHub Discussions: Discuss new architectures, implementations, Physics-ML research, etc.
459
+ - GitHub Issues: Bug reports, feature requests, install issues, etc.
460
+ - PhysicsNeMo Forum: The [PhysicsNeMo Forum](https://forums.developer.nvidia.com/t/welcome-to-the-physicsnemo-ml-model-framework-forum/178556)
461
+ hosts an audience of new to moderate-level users and developers for general chat, online
462
+ discussions, collaboration, etc.
463
+
464
+ ## Feedback
465
+
466
+ Want to suggest some improvements to PhysicsNeMo? Use our [feedback form](https://docs.google.com/forms/d/e/1FAIpQLSfX4zZ0Lp7MMxzi3xqvzX4IQDdWbkNh5H_a_clzIhclE2oSBQ/viewform?usp=sf_link).
467
+
468
+ ## License
469
+
470
+ PhysicsNeMo is provided under the Apache License 2.0. Please see [LICENSE.txt](./LICENSE.txt)
471
+ for the full license text. Enterprise SLA, support, and preview access are available
472
+ under NVAIE.
physics_mcp/source/SECURITY.md ADDED
@@ -0,0 +1,34 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Security
2
+
3
+ NVIDIA is dedicated to the security and trust of our software products and
4
+ services, including all source code repositories managed through our organization.
5
+
6
+ If you need to report a security issue, please use the appropriate contact points
7
+ outlined below. **Please do not report security vulnerabilities through GitHub/GitLab.**
8
+
9
+ ## Reporting Potential Security Vulnerability in an NVIDIA Product
10
+
11
+ To report a potential security vulnerability in any NVIDIA product:
12
+
13
+ - Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html)
14
+ - E-Mail: `psirt@nvidia.com`
15
+ - We encourage you to use the following PGP key for secure email communication:
16
+ [NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key)
17
+ - Please include the following information:
18
+ - Product/Driver name and version/branch that contains the vulnerability
19
+ - Type of vulnerability (code execution, denial of service, buffer overflow, etc.)
20
+ - Instructions to reproduce the vulnerability
21
+ - Proof-of-concept or exploit code
22
+ - Potential impact of the vulnerability, including how an attacker could
23
+ exploit the vulnerability
24
+
25
+ While NVIDIA currently does not have a bug bounty program, we do offer
26
+ acknowledgement when an externally reported security issue is addressed under our
27
+ coordinated vulnerability disclosure policy. Please visit our
28
+ [Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/)
29
+ policies page for more information.
30
+
31
+ ## NVIDIA Product Security
32
+
33
+ For all security-related concerns, please visit NVIDIA's Product Security portal
34
+ at `https://www.nvidia.com/en-us/security`
physics_mcp/source/__init__.py ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ # -*- coding: utf-8 -*-
2
+ """
3
+ physicsnemo Project Package Initialization File
4
+ """
physics_mcp/source/greptile.json ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "comment": "",
3
+ "fixWithAI": false,
4
+ "commentTypes": [
5
+ "logic",
6
+ "syntax",
7
+ "style"
8
+ ],
9
+ "instructions": "",
10
+ "excludeAuthors": [
11
+ "dependabot[bot]",
12
+ "renovate[bot]"
13
+ ],
14
+ "ignorePatterns": "greptile.json\n",
15
+ "summarySection": {
16
+ "included": true,
17
+ "collapsible": false,
18
+ "defaultOpen": false
19
+ },
20
+ "triggerOnUpdates": false,
21
+ "updateSummaryOnly": false,
22
+ "issuesTableSection": {
23
+ "included": true,
24
+ "collapsible": false,
25
+ "defaultOpen": false
26
+ },
27
+ "confidenceScoreSection": {
28
+ "included": false,
29
+ "collapsible": false,
30
+ "defaultOpen": false
31
+ },
32
+ "sequenceDiagramSection": {
33
+ "included": false,
34
+ "collapsible": false,
35
+ "defaultOpen": false
36
+ },
37
+ "shouldUpdateDescription": false,
38
+ "customContext": {
39
+ "other": [
40
+ {
41
+ "scope": [],
42
+ "content": ""
43
+ }
44
+ ],
45
+ "rules": [
46
+ {
47
+ "scope": [],
48
+ "rule": ""
49
+ }
50
+ ],
51
+ "files": [
52
+ {
53
+ "scope": [],
54
+ "path": "",
55
+ "description": ""
56
+ }
57
+ ]
58
+ }
59
+ }
physics_mcp/source/physicsnemo/__init__.py ADDED
@@ -0,0 +1,22 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from .datapipes.datapipe import Datapipe
18
+ from .datapipes.meta import DatapipeMetaData
19
+ from .models.meta import ModelMetaData
20
+ from .models.module import Module
21
+
22
+ __version__ = "1.3.0a0"
physics_mcp/source/physicsnemo/active_learning/README.md ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Active Learning Module
2
+
3
+ The `physicsnemo.active_learning` namespace is used for defining the "scaffolding"
4
+ that can be used to construct automated, end-to-end active learning workflows.
5
+ For areas of science that are difficult to source ground-truths to train on
6
+ (of which there are many), an active learning curriculum attempts to train a
7
+ model with improved data efficiency; better generalization performance but requiring
8
+ fewer training samples.
9
+
10
+ Generally, an active learning workflow can be decomposed into three "phases"
11
+ that are - in the simplest case - run sequentially:
12
+
13
+ - **Training/fine-tuning**: A "learner" or surrogate model is initially trained
14
+ on available data, and in subsequent active learning iterations, is fine-tuned
15
+ with the new data appended on the original dataset.
16
+ - **Querying**: One or more strategies that encode some heuristics for what
17
+ new data is most informative for the learner. Examples of this include
18
+ uncertainty-based methods, which may screen a pool of unlabeled data for
19
+ those the model is least confident with.
20
+ - **Labeling**: A method of obtaining ground truth (labels) for new data
21
+ points, pipelined from the querying stage. This may entail running an
22
+ expensive solver, or acquiring experimental data.
23
+
24
+ The three phases are repeated until the learner converges. Because "convergence"
25
+ may not be easily defined, we define an additional phase which we call
26
+ **metrology**: this represents a phase most similar to querying, but allows
27
+ a user to define some set of criteria to monitor over the course of active
28
+ learning *beyond* simple validation metrics to ensure the model can be used
29
+ with confidence as surrogates (e.g. within a simulation loop).
30
+
31
+ ## How to use this module
32
+
33
+ With the context above in mind, inspecting the `driver` module will give you
34
+ a sense for how the end-to-end workflow functions; the `Driver` class acts
35
+ as an orchestrator for all the phases of active learning we described above.
36
+
37
+ From there, you should realize that `Driver` is written in a highly abstract
38
+ way: we need concrete *strategies* that implement querying, labeling, and metrology
39
+ concepts. The `protocols` module provides the scaffolding to do so - we implement
40
+ various components as `typing.Protocol` which are used for structural sub-typing:
41
+ they can be thought of as abstract classes that define an expected interface
42
+ in a function or class from which you can define your own classes by either
43
+ inheriting from them, or defining your own class that implements the expected
44
+ methods and attributes.
45
+
46
+ In order to perform the training portion of active learning, we provide a
47
+ minimal yet functional `DefaultTrainingLoop` inside the `loop` module. This
48
+ loop simply requires a `protocols.TrainingProtocol` to be passed, which is
49
+ a function that defines the logic for computing the loss per batch/training
50
+ step.
51
+
52
+ ## Configuring workflows
53
+
54
+ The `config` module defines some simple `dataclass`es that can be used
55
+ to configure the behavior of various parts of active learning, e.g. how
56
+ training is conducted, etc. Because `Driver` is designed to be checkpointable,
57
+ with the exception of a few parts such as datasets, everything should be
58
+ JSON-serializable.
59
+
60
+ ## Restarting workflows
61
+
62
+ For classes and functions that are created at runtime, checkpointing requires
63
+ that these components can be recreated when restarting from a checkpoint. To
64
+ that end, the `_registry` module provides a user-friendly way to instantiate
65
+ objects: user-defined strategy classes can be added to the registry to enable
66
+ their creation in checkpoint restarts.
physics_mcp/source/physicsnemo/active_learning/__init__.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from physicsnemo.active_learning._registry import registry
18
+ from physicsnemo.active_learning.config import (
19
+ DriverConfig,
20
+ OptimizerConfig,
21
+ StrategiesConfig,
22
+ TrainingConfig,
23
+ )
24
+ from physicsnemo.active_learning.driver import Driver
25
+ from physicsnemo.active_learning.loop import DefaultTrainingLoop
26
+
27
+ __all__ = [
28
+ "registry",
29
+ "Driver",
30
+ "DefaultTrainingLoop",
31
+ "DriverConfig",
32
+ "OptimizerConfig",
33
+ "StrategiesConfig",
34
+ "TrainingConfig",
35
+ ]
physics_mcp/source/physicsnemo/active_learning/_registry.py ADDED
@@ -0,0 +1,332 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from __future__ import annotations
18
+
19
+ import importlib
20
+ import inspect
21
+ from typing import Any, Callable
22
+ from warnings import warn
23
+
24
+ from physicsnemo.active_learning.protocols import ActiveLearningProtocol
25
+
26
+ __all__ = ["registry"]
27
+
28
+
29
+ class ActiveLearningRegistry:
30
+ """
31
+ Registry for active learning protocols.
32
+
33
+ This class provides a centralized registry for user-defined active learning
34
+ protocols that implement the `ActiveLearningProtocol`. It enables string-based
35
+ lookups for checkpointing and provides argument validation when constructing
36
+ protocol instances.
37
+
38
+ The registry supports two primary modes of interaction:
39
+ 1. Registration via decorator: `@registry.register("my_strategy")`
40
+ 2. Construction with validation: `registry.construct("my_strategy", **kwargs)`
41
+
42
+ Attributes
43
+ ----------
44
+ _registry : dict[str, type[ActiveLearningProtocol]]
45
+ Internal dictionary mapping protocol names to their class types.
46
+
47
+ Methods
48
+ -------
49
+ register(cls_name: str) -> Callable[[type[ActiveLearningProtocol]], type[ActiveLearningProtocol]]
50
+ Decorator to register a protocol class with a given name.
51
+ construct(cls_name: str, **kwargs) -> ActiveLearningProtocol
52
+ Construct an instance of a registered protocol with argument validation.
53
+ is_registered(cls_name: str) -> bool
54
+ Check if a protocol name is registered.
55
+
56
+ Properties
57
+ ----------
58
+ registered_names : list[str]
59
+ A list of all registered protocol names, sorted alphabetically.
60
+
61
+ Examples
62
+ --------
63
+ Register a custom strategy:
64
+
65
+ >>> from physicsnemo.active_learning._registry import registry
66
+ >>> @registry.register("my_custom_strategy")
67
+ ... class MyCustomStrategy:
68
+ ... def __init__(self, param1: int, param2: str):
69
+ ... self.param1 = param1
70
+ ... self.param2 = param2
71
+
72
+ Construct an instance with validation:
73
+
74
+ >>> strategy = registry.construct("my_custom_strategy", param1=42, param2="test")
75
+ """
76
+
77
+ def __init__(self) -> None:
78
+ """Initialize an empty registry."""
79
+ self._registry: dict[str, type[ActiveLearningProtocol]] = {}
80
+
81
+ def register(
82
+ self, cls_name: str
83
+ ) -> Callable[[type[ActiveLearningProtocol]], type[ActiveLearningProtocol]]:
84
+ """
85
+ Decorator to register an active learning protocol class.
86
+
87
+ This decorator registers a class implementing the `ActiveLearningProtocol`
88
+ under the given name, allowing it to be retrieved and constructed later
89
+ using the `construct` method.
90
+
91
+ Parameters
92
+ ----------
93
+ cls_name : str
94
+ The name to register the protocol under. This will be used as the
95
+ key for later retrieval.
96
+
97
+ Returns
98
+ -------
99
+ Callable[[type[ActiveLearningProtocol]], type[ActiveLearningProtocol]]
100
+ A decorator function that registers the class and returns it unchanged.
101
+
102
+ Raises
103
+ ------
104
+ ValueError
105
+ If a protocol with the same name is already registered.
106
+
107
+ Examples
108
+ --------
109
+ >>> @registry.register("my_new_strategy")
110
+ ... class MyStrategy:
111
+ ... def __init__(self, param: int):
112
+ ... self.param = param
113
+ """
114
+
115
+ def decorator(
116
+ cls: type[ActiveLearningProtocol],
117
+ ) -> type[ActiveLearningProtocol]:
118
+ """
119
+ Method for decorating a class to registry it with the registry.
120
+ """
121
+ if cls_name in self._registry:
122
+ raise ValueError(
123
+ f"Protocol '{cls_name}' is already registered. "
124
+ f"Existing class: {self._registry[cls_name].__name__}"
125
+ )
126
+ self._registry[cls_name] = cls
127
+ return cls
128
+
129
+ return decorator
130
+
131
+ def construct(
132
+ self, cls_name: str, module_path: str | None = None, **kwargs: Any
133
+ ) -> ActiveLearningProtocol:
134
+ """
135
+ Construct an instance of a registered protocol with argument validation.
136
+
137
+ This method retrieves a registered protocol class by name, validates that
138
+ the provided keyword arguments match the class's constructor signature,
139
+ and returns a new instance of the class.
140
+
141
+ Parameters
142
+ ----------
143
+ cls_name : str
144
+ The name of the registered protocol to construct.
145
+ module_path: str | None
146
+ The path to the module to get the class from.
147
+ **kwargs : Any
148
+ Keyword arguments to pass to the protocol's constructor.
149
+
150
+ Returns
151
+ -------
152
+ ActiveLearningProtocol
153
+ A new instance of the requested protocol class.
154
+
155
+ Raises
156
+ ------
157
+ KeyError
158
+ If the protocol name is not registered.
159
+ TypeError
160
+ If the provided keyword arguments do not match the constructor signature.
161
+ This includes missing required parameters or unexpected parameters.
162
+
163
+ Examples
164
+ --------
165
+ >>> from physicsnemo.active_learning._registry import registry
166
+ >>> @registry.register("my_latest_strategy")
167
+ ... class MyStrategy:
168
+ ... def __init__(self, param: int):
169
+ ... self.param = param
170
+ >>> strategy = registry.construct("my_latest_strategy", param=42)
171
+ """
172
+ cls = self.get_class(cls_name, module_path)
173
+
174
+ # Validate arguments against the class signature
175
+ try:
176
+ sig = inspect.signature(cls.__init__)
177
+ except (ValueError, TypeError) as e:
178
+ raise TypeError(
179
+ f"Could not inspect signature of {cls.__name__}.__init__: {e}"
180
+ )
181
+
182
+ # Get parameters, excluding 'self'
183
+ params = {
184
+ name: param for name, param in sig.parameters.items() if name != "self"
185
+ }
186
+
187
+ # Check if the signature accepts **kwargs
188
+ has_var_keyword = any(
189
+ p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
190
+ )
191
+
192
+ # Check for missing required parameters
193
+ missing = []
194
+ for name, param in params.items():
195
+ if (
196
+ param.kind
197
+ not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
198
+ and param.default is inspect.Parameter.empty
199
+ and name not in kwargs
200
+ ):
201
+ missing.append(name)
202
+
203
+ if missing:
204
+ raise TypeError(
205
+ f"Missing required arguments for {cls.__name__}: {', '.join(missing)}"
206
+ )
207
+
208
+ # Check for unexpected parameters (unless **kwargs is present)
209
+ if not has_var_keyword:
210
+ param_names = {
211
+ name
212
+ for name, param in params.items()
213
+ if param.kind
214
+ not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
215
+ }
216
+ unexpected = [name for name in kwargs if name not in param_names]
217
+
218
+ if unexpected:
219
+ warn(
220
+ f"Unexpected arguments for {cls.__name__}: {', '.join(unexpected)}. "
221
+ f"Valid parameters: {', '.join(sorted(param_names))}"
222
+ )
223
+ return cls(**kwargs)
224
+
225
+ def __getitem__(self, cls_name: str) -> type[ActiveLearningProtocol]:
226
+ """
227
+ Retrieve a registered protocol class by name using dict-like access.
228
+
229
+ This method allows accessing registered protocol classes using square
230
+ bracket notation, e.g., `registry['my_strategy']`.
231
+
232
+ Parameters
233
+ ----------
234
+ cls_name : str
235
+ The name of the registered protocol to retrieve.
236
+
237
+ Returns
238
+ -------
239
+ type[ActiveLearningProtocol]
240
+ The class type of the registered protocol.
241
+
242
+ Raises
243
+ ------
244
+ KeyError
245
+ If the protocol name is not registered.
246
+
247
+ Examples
248
+ --------
249
+ >>> from physicsnemo.active_learning._registry import registry
250
+ >>> @registry.register("my_strategy")
251
+ ... class MyStrategy:
252
+ ... def __init__(self, param: int):
253
+ ... self.param = param
254
+ >>> RetrievedClass = registry['my_strategy']
255
+ >>> instance = RetrievedClass(param=42)
256
+ """
257
+ if cls_name not in self._registry:
258
+ available = ", ".join(self._registry.keys()) if self._registry else "none"
259
+ raise KeyError(
260
+ f"Protocol '{cls_name}' is not registered. "
261
+ f"Available protocols: {available}"
262
+ )
263
+ return self._registry[cls_name]
264
+
265
+ def is_registered(self, cls_name: str) -> bool:
266
+ """
267
+ Check if a protocol name is registered.
268
+
269
+ Parameters
270
+ ----------
271
+ cls_name : str
272
+ The name of the protocol to check.
273
+
274
+ Returns
275
+ -------
276
+ bool
277
+ True if the protocol is registered, False otherwise.
278
+ """
279
+ return cls_name in self._registry
280
+
281
+ @property
282
+ def registered_names(self) -> list[str]:
283
+ """
284
+ A list of all registered protocol names, sorted alphabetically.
285
+
286
+ Returns
287
+ -------
288
+ list[str]
289
+ A list of all registered protocol names, sorted alphabetically.
290
+ """
291
+ return sorted(self._registry.keys())
292
+
293
+ def get_class(self, cls_name: str, module_path: str | None = None) -> type:
294
+ """
295
+ Get a class by name from the registry or from a module path.
296
+
297
+ Parameters
298
+ ----------
299
+ cls_name: str
300
+ The name of the class to get.
301
+ module_path: str | None
302
+ The path to the module to get the class from.
303
+
304
+ Returns
305
+ -------
306
+ type
307
+ The class.
308
+
309
+ Raises
310
+ ------
311
+ NameError: If the class is not found in the registry or module.
312
+ ModuleNotFoundError: If the module is not found with the specified module path.
313
+ """
314
+ if cls_name in self.registered_names:
315
+ return self._registry[cls_name]
316
+ else:
317
+ if module_path:
318
+ module = importlib.import_module(module_path)
319
+ cls = getattr(module, cls_name, None)
320
+ if not cls:
321
+ raise NameError(
322
+ f"Class {cls_name} not found in module {module_path}"
323
+ )
324
+ return cls
325
+ else:
326
+ raise NameError(
327
+ f"Class {cls_name} not found in registry, and no module path was provided."
328
+ )
329
+
330
+
331
+ # Module-level registry instance for global access
332
+ registry = ActiveLearningRegistry()
physics_mcp/source/physicsnemo/active_learning/config.py ADDED
@@ -0,0 +1,808 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ Configuration dataclasses for the active learning driver.
19
+
20
+ This module provides structured configuration classes that separate different
21
+ concerns in the active learning workflow: optimization, training, strategies,
22
+ and driver orchestration.
23
+ """
24
+
25
+ from __future__ import annotations
26
+
27
+ import math
28
+ import uuid
29
+ from collections import defaultdict
30
+ from dataclasses import dataclass, field
31
+ from json import dumps
32
+ from pathlib import Path
33
+ from typing import Any
34
+ from warnings import warn
35
+
36
+ import torch
37
+ from torch import distributed as dist
38
+ from torch.optim import AdamW, Optimizer
39
+ from torch.optim.lr_scheduler import _LRScheduler
40
+
41
+ from physicsnemo.active_learning import protocols as p
42
+ from physicsnemo.active_learning._registry import registry
43
+ from physicsnemo.active_learning.loop import DefaultTrainingLoop
44
+ from physicsnemo.distributed import DistributedManager
45
+
46
+
47
+ @dataclass
48
+ class OptimizerConfig:
49
+ """
50
+ Configuration for optimizer and learning rate scheduler.
51
+
52
+ This encapsulates all training optimization parameters, keeping
53
+ them separate from the active learning orchestration logic.
54
+
55
+ Attributes
56
+ ----------
57
+ optimizer_cls: type[Optimizer]
58
+ The optimizer class to use. Defaults to AdamW.
59
+ optimizer_kwargs: dict[str, Any]
60
+ Keyword arguments to pass to the optimizer constructor.
61
+ Defaults to {"lr": 1e-4}.
62
+ scheduler_cls: type[_LRScheduler] | None
63
+ The learning rate scheduler class to use. If None, no
64
+ scheduler will be configured.
65
+ scheduler_kwargs: dict[str, Any]
66
+ Keyword arguments to pass to the scheduler constructor.
67
+ """
68
+
69
+ optimizer_cls: type[Optimizer] = AdamW
70
+ optimizer_kwargs: dict[str, Any] = field(default_factory=lambda: {"lr": 1e-4})
71
+ scheduler_cls: type[_LRScheduler] | None = None
72
+ scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
73
+
74
+ def __post_init__(self) -> None:
75
+ """Validate optimizer configuration."""
76
+ # Validate learning rate if present
77
+ if "lr" in self.optimizer_kwargs:
78
+ lr = self.optimizer_kwargs["lr"]
79
+ if not isinstance(lr, (int, float)) or lr <= 0:
80
+ raise ValueError(f"Learning rate must be positive, got {lr}")
81
+
82
+ # Validate that scheduler_kwargs is only set if scheduler_cls is provided
83
+ if self.scheduler_kwargs and self.scheduler_cls is None:
84
+ raise ValueError(
85
+ "scheduler_kwargs provided but scheduler_cls is None. "
86
+ "Provide a scheduler_cls or remove scheduler_kwargs."
87
+ )
88
+
89
+ def to_dict(self) -> dict[str, Any]:
90
+ """
91
+ Returns a JSON-serializable dictionary representation of the OptimizerConfig.
92
+
93
+ For round-tripping, the registry is used to de-serialize the optimizer and scheduler
94
+ classes.
95
+
96
+ Returns
97
+ -------
98
+ dict[str, Any]
99
+ A dictionary that can be JSON serialized.
100
+ """
101
+ opt = {
102
+ "__name__": self.optimizer_cls.__name__,
103
+ "__module__": self.optimizer_cls.__module__,
104
+ }
105
+ if self.scheduler_cls:
106
+ scheduler = {
107
+ "__name__": self.scheduler_cls.__name__,
108
+ "__module__": self.scheduler_cls.__module__,
109
+ }
110
+ else:
111
+ scheduler = None
112
+ return {
113
+ "optimizer_cls": opt,
114
+ "optimizer_kwargs": self.optimizer_kwargs,
115
+ "scheduler_cls": scheduler,
116
+ "scheduler_kwargs": self.scheduler_kwargs,
117
+ }
118
+
119
+ @classmethod
120
+ def from_dict(cls, data: dict[str, Any]) -> OptimizerConfig:
121
+ """
122
+ Creates an OptimizerConfig instance from a dictionary.
123
+
124
+ This method assumes that the optimizer and scheduler classes are
125
+ included in the ``physicsnemo.active_learning.registry``, or
126
+ a module path is specified to import the class from.
127
+
128
+ Parameters
129
+ ----------
130
+ data: dict[str, Any]
131
+ A dictionary that was previously serialized using the ``to_dict`` method.
132
+
133
+ Returns
134
+ -------
135
+ OptimizerConfig
136
+ A new ``OptimizerConfig`` instance.
137
+ """
138
+ optimizer_cls = registry.get_class(
139
+ data["optimizer_cls"]["__name__"], data["optimizer_cls"]["__module__"]
140
+ )
141
+ if (s := data.get("scheduler_cls")) is not None:
142
+ scheduler_cls = registry.get_class(s["__name__"], s["__module__"])
143
+ else:
144
+ scheduler_cls = None
145
+ return cls(
146
+ optimizer_cls=optimizer_cls,
147
+ optimizer_kwargs=data["optimizer_kwargs"],
148
+ scheduler_cls=scheduler_cls,
149
+ scheduler_kwargs=data["scheduler_kwargs"],
150
+ )
151
+
152
+
153
+ @dataclass
154
+ class TrainingConfig:
155
+ """
156
+ Configuration for the training phase of active learning.
157
+
158
+ This groups all training-related components together, making it
159
+ clear when training is or isn't being used in the AL workflow.
160
+
161
+ Attributes
162
+ ----------
163
+ train_datapool: p.DataPool
164
+ The pool of labeled data to use for training.
165
+ max_training_epochs: int
166
+ The maximum number of epochs to train for. If ``max_fine_tuning_epochs``
167
+ isn't specified, this value is used for all active learning steps.
168
+ val_datapool: p.DataPool | None
169
+ Optional pool of data to use for validation during training.
170
+ optimizer_config: OptimizerConfig
171
+ Configuration for the optimizer and scheduler. Defaults to
172
+ AdamW with lr=1e-4, no scheduler.
173
+ max_fine_tuning_epochs: int | None
174
+ The maximum number of epochs used during fine-tuning steps, i.e. after
175
+ the first active learning step. If ``None``, then the fine-tuning will
176
+ be performed for the duration of the active learning loop.
177
+ train_loop_fn: p.TrainingLoop
178
+ The training loop function that orchestrates the training process.
179
+ This defaults to a concrete implementation, ``DefaultTrainingLoop``,
180
+ which provides a very typical loop that includes the use of static
181
+ capture, etc.
182
+ """
183
+
184
+ train_datapool: p.DataPool
185
+ max_training_epochs: int
186
+ val_datapool: p.DataPool | None = None
187
+ optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)
188
+ max_fine_tuning_epochs: int | None = None
189
+ train_loop_fn: p.TrainingLoop = field(default_factory=DefaultTrainingLoop)
190
+
191
+ def __post_init__(self) -> None:
192
+ """Validate training configuration."""
193
+ # Validate datapools have consistent interface
194
+ if not hasattr(self.train_datapool, "__len__"):
195
+ raise ValueError("train_datapool must implement __len__")
196
+ if self.val_datapool is not None and not hasattr(self.val_datapool, "__len__"):
197
+ raise ValueError("val_datapool must implement __len__")
198
+
199
+ # Validate training loop is callable
200
+ if not callable(self.train_loop_fn):
201
+ raise ValueError("train_loop_fn must be callable")
202
+
203
+ # set the same value for fine tuning epochs if not provided
204
+ if self.max_fine_tuning_epochs is None:
205
+ self.max_fine_tuning_epochs = self.max_training_epochs
206
+
207
+ def to_dict(self) -> dict[str, Any]:
208
+ """
209
+ Returns a JSON-serializable dictionary representation of the TrainingConfig.
210
+
211
+ For round-tripping, the registry is used to de-serialize the training loop
212
+ and optimizer configuration. Note that datapools (train_datapool and val_datapool)
213
+ are NOT serialized as they typically contain large datasets, file handles, or other
214
+ non-serializable state.
215
+
216
+ Returns
217
+ -------
218
+ dict[str, Any]
219
+ A dictionary that can be JSON serialized. Excludes datapools.
220
+
221
+ Warnings
222
+ --------
223
+ This method will issue a warning about the exclusion of datapools.
224
+ """
225
+ # Warn about datapool exclusion
226
+ warn(
227
+ "The `train_datapool` and `val_datapool` attributes are not supported for "
228
+ "serialization and will be excluded from the ``TrainingConfig`` dictionary. "
229
+ "You must re-provide these datapools when deserializing."
230
+ )
231
+
232
+ # Serialize optimizer config
233
+ optimizer_dict = self.optimizer_config.to_dict()
234
+
235
+ # Serialize training loop function
236
+ if not hasattr(self.train_loop_fn, "_args"):
237
+ raise ValueError(
238
+ f"Training loop {self.train_loop_fn} does not have an `_args` attribute "
239
+ "which is required for serialization. Make sure your training loop "
240
+ "either subclasses `ActiveLearningProtocol` or implements the `__new__` "
241
+ "method to capture object arguments."
242
+ )
243
+
244
+ train_loop_dict = self.train_loop_fn._args
245
+
246
+ return {
247
+ "max_training_epochs": self.max_training_epochs,
248
+ "max_fine_tuning_epochs": self.max_fine_tuning_epochs,
249
+ "optimizer_config": optimizer_dict,
250
+ "train_loop_fn": train_loop_dict,
251
+ }
252
+
253
+ @classmethod
254
+ def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> TrainingConfig:
255
+ """
256
+ Creates a TrainingConfig instance from a dictionary.
257
+
258
+ This method assumes that the training loop class is included in the
259
+ ``physicsnemo.active_learning.registry``, or a module path is specified
260
+ to import the class from. Note that datapools must be provided via
261
+ kwargs as they are not serialized.
262
+
263
+ Parameters
264
+ ----------
265
+ data: dict[str, Any]
266
+ A dictionary that was previously serialized using the ``to_dict`` method.
267
+ **kwargs: Any
268
+ Additional keyword arguments to pass to the constructor. This is where
269
+ you must provide ``train_datapool`` and optionally ``val_datapool``.
270
+
271
+ Returns
272
+ -------
273
+ TrainingConfig
274
+ A new ``TrainingConfig`` instance.
275
+
276
+ Raises
277
+ ------
278
+ ValueError
279
+ If required datapools are not provided in kwargs, if the data contains
280
+ unexpected keys, or if object construction fails.
281
+ """
282
+ # Ensure required datapools are provided
283
+ if "train_datapool" not in kwargs:
284
+ raise ValueError(
285
+ "``train_datapool`` must be provided in kwargs when deserializing "
286
+ "TrainingConfig, as datapools are not serialized."
287
+ )
288
+
289
+ # Reconstruct optimizer config
290
+ optimizer_config = OptimizerConfig.from_dict(data["optimizer_config"])
291
+
292
+ # Reconstruct training loop function
293
+ train_loop_data = data["train_loop_fn"]
294
+ train_loop_fn = registry.construct(
295
+ train_loop_data["__name__"],
296
+ module_path=train_loop_data["__module__"],
297
+ **train_loop_data["__args__"],
298
+ )
299
+
300
+ # Build the config
301
+ try:
302
+ config = cls(
303
+ max_training_epochs=data["max_training_epochs"],
304
+ max_fine_tuning_epochs=data.get("max_fine_tuning_epochs"),
305
+ optimizer_config=optimizer_config,
306
+ train_loop_fn=train_loop_fn,
307
+ **kwargs,
308
+ )
309
+ except Exception as e:
310
+ raise ValueError(
311
+ "Failed to construct ``TrainingConfig`` from dictionary."
312
+ ) from e
313
+
314
+ return config
315
+
316
+
317
+ @dataclass
318
+ class StrategiesConfig:
319
+ """
320
+ Configuration for active learning strategies and data acquisition.
321
+
322
+ This encapsulates the query-label-metrology cycle that is at the
323
+ heart of active learning: strategies for selecting data, labeling it,
324
+ and measuring model uncertainty/performance.
325
+
326
+ Attributes
327
+ ----------
328
+ query_strategies: list[p.QueryStrategy]
329
+ The query strategies to use for selecting data to label.
330
+ queue_cls: type[p.AbstractQueue]
331
+ The queue implementation to use for passing data between
332
+ query and labeling phases.
333
+ label_strategy: p.LabelStrategy | None
334
+ The strategy to use for labeling queried data. If None,
335
+ labeling will be skipped.
336
+ metrology_strategies: list[p.MetrologyStrategy] | None
337
+ Strategies for measuring model performance and uncertainty.
338
+ If None, metrology will be skipped.
339
+ unlabeled_datapool: p.DataPool | None
340
+ Pool of unlabeled data that query strategies can sample from.
341
+ Not all strategies require this (some may generate synthetic data).
342
+ """
343
+
344
+ query_strategies: list[p.QueryStrategy]
345
+ queue_cls: type[p.AbstractQueue]
346
+ label_strategy: p.LabelStrategy | None = None
347
+ metrology_strategies: list[p.MetrologyStrategy] | None = None
348
+ unlabeled_datapool: p.DataPool | None = None
349
+
350
+ def __post_init__(self) -> None:
351
+ """Validate strategies configuration."""
352
+ # Must have at least one query strategy
353
+ if not self.query_strategies:
354
+ raise ValueError(
355
+ "At least one query strategy must be provided. "
356
+ "Active learning requires a mechanism to select data."
357
+ )
358
+
359
+ # All query strategies must be callable
360
+ for strategy in self.query_strategies:
361
+ if not callable(strategy):
362
+ raise ValueError(f"Query strategy {strategy} must be callable")
363
+
364
+ # Label strategy must be callable if provided
365
+ if self.label_strategy is not None and not callable(self.label_strategy):
366
+ raise ValueError("label_strategy must be callable")
367
+
368
+ # Metrology strategies must be callable if provided
369
+ if self.metrology_strategies is not None:
370
+ if not self.metrology_strategies:
371
+ raise ValueError(
372
+ "metrology_strategies is an empty list. "
373
+ "Either provide strategies or set to None to skip metrology."
374
+ )
375
+ for strategy in self.metrology_strategies:
376
+ if not callable(strategy):
377
+ raise ValueError(f"Metrology strategy {strategy} must be callable")
378
+
379
+ # Validate queue class has basic queue interface
380
+ if not hasattr(self.queue_cls, "__call__"):
381
+ raise ValueError("queue_cls must be a callable class")
382
+
383
+ def to_dict(self) -> dict[str, Any]:
384
+ """
385
+ Method that converts the present ``StrategiesConfig`` instance into a dictionary
386
+ that can be JSON serialized.
387
+
388
+ This method, for the most part, assumes that strategies are subclasses of
389
+ ``ActiveLearningProtocol`` and/or they have an ``_args`` attribute that
390
+ captures the arguments to the constructor.
391
+
392
+ One issue is the inability to reliably serialize the ``unlabeled_datapool``,
393
+ which for the most part, likely does not need serialization as a dataset.
394
+ Regardless, this method will trigger a warning if ``unlabeled_datapool`` is
395
+ not None.
396
+
397
+ Returns
398
+ -------
399
+ dict[str, Any]
400
+ A dictionary that can be JSON serialized.
401
+ """
402
+ output = defaultdict(list)
403
+ for strategy in self.query_strategies:
404
+ if not hasattr(strategy, "_args"):
405
+ raise ValueError(
406
+ f"Query strategy {strategy} does not have an `_args` attribute"
407
+ " which is required for serialization. Make sure your strategy"
408
+ " either subclasses `ActiveLearningProtocol` or implements"
409
+ " the `__new__` method to capture object arguments."
410
+ )
411
+ output["query_strategies"].append(strategy._args)
412
+ if self.label_strategy is not None:
413
+ if not hasattr(self.label_strategy, "_args"):
414
+ raise ValueError(
415
+ f"Label strategy {self.label_strategy} does not have an `_args` attribute"
416
+ " which is required for serialization. Make sure your strategy"
417
+ " either subclasses `ActiveLearningProtocol` or implements"
418
+ " the `__new__` method to capture object arguments."
419
+ )
420
+ output["label_strategy"] = self.label_strategy._args
421
+ output["queue_cls"] = {
422
+ "__name__": self.queue_cls.__name__,
423
+ "__module__": self.queue_cls.__module__,
424
+ }
425
+ if self.metrology_strategies is not None:
426
+ for strategy in self.metrology_strategies:
427
+ if not hasattr(strategy, "_args"):
428
+ raise ValueError(
429
+ f"Metrology strategy {strategy} does not have an `_args` attribute"
430
+ " which is required for serialization. Make sure your strategy"
431
+ " either subclasses `ActiveLearningProtocol` or implements"
432
+ " the `__new__` method to capture object arguments."
433
+ )
434
+ output["metrology_strategies"].append(strategy._args)
435
+ if self.unlabeled_datapool is not None:
436
+ warn(
437
+ "The `unlabeled_datapool` attribute is not supported for serialization"
438
+ " and will be excluded from the ``StrategiesConfig`` dictionary."
439
+ )
440
+ return output
441
+
442
+ @classmethod
443
+ def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> StrategiesConfig:
444
+ """
445
+ Create a ``StrategiesConfig`` instance from a dictionary.
446
+
447
+ This method heavily relies on classes being added to the
448
+ ``physicsnemo.active_learning.registry``, which is used to instantiate
449
+ all strategies and custom types used in active learning. As a fall
450
+ back, the `registry.construct` method will try and import the class
451
+ from the module path if it is not found in the registry.
452
+
453
+ Parameters
454
+ ----------
455
+ data: dict[str, Any]
456
+ A dictionary that was previously serialized using the ``to_dict`` method.
457
+ **kwargs: Any
458
+ Additional keyword arguments to pass to the constructor.
459
+
460
+ Returns
461
+ -------
462
+ StrategiesConfig
463
+ A new ``StrategiesConfig`` instance.
464
+
465
+ Raises
466
+ ------
467
+ ValueError:
468
+ If the data contains unexpected keys or if the object construction fails.
469
+ NameError:
470
+ If a class is not found in the registry and no module path is provided.
471
+ ModuleNotFoundError:
472
+ If a module is not found with the specified module path.
473
+ """
474
+ # ensure that the data contains no unexpected keys
475
+ data_keys = set(data.keys())
476
+ expected_keys = set(cls.__dataclass_fields__.keys())
477
+ extra_keys = data_keys - expected_keys
478
+ if extra_keys:
479
+ raise ValueError(
480
+ f"Unexpected keys in data: {extra_keys}. Expected keys are {expected_keys}."
481
+ )
482
+ # instantiate objects from the serialized data; general strategy is to
483
+ # use `registry.construct` that will try and resolve the class within
484
+ # the registry first, and if not found, then it will try and import the
485
+ # class from the module path.
486
+ output_dict = defaultdict(list)
487
+ for entry in data["query_strategies"]:
488
+ output_dict["query_strategies"].append(
489
+ registry.construct(
490
+ entry["__name__"],
491
+ module_path=entry["__module__"],
492
+ **entry["__args__"],
493
+ )
494
+ )
495
+ if "metrology_strategies" in data:
496
+ for entry in data["metrology_strategies"]:
497
+ output_dict["metrology_strategies"].append(
498
+ registry.construct(
499
+ entry["__name__"],
500
+ module_path=entry["__module__"],
501
+ **entry["__args__"],
502
+ )
503
+ )
504
+ if "label_strategy" in data:
505
+ output_dict["label_strategy"] = registry.construct(
506
+ data["label_strategy"]["__name__"],
507
+ module_path=data["label_strategy"]["__module__"],
508
+ **data["label_strategy"]["__args__"],
509
+ )
510
+ output_dict["queue_cls"] = registry.get_class(
511
+ data["queue_cls"]["__name__"], data["queue_cls"]["__module__"]
512
+ )
513
+ # potentially override with keyword arguments
514
+ output_dict.update(kwargs)
515
+ try:
516
+ config = cls(**output_dict)
517
+ except Exception as e:
518
+ raise ValueError(
519
+ "Failed to construct ``StrategiesConfig`` from dictionary."
520
+ ) from e
521
+ return config
522
+
523
+
524
+ @dataclass
525
+ class DriverConfig:
526
+ """
527
+ Configuration for driver orchestration and infrastructure.
528
+
529
+ This contains parameters that control the overall loop execution,
530
+ logging, checkpointing, and distributed training setup - orthogonal
531
+ to the specific AL or training logic.
532
+
533
+ Attributes
534
+ ----------
535
+ batch_size: int
536
+ The batch size to use for data loaders.
537
+ max_active_learning_steps: int | None, default None
538
+ Maximum number of AL iterations to perform. None means infinite.
539
+ run_id: str, default auto-generated UUID
540
+ Unique identifier for this run. Auto-generated if not provided.
541
+ fine_tuning_lr: float | None, default None
542
+ Learning rate to switch to after the first AL step for fine-tuning.
543
+ reset_optim_states: bool, default True
544
+ Whether to reset optimizer states between AL steps.
545
+ skip_training: bool, default False
546
+ If True, skip the training phase entirely.
547
+ skip_metrology: bool, default False
548
+ If True, skip the metrology phase entirely.
549
+ skip_labeling: bool, default False
550
+ If True, skip the labeling phase entirely.
551
+ checkpoint_interval: int, default 1
552
+ Save model checkpoint every N AL steps. 0 disables checkpointing.
553
+ checkpoint_on_training: bool, default False
554
+ If True, save checkpoint at the start of the training phase.
555
+ checkpoint_on_metrology: bool, default False
556
+ If True, save checkpoint at the start of the metrology phase.
557
+ checkpoint_on_query: bool, default False
558
+ If True, save checkpoint at the start of the query phase.
559
+ checkpoint_on_labeling: bool, default True
560
+ If True, save checkpoint at the start of the labeling phase.
561
+ model_checkpoint_frequency: int, default 0
562
+ Save model weights every N epochs during training. 0 means only save
563
+ between active learning phases. Useful for mid-training restarts.
564
+ root_log_dir: str | Path, default Path.cwd() / "active_learning_logs"
565
+ Directory to save logs and checkpoints to. Defaults to
566
+ an 'active_learning_logs' directory in the current working directory.
567
+ dist_manager: DistributedManager | None, default None
568
+ Manager for distributed training configuration.
569
+ collate_fn: callable | None, default None
570
+ Custom collate function for batching data.
571
+ num_dataloader_workers: int, default 0
572
+ Number of worker processes for data loading.
573
+ device: str | torch.device | None, default None
574
+ Device to use for model and data. This is intended for single process
575
+ workflows; for distributed workflows, the device should be set in
576
+ ``DistributedManager`` instead. If not specified, then the device
577
+ will default to ``torch.get_default_device()``.
578
+ dtype: torch.dtype | None, default None
579
+ The dtype to use for model and data, and AMP contexts. If not provided,
580
+ then the dtype will default to ``torch.get_default_dtype()``.
581
+ """
582
+
583
+ batch_size: int
584
+ max_active_learning_steps: int | None = None
585
+ run_id: str = field(default_factory=lambda: str(uuid.uuid4()))
586
+ fine_tuning_lr: float | None = None # TODO: move to TrainingConfig
587
+ reset_optim_states: bool = True
588
+ skip_training: bool = False
589
+ skip_metrology: bool = False
590
+ skip_labeling: bool = False
591
+ checkpoint_interval: int = 1
592
+ checkpoint_on_training: bool = False
593
+ checkpoint_on_metrology: bool = False
594
+ checkpoint_on_query: bool = False
595
+ checkpoint_on_labeling: bool = True
596
+ model_checkpoint_frequency: int = 0
597
+ root_log_dir: str | Path = field(default=Path.cwd() / "active_learning_logs")
598
+ dist_manager: DistributedManager | None = None
599
+ collate_fn: callable | None = None
600
+ num_dataloader_workers: int = 0
601
+ device: str | torch.device | None = None
602
+ dtype: torch.dtype | None = None
603
+
604
+ def __post_init__(self) -> None:
605
+ """Validate driver configuration."""
606
+ if self.max_active_learning_steps is None:
607
+ self.max_active_learning_steps = float("inf")
608
+
609
+ if (
610
+ self.max_active_learning_steps is not None
611
+ and self.max_active_learning_steps <= 0
612
+ ):
613
+ raise ValueError(
614
+ "`max_active_learning_steps` must be a positive integer or None."
615
+ )
616
+
617
+ if not math.isfinite(self.batch_size) or self.batch_size <= 0:
618
+ raise ValueError("`batch_size` must be a positive integer.")
619
+
620
+ if not math.isfinite(self.checkpoint_interval) or self.checkpoint_interval < 0:
621
+ raise ValueError(
622
+ "`checkpoint_interval` must be a non-negative integer. "
623
+ "Use 0 to disable checkpointing."
624
+ )
625
+
626
+ if self.fine_tuning_lr is not None and self.fine_tuning_lr <= 0:
627
+ raise ValueError("`fine_tuning_lr` must be positive if provided.")
628
+
629
+ if self.num_dataloader_workers < 0:
630
+ raise ValueError("`num_dataloader_workers` must be non-negative.")
631
+
632
+ if self.model_checkpoint_frequency < 0:
633
+ raise ValueError("`model_checkpoint_frequency` must be non-negative.")
634
+
635
+ if isinstance(self.root_log_dir, str):
636
+ self.root_log_dir = Path(self.root_log_dir)
637
+
638
+ # Validate collate_fn if provided
639
+ if self.collate_fn is not None and not callable(self.collate_fn):
640
+ raise ValueError("`collate_fn` must be callable if provided.")
641
+
642
+ # device and dtype setup when not using DistributedManager
643
+ if self.device is None and not self.dist_manager:
644
+ self.device = torch.get_default_device()
645
+ if self.dtype is None:
646
+ self.dtype = torch.get_default_dtype()
647
+
648
+ def to_json(self) -> str:
649
+ """
650
+ Returns a JSON string representation of the ``DriverConfig``.
651
+
652
+ Note that certain fields are not serialized and must be provided when
653
+ deserializing: ``dist_manager``, ``collate_fn``.
654
+
655
+ Returns
656
+ -------
657
+ str
658
+ A JSON string representation of the config.
659
+ """
660
+ # base dict representation skips Python objects
661
+ dict_repr = {
662
+ key: self.__dict__[key]
663
+ for key in self.__dict__
664
+ if key
665
+ not in ["dist_manager", "collate_fn", "root_log_dir", "device", "dtype"]
666
+ }
667
+ # Note: checkpoint flags are included in dict_repr automatically
668
+ dict_repr["default_dtype"] = str(torch.get_default_dtype())
669
+ dict_repr["log_dir"] = str(self.root_log_dir)
670
+ # Convert dtype to string for JSON serialization
671
+ if self.dtype is not None:
672
+ dict_repr["dtype"] = str(self.dtype)
673
+ else:
674
+ dict_repr["dtype"] = None
675
+ if self.dist_manager is not None:
676
+ dict_repr["world_size"] = self.dist_manager.world_size
677
+ dict_repr["device"] = self.dist_manager.device.type
678
+ dict_repr["dist_manager_init_method"] = (
679
+ self.dist_manager._initialization_method
680
+ )
681
+ else:
682
+ if dist.is_initialized():
683
+ world_size = dist.get_world_size()
684
+ else:
685
+ world_size = 1
686
+ dict_repr["world_size"] = world_size
687
+ if self.device is not None:
688
+ dict_repr["device"] = (
689
+ str(self.device)
690
+ if hasattr(self.device, "type")
691
+ else str(self.device)
692
+ )
693
+ else:
694
+ dict_repr["device"] = torch.get_default_device().type
695
+ dict_repr["dist_manager_init_method"] = None
696
+ if self.collate_fn is not None:
697
+ dict_repr["collate_fn"] = self.collate_fn.__name__
698
+ else:
699
+ dict_repr["collate_fn"] = None
700
+ return dumps(dict_repr, indent=2)
701
+
702
+ @classmethod
703
+ def from_json(cls, json_str: str, **kwargs: Any) -> DriverConfig:
704
+ """
705
+ Creates a DriverConfig instance from a JSON string.
706
+
707
+ This method reconstructs a DriverConfig from JSON. Note that certain
708
+ fields cannot be serialized and must be provided via kwargs:
709
+ - ``dist_manager``: DistributedManager instance (optional)
710
+ - ``collate_fn``: Custom collate function (optional)
711
+
712
+ Parameters
713
+ ----------
714
+ json_str: str
715
+ A JSON string that was previously serialized using ``to_json()``.
716
+ **kwargs: Any
717
+ Additional keyword arguments to override or provide non-serializable
718
+ fields like ``dist_manager`` and ``collate_fn``.
719
+
720
+ Returns
721
+ -------
722
+ DriverConfig
723
+ A new ``DriverConfig`` instance.
724
+
725
+ Raises
726
+ ------
727
+ ValueError
728
+ If the JSON cannot be parsed or required fields are missing.
729
+
730
+ Notes
731
+ -----
732
+ The device and dtype fields are reconstructed from their string
733
+ representations. The ``log_dir`` field in JSON is mapped to
734
+ ``root_log_dir`` in the config.
735
+ """
736
+ import json
737
+
738
+ try:
739
+ data = json.loads(json_str)
740
+ except json.JSONDecodeError as e:
741
+ raise ValueError(f"Invalid JSON string: {e}") from e
742
+
743
+ # Define fields that are not actual DriverConfig constructor parameters
744
+ metadata_fields = [
745
+ "default_dtype",
746
+ "world_size",
747
+ "dist_manager_init_method",
748
+ "log_dir", # handled separately as root_log_dir
749
+ ]
750
+ non_serializable_fields = [
751
+ "dist_manager",
752
+ "collate_fn",
753
+ "root_log_dir",
754
+ "device",
755
+ "dtype",
756
+ ]
757
+
758
+ # Extract serializable fields that map directly
759
+ config_fields = {
760
+ key: value
761
+ for key, value in data.items()
762
+ if key not in metadata_fields + non_serializable_fields
763
+ }
764
+
765
+ # Handle root_log_dir (stored as "log_dir" in JSON)
766
+ if "log_dir" in data:
767
+ config_fields["root_log_dir"] = Path(data["log_dir"])
768
+
769
+ # Handle device reconstruction from string
770
+ if "device" in data and data["device"] is not None:
771
+ device_str = data["device"]
772
+ # Parse device strings like "cuda:0", "cpu", "cuda", etc.
773
+ config_fields["device"] = torch.device(device_str)
774
+
775
+ # Handle dtype reconstruction from string
776
+ if "dtype" in data and data["dtype"] is not None:
777
+ dtype_str = data["dtype"]
778
+ # Map string representations to torch dtypes
779
+ dtype_map = {
780
+ "torch.float32": torch.float32,
781
+ "torch.float64": torch.float64,
782
+ "torch.float16": torch.float16,
783
+ "torch.bfloat16": torch.bfloat16,
784
+ "torch.int32": torch.int32,
785
+ "torch.int64": torch.int64,
786
+ "torch.int8": torch.int8,
787
+ "torch.uint8": torch.uint8,
788
+ }
789
+ if dtype_str in dtype_map:
790
+ config_fields["dtype"] = dtype_map[dtype_str]
791
+ else:
792
+ warn(
793
+ f"Unknown dtype string '{dtype_str}' in JSON. "
794
+ "Using default dtype instead."
795
+ )
796
+
797
+ # Merge with provided kwargs (allows overriding and adding non-serializable fields)
798
+ config_fields.update(kwargs)
799
+
800
+ # Create the config
801
+ try:
802
+ config = cls(**config_fields)
803
+ except Exception as e:
804
+ raise ValueError(
805
+ "Failed to construct ``DriverConfig`` from JSON string."
806
+ ) from e
807
+
808
+ return config
physics_mcp/source/physicsnemo/active_learning/driver.py ADDED
@@ -0,0 +1,1449 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ This module contains the definition for an active learning driver
19
+ class, which is responsible for orchestration and automation of
20
+ the end-to-end active learning process.
21
+ """
22
+
23
+ from __future__ import annotations
24
+
25
+ import inspect
26
+ import pickle
27
+ from contextlib import contextmanager
28
+ from copy import deepcopy
29
+ from dataclasses import dataclass
30
+ from pathlib import Path
31
+ from typing import Any, Generator
32
+
33
+ import torch
34
+ from torch import distributed as dist
35
+ from torch.nn.parallel import DistributedDataParallel
36
+ from torch.utils.data import DataLoader, DistributedSampler
37
+
38
+ from physicsnemo import Module
39
+ from physicsnemo import __version__ as physicsnemo_version
40
+ from physicsnemo.active_learning import protocols as p
41
+ from physicsnemo.active_learning.config import (
42
+ DriverConfig,
43
+ StrategiesConfig,
44
+ TrainingConfig,
45
+ )
46
+ from physicsnemo.active_learning.logger import (
47
+ ActiveLearningLoggerAdapter,
48
+ setup_active_learning_logger,
49
+ )
50
+ from physicsnemo.distributed import DistributedManager
51
+
52
+
53
+ @dataclass
54
+ class ActiveLearningCheckpoint:
55
+ """
56
+ Metadata associated with an ongoing (or completed) active
57
+ learning experiment.
58
+
59
+ The information contained in this metadata should be sufficient
60
+ to restart the active learning experiment at the nearest point:
61
+ for example, training should be able to continue from an epoch,
62
+ while for querying/sampling, etc. we continue from a pre-existing
63
+ queue.
64
+ """
65
+
66
+ driver_config: DriverConfig
67
+ strategies_config: StrategiesConfig
68
+ active_learning_step_idx: int
69
+ active_learning_phase: p.ActiveLearningPhase
70
+ physicsnemo_version: str = physicsnemo_version
71
+ training_config: TrainingConfig | None = None
72
+ optimizer_state: dict[str, Any] | None = None
73
+ lr_scheduler_state: dict[str, Any] | None = None
74
+ has_query_queue: bool = False
75
+ has_label_queue: bool = False
76
+
77
+
78
+ class Driver(p.DriverProtocol):
79
+ """
80
+ Provides a simple implementation of the ``DriverProtocol`` used to
81
+ orchestrate an active learning process within PhysicsNeMo.
82
+
83
+ At a high level, the active learning process is broken down into four
84
+ phases: training, metrology, query, and labeling.
85
+
86
+ To understand the orchestration, start by inspecting the
87
+ ``active_learning_step`` method, which defines a single iteration of
88
+ the active learning loop, which is dispatched by the ``run`` method.
89
+ From there, it should be relatively straightforward to trace the
90
+ remaining components.
91
+
92
+ Attributes
93
+ ----------
94
+ config: DriverConfig
95
+ Infrastructure and orchestration configuration.
96
+ learner: Module | p.LearnerProtocol
97
+ The learner module for the active learning process.
98
+ strategies_config: StrategiesConfig
99
+ Active learning strategies (query, label, metrology).
100
+ training_config: TrainingConfig | None
101
+ Training components. None if training is skipped.
102
+ inference_fn: p.InferenceProtocol | None
103
+ Custom inference function.
104
+ active_learning_step_idx: int
105
+ Current iteration index of the active learning loop.
106
+ query_queue: p.AbstractQueue
107
+ Queue populated with data by query strategies.
108
+ label_queue: p.AbstractQueue
109
+ Queue populated with labeled data by the label strategy.
110
+ optimizer: torch.optim.Optimizer | None
111
+ Configured optimizer (set after configure_optimizer is called).
112
+ lr_scheduler: torch.optim.lr_scheduler._LRScheduler | None
113
+ Configured learning rate scheduler.
114
+ logger: logging.Logger
115
+ Persistent logger for the active learning process.
116
+ """
117
+
118
+ # Phase execution order for active learning step (immutable)
119
+ _PHASE_ORDER = [
120
+ p.ActiveLearningPhase.TRAINING,
121
+ p.ActiveLearningPhase.METROLOGY,
122
+ p.ActiveLearningPhase.QUERY,
123
+ p.ActiveLearningPhase.LABELING,
124
+ ]
125
+
126
+ def __init__(
127
+ self,
128
+ config: DriverConfig,
129
+ learner: Module | p.LearnerProtocol,
130
+ strategies_config: StrategiesConfig,
131
+ training_config: TrainingConfig | None = None,
132
+ inference_fn: p.InferenceProtocol | None = None,
133
+ ) -> None:
134
+ """
135
+ Initializes the active learning driver.
136
+
137
+ At the bare minimum, the driver requires a config, learner, and
138
+ strategies config to be used in a purely querying loop. Additional
139
+ arguments can be provided to enable training and other workflows.
140
+
141
+ Parameters
142
+ ----------
143
+ config: DriverConfig
144
+ Orchestration and infrastructure configuration, for example
145
+ the batch size, the log directory, the distributed manager, etc.
146
+ learner: Module | p.LearnerProtocol
147
+ The model to use for active learning.
148
+ strategies_config: StrategiesConfig
149
+ Container for active learning strategies (query, label, metrology).
150
+ training_config: TrainingConfig | None
151
+ Training components. Required if ``skip_training`` is False in
152
+ the ``DriverConfig``.
153
+ inference_fn: p.InferenceProtocol | None
154
+ Custom inference function. If None, uses ``learner.__call__``.
155
+ This is not actually called by the driver, but is stored as an
156
+ attribute for attached strategies to use as needed.
157
+ """
158
+ # Configs have already validated themselves in __post_init__
159
+ self.config = config
160
+ self.learner = learner
161
+ self.strategies_config = strategies_config
162
+ self.training_config = training_config
163
+ self.inference_fn = inference_fn
164
+ self.active_learning_step_idx = 0
165
+ self.current_phase: p.ActiveLearningPhase | None = (
166
+ None # Track current phase for logging context
167
+ )
168
+ self._last_checkpoint_path: Path | None = None
169
+
170
+ # Validate cross-config constraints
171
+ self._validate_config_consistency()
172
+
173
+ self._setup_logger()
174
+ self.attach_strategies()
175
+
176
+ # Initialize queues from strategies_config
177
+ self.query_queue = strategies_config.queue_cls()
178
+ self.label_queue = strategies_config.queue_cls()
179
+
180
+ def _validate_config_consistency(self) -> None:
181
+ """
182
+ Validate consistency across configs.
183
+
184
+ Each config validates itself, but this method checks relationships
185
+ between configs that can only be validated when composed together.
186
+ """
187
+ # If training is not skipped, training_config must be provided
188
+ if not self.config.skip_training and self.training_config is None:
189
+ raise ValueError(
190
+ "`training_config` must be provided when `skip_training` is False."
191
+ )
192
+
193
+ # If labeling is not skipped, must have label strategy and train datapool
194
+ if not self.config.skip_labeling:
195
+ if self.strategies_config.label_strategy is None:
196
+ raise ValueError(
197
+ "`label_strategy` must be provided in strategies_config "
198
+ "when `skip_labeling` is False."
199
+ )
200
+ if (
201
+ self.training_config is None
202
+ or self.training_config.train_datapool is None
203
+ ):
204
+ raise ValueError(
205
+ "`train_datapool` must be provided in training_config "
206
+ "when `skip_labeling` is False (labeled data is appended to it)."
207
+ )
208
+
209
+ # If fine-tuning lr is set, must have training enabled
210
+ if self.config.fine_tuning_lr is not None and self.config.skip_training:
211
+ raise ValueError(
212
+ "`fine_tuning_lr` has no effect when `skip_training` is True."
213
+ )
214
+
215
+ @property
216
+ def query_strategies(self) -> list[p.QueryStrategy]:
217
+ """Returns the query strategies from strategies_config."""
218
+ return self.strategies_config.query_strategies
219
+
220
+ @property
221
+ def label_strategy(self) -> p.LabelStrategy | None:
222
+ """Returns the label strategy from strategies_config."""
223
+ return self.strategies_config.label_strategy
224
+
225
+ @property
226
+ def metrology_strategies(self) -> list[p.MetrologyStrategy] | None:
227
+ """Returns the metrology strategies from strategies_config."""
228
+ return self.strategies_config.metrology_strategies
229
+
230
+ @property
231
+ def unlabeled_datapool(self) -> p.DataPool | None:
232
+ """Returns the unlabeled datapool from strategies_config."""
233
+ return self.strategies_config.unlabeled_datapool
234
+
235
+ @property
236
+ def train_datapool(self) -> p.DataPool | None:
237
+ """Returns the training datapool from training_config."""
238
+ return self.training_config.train_datapool if self.training_config else None
239
+
240
+ @property
241
+ def val_datapool(self) -> p.DataPool | None:
242
+ """Returns the validation datapool from training_config."""
243
+ return self.training_config.val_datapool if self.training_config else None
244
+
245
+ @property
246
+ def train_loop_fn(self) -> p.TrainingLoop | None:
247
+ """Returns the training loop function from training_config."""
248
+ return self.training_config.train_loop_fn if self.training_config else None
249
+
250
+ @property
251
+ def device(self) -> torch.device:
252
+ """Return a consistent device interface to use across the driver."""
253
+ if self.dist_manager is not None and self.dist_manager.is_initialized():
254
+ return self.dist_manager.device
255
+ else:
256
+ return torch.get_default_device()
257
+
258
+ @property
259
+ def run_id(self) -> str:
260
+ """Returns the run id from the ``DriverConfig``.
261
+
262
+ Returns
263
+ -------
264
+ str
265
+ The run id.
266
+ """
267
+ return self.config.run_id
268
+
269
+ @property
270
+ def log_dir(self) -> Path:
271
+ """Returns the log directory.
272
+
273
+ Note that this is the ``DriverConfig.root_log_dir`` combined
274
+ with the shortened run ID for the current run.
275
+
276
+ Effectively, this means that each run will have its own
277
+ directory for logs, checkpoints, etc.
278
+
279
+ Returns
280
+ -------
281
+ Path
282
+ The log directory.
283
+ """
284
+ return self.config.root_log_dir / self.short_run_id
285
+
286
+ @property
287
+ def short_run_id(self) -> str:
288
+ """Returns the first 8 characters of the run id.
289
+
290
+ The 8 character limit assumes that the run ID is a UUID4.
291
+ This is particularly useful for user-facing interfaces,
292
+ where you do not necessarily want to reference the full UUID.
293
+
294
+ Returns
295
+ -------
296
+ str
297
+ The first 8 characters of the run id.
298
+ """
299
+ return self.run_id[:8]
300
+
301
+ @property
302
+ def last_checkpoint(self) -> Path | None:
303
+ """
304
+ Returns path to the most recently saved checkpoint.
305
+
306
+ Returns
307
+ -------
308
+ Path | None
309
+ Path to the last checkpoint directory, or None if no checkpoint
310
+ has been saved yet.
311
+ """
312
+ return self._last_checkpoint_path
313
+
314
+ @property
315
+ def active_learning_step_idx(self) -> int:
316
+ """
317
+ Returns the current active learning step index.
318
+
319
+ This represents the number of times the active learning step
320
+ has been called, i.e. the number of iterations of the loop.
321
+
322
+ Returns
323
+ -------
324
+ int
325
+ The current active learning step index.
326
+ """
327
+ return self._active_learning_step_idx
328
+
329
+ @active_learning_step_idx.setter
330
+ def active_learning_step_idx(self, value: int) -> None:
331
+ """
332
+ Sets the current active learning step index.
333
+
334
+ Parameters
335
+ ----------
336
+ value: int
337
+ The new active learning step index.
338
+
339
+ Raises
340
+ ------
341
+ ValueError
342
+ If the new active learning step index is negative.
343
+ """
344
+ if value < 0:
345
+ raise ValueError("Active learning step index must be non-negative.")
346
+ self._active_learning_step_idx = value
347
+
348
+ @property
349
+ def dist_manager(self) -> DistributedManager | None:
350
+ """Returns the distributed manager, if it was specified as part
351
+ of the `DriverConfig` configuration.
352
+
353
+ Returns
354
+ -------
355
+ DistributedManager | None
356
+ The distributed manager.
357
+ """
358
+ return self.config.dist_manager
359
+
360
+ def configure_optimizer(self) -> None:
361
+ """Setup optimizer and LR schedulers from training_config."""
362
+ if self.training_config is None:
363
+ self.optimizer = None
364
+ self.lr_scheduler = None
365
+ return
366
+
367
+ opt_cfg = self.training_config.optimizer_config
368
+
369
+ if opt_cfg.optimizer_cls is not None:
370
+ try:
371
+ _ = inspect.signature(opt_cfg.optimizer_cls).bind(
372
+ self.learner.parameters(), **opt_cfg.optimizer_kwargs
373
+ )
374
+ except TypeError as e:
375
+ raise ValueError(
376
+ f"Invalid optimizer kwargs for {opt_cfg.optimizer_cls}; {e}"
377
+ )
378
+ self.optimizer = opt_cfg.optimizer_cls(
379
+ self.learner.parameters(), **opt_cfg.optimizer_kwargs
380
+ )
381
+ else:
382
+ self.optimizer = None
383
+ return
384
+
385
+ if opt_cfg.scheduler_cls is not None and self.optimizer is not None:
386
+ try:
387
+ _ = inspect.signature(opt_cfg.scheduler_cls).bind(
388
+ self.optimizer, **opt_cfg.scheduler_kwargs
389
+ )
390
+ except TypeError as e:
391
+ raise ValueError(
392
+ f"Invalid LR scheduler kwargs for {opt_cfg.scheduler_cls}; {e}"
393
+ )
394
+ self.lr_scheduler = opt_cfg.scheduler_cls(
395
+ self.optimizer, **opt_cfg.scheduler_kwargs
396
+ )
397
+ else:
398
+ self.lr_scheduler = None
399
+ # in the case where we want to reset optimizer states between active learning steps
400
+ if self.config.reset_optim_states and self.is_optimizer_configured:
401
+ self._original_optim_state = deepcopy(self.optimizer.state_dict())
402
+
403
+ @property
404
+ def is_optimizer_configured(self) -> bool:
405
+ """Returns whether the optimizer is configured."""
406
+ return getattr(self, "optimizer", None) is not None
407
+
408
+ @property
409
+ def is_lr_scheduler_configured(self) -> bool:
410
+ """Returns whether the LR scheduler is configured."""
411
+ return getattr(self, "lr_scheduler", None) is not None
412
+
413
+ def attach_strategies(self) -> None:
414
+ """Calls ``strategy.attach`` for all available strategies."""
415
+ super().attach_strategies()
416
+
417
+ def _setup_logger(self) -> None:
418
+ """
419
+ Sets up a persistent logger for the driver.
420
+
421
+ This logger is specialized in that it provides additional context
422
+ information depending on the part of the active learning cycle.
423
+ """
424
+ base_logger = setup_active_learning_logger(
425
+ "core.active_learning",
426
+ run_id=self.run_id,
427
+ log_dir=self.log_dir,
428
+ )
429
+ # Wrap with adapter to automatically include iteration context
430
+ self.logger = ActiveLearningLoggerAdapter(base_logger, driver_ref=self)
431
+
432
+ def _should_checkpoint_at_step(self) -> bool:
433
+ """
434
+ Determine if a checkpoint should be saved at the current AL step.
435
+
436
+ Uses the `checkpoint_interval` from config to decide. If interval is 0,
437
+ checkpointing is disabled. Otherwise, checkpoint at step 0 and every
438
+ N steps thereafter.
439
+
440
+ Returns
441
+ -------
442
+ bool
443
+ True if checkpoint should be saved, False otherwise.
444
+ """
445
+ if self.config.checkpoint_interval == 0:
446
+ return False
447
+ # Always checkpoint at step 0, then every checkpoint_interval steps
448
+ return self.active_learning_step_idx % self.config.checkpoint_interval == 0
449
+
450
+ def _serialize_queue(self, queue: p.AbstractQueue, file_path: Path) -> bool:
451
+ """
452
+ Serialize queue to a file.
453
+
454
+ If queue implements `to_list()`, serialize the list. Otherwise, use
455
+ torch.save to serialize the entire queue object.
456
+
457
+ Parameters
458
+ ----------
459
+ queue: p.AbstractQueue
460
+ The queue to serialize.
461
+ file_path: Path
462
+ Path where the queue should be saved.
463
+
464
+ Returns
465
+ -------
466
+ bool
467
+ True if serialization succeeded, False otherwise.
468
+ """
469
+ try:
470
+ if hasattr(queue, "to_list") and callable(getattr(queue, "to_list")):
471
+ # Use custom serialization method
472
+ queue_data = {"type": "list", "data": queue.to_list()}
473
+ else:
474
+ # Fallback to torch.save for the entire queue
475
+ queue_data = {"type": "torch", "data": queue}
476
+
477
+ torch.save(queue_data, file_path)
478
+ return True
479
+ except (TypeError, AttributeError, pickle.PicklingError, RuntimeError) as e:
480
+ # Some queues cannot be pickled, e.g. stdlib queue.Queue with thread locks
481
+ # Clean up any partially written file
482
+ if file_path.exists():
483
+ file_path.unlink()
484
+
485
+ self.logger.warning(
486
+ f"Failed to serialize queue to {file_path}: {e}. Queue state will not be saved. "
487
+ f"Consider implementing to_list()/from_list() methods for custom serialization."
488
+ )
489
+ return False
490
+
491
+ def _deserialize_queue(self, queue: p.AbstractQueue, file_path: Path) -> None:
492
+ """
493
+ Restore queue from a file.
494
+
495
+ Parameters
496
+ ----------
497
+ queue: p.AbstractQueue
498
+ The queue to restore data into.
499
+ file_path: Path
500
+ Path to the saved queue file.
501
+ """
502
+ if not file_path.exists():
503
+ return
504
+
505
+ try:
506
+ queue_data = torch.load(file_path, map_location="cpu", weights_only=False)
507
+
508
+ if queue_data["type"] == "list":
509
+ if hasattr(queue, "from_list") and callable(
510
+ getattr(queue, "from_list")
511
+ ):
512
+ queue.from_list(queue_data["data"])
513
+ else:
514
+ # Manually populate queue from list
515
+ for item in queue_data["data"]:
516
+ queue.put(item)
517
+ elif queue_data["type"] == "torch":
518
+ # Restore from torch-saved queue - copy items to current queue
519
+ restored_queue = queue_data["data"]
520
+ # Copy items from restored queue to current queue
521
+ while not restored_queue.empty():
522
+ queue.put(restored_queue.get())
523
+ except Exception as e:
524
+ self.logger.warning(
525
+ f"Failed to deserialize queue from {file_path}: {e}. "
526
+ f"Queue will be empty."
527
+ )
528
+
529
+ def save_checkpoint(
530
+ self, path: str | Path | None = None, training_epoch: int | None = None
531
+ ) -> Path | None:
532
+ """
533
+ Save a checkpoint of the active learning experiment.
534
+
535
+ Saves AL orchestration state (configs, queues, step index, phase) and model weights.
536
+ Training-specific state (optimizer, scheduler) is handled by DefaultTrainingLoop
537
+ and saved to training_state.pt during training.
538
+
539
+ Parameters
540
+ ----------
541
+ path: str | Path | None
542
+ Path to save checkpoint. If None, creates path based on current
543
+ AL step index and phase: log_dir/checkpoints/step_{idx}/{phase}/
544
+ training_epoch: int | None
545
+ Optional epoch number for mid-training checkpoints.
546
+
547
+ Returns
548
+ -------
549
+ Path | None
550
+ Checkpoint directory path, or None if checkpoint not saved (non-rank-0 in distributed).
551
+ """
552
+ # Determine checkpoint directory
553
+ if path is None:
554
+ phase_name = self.current_phase if self.current_phase else "init"
555
+ checkpoint_dir = (
556
+ self.log_dir
557
+ / "checkpoints"
558
+ / f"step_{self.active_learning_step_idx}"
559
+ / phase_name
560
+ )
561
+ if training_epoch is not None:
562
+ checkpoint_dir = checkpoint_dir / f"epoch_{training_epoch}"
563
+ else:
564
+ checkpoint_dir = Path(path)
565
+
566
+ # Create checkpoint directory
567
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
568
+
569
+ # Only rank 0 saves checkpoint in distributed setting
570
+ if self.dist_manager is not None and self.dist_manager.is_initialized():
571
+ if self.dist_manager.rank != 0:
572
+ return None
573
+
574
+ # Serialize configurations
575
+ driver_config_json = self.config.to_json()
576
+ strategies_config_dict = self.strategies_config.to_dict()
577
+ training_config_dict = (
578
+ self.training_config.to_dict() if self.training_config else None
579
+ )
580
+
581
+ # Serialize queue states to separate files
582
+ query_queue_file = checkpoint_dir / "query_queue.pt"
583
+ label_queue_file = checkpoint_dir / "label_queue.pt"
584
+ has_query_queue = self._serialize_queue(self.query_queue, query_queue_file)
585
+ has_label_queue = self._serialize_queue(self.label_queue, label_queue_file)
586
+
587
+ # Create checkpoint dataclass (only AL orchestration state)
588
+ checkpoint = ActiveLearningCheckpoint(
589
+ driver_config=driver_config_json,
590
+ strategies_config=strategies_config_dict,
591
+ active_learning_step_idx=self.active_learning_step_idx,
592
+ active_learning_phase=self.current_phase or p.ActiveLearningPhase.TRAINING,
593
+ physicsnemo_version=physicsnemo_version,
594
+ training_config=training_config_dict,
595
+ optimizer_state=None, # Training loop handles this
596
+ lr_scheduler_state=None, # Training loop handles this
597
+ has_query_queue=has_query_queue,
598
+ has_label_queue=has_label_queue,
599
+ )
600
+
601
+ # Add training epoch if in mid-training checkpoint
602
+ checkpoint_dict = {
603
+ "checkpoint": checkpoint,
604
+ }
605
+ if training_epoch is not None:
606
+ checkpoint_dict["training_epoch"] = training_epoch
607
+
608
+ # Save checkpoint metadata
609
+ checkpoint_path = checkpoint_dir / "checkpoint.pt"
610
+ torch.save(checkpoint_dict, checkpoint_path)
611
+
612
+ # Save model weights (separate from training state)
613
+ if isinstance(self.learner, Module):
614
+ model_name = (
615
+ self.learner.meta.name
616
+ if self.learner.meta
617
+ else self.learner.__class__.__name__
618
+ )
619
+ model_path = checkpoint_dir / f"{model_name}.mdlus"
620
+ self.learner.save(str(model_path))
621
+ elif hasattr(self.learner, "module") and isinstance(
622
+ self.learner.module, Module
623
+ ):
624
+ # Unwrap DDP
625
+ model_name = (
626
+ self.learner.module.meta.name
627
+ if self.learner.module.meta
628
+ else self.learner.module.__class__.__name__
629
+ )
630
+ model_path = checkpoint_dir / f"{model_name}.mdlus"
631
+ self.learner.module.save(str(model_path))
632
+ else:
633
+ model_name = self.learner.__class__.__name__
634
+ model_path = checkpoint_dir / f"{model_name}.pt"
635
+ torch.save(self.learner.state_dict(), model_path)
636
+
637
+ # Update last checkpoint path
638
+ self._last_checkpoint_path = checkpoint_dir
639
+
640
+ # Log successful checkpoint save
641
+ self.logger.info(
642
+ f"Saved checkpoint at step {self.active_learning_step_idx}, "
643
+ f"phase {self.current_phase}: {checkpoint_dir}"
644
+ )
645
+
646
+ return checkpoint_dir
647
+
648
+ @classmethod
649
+ def load_checkpoint(
650
+ cls,
651
+ checkpoint_path: str | Path,
652
+ learner: Module | p.LearnerProtocol | None = None,
653
+ train_datapool: p.DataPool | None = None,
654
+ val_datapool: p.DataPool | None = None,
655
+ unlabeled_datapool: p.DataPool | None = None,
656
+ **kwargs: Any,
657
+ ) -> Driver:
658
+ """
659
+ Load a Driver instance from a checkpoint.
660
+
661
+ Given a checkpoint directory, this method will attempt to reconstruct
662
+ the driver and its associated components from the checkpoint. The
663
+ checkpoint path must contain a ``checkpoint.pt`` file, which contains
664
+ the metadata associated with the experiment.
665
+
666
+ Additional parameters that might not be serialized with the checkpointing
667
+ mechanism can/need to be provided to this method; for example when
668
+ using non-`physicsnemo.Module` learners, and any data pools associated
669
+ with the workflow.
670
+
671
+ .. important::
672
+
673
+ Currently, the strategy states are not reloaded from the checkpoint.
674
+ This will be addressed in a future patch, but for now it is recommended
675
+ to back up your strategy states (e.g. metrology records) manually
676
+ before restarting experiments.
677
+
678
+ Parameters
679
+ ----------
680
+ checkpoint_path: str | Path
681
+ Path to checkpoint directory containing checkpoint.pt and model weights.
682
+ learner: Module | p.LearnerProtocol | None
683
+ Learner model to load weights into. If None, will attempt to
684
+ reconstruct from checkpoint (only works for physicsnemo.Module).
685
+ train_datapool: p.DataPool | None
686
+ Training datapool. Required if training_config exists in checkpoint.
687
+ val_datapool: p.DataPool | None
688
+ Validation datapool. Optional.
689
+ unlabeled_datapool: p.DataPool | None
690
+ Unlabeled datapool for query strategies. Optional.
691
+ **kwargs: Any
692
+ Additional keyword arguments to override config values.
693
+
694
+ Returns
695
+ -------
696
+ Driver
697
+ Reconstructed Driver instance ready to resume execution.
698
+ """
699
+ checkpoint_path = Path(checkpoint_path)
700
+
701
+ # Load checkpoint file
702
+ checkpoint_file = checkpoint_path / "checkpoint.pt"
703
+ if not checkpoint_file.exists():
704
+ raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_file}")
705
+
706
+ checkpoint_dict = torch.load(
707
+ checkpoint_file, map_location="cpu", weights_only=False
708
+ )
709
+ checkpoint: ActiveLearningCheckpoint = checkpoint_dict["checkpoint"]
710
+ training_epoch = checkpoint_dict.get("training_epoch", None)
711
+
712
+ # Reconstruct configs
713
+ driver_config = DriverConfig.from_json(
714
+ checkpoint.driver_config, **kwargs.get("driver_config_overrides", {})
715
+ )
716
+
717
+ # TODO add strategy state loading from checkpoint
718
+ strategies_config = StrategiesConfig.from_dict(
719
+ checkpoint.strategies_config,
720
+ unlabeled_datapool=unlabeled_datapool,
721
+ **kwargs.get("strategies_config_overrides", {}),
722
+ )
723
+
724
+ training_config = None
725
+ if checkpoint.training_config is not None:
726
+ training_config = TrainingConfig.from_dict(
727
+ checkpoint.training_config,
728
+ train_datapool=train_datapool,
729
+ val_datapool=val_datapool,
730
+ **kwargs.get("training_config_overrides", {}),
731
+ )
732
+
733
+ # Load or reconstruct learner
734
+ if learner is None:
735
+ # Attempt to reconstruct from checkpoint (only for Module)
736
+ # Try to find any .mdlus file in the checkpoint directory
737
+ mdlus_files = list(checkpoint_path.glob("*.mdlus"))
738
+ if mdlus_files:
739
+ # Use the first .mdlus file found
740
+ model_path = mdlus_files[0]
741
+ learner = Module.from_checkpoint(str(model_path))
742
+ else:
743
+ raise ValueError(
744
+ "No learner provided and unable to reconstruct from checkpoint. "
745
+ "Please provide a learner instance."
746
+ )
747
+ else:
748
+ # Load model weights into provided learner
749
+ # Determine expected model filename based on learner type
750
+ if isinstance(learner, Module):
751
+ model_name = (
752
+ learner.meta.name if learner.meta else learner.__class__.__name__
753
+ )
754
+ model_path = checkpoint_path / f"{model_name}.mdlus"
755
+ if model_path.exists():
756
+ learner.load(str(model_path))
757
+ else:
758
+ # Fallback: try to find any .mdlus file
759
+ mdlus_files = list(checkpoint_path.glob("*.mdlus"))
760
+ if mdlus_files:
761
+ learner.load(str(mdlus_files[0]))
762
+ elif hasattr(learner, "module") and isinstance(learner.module, Module):
763
+ # Unwrap DDP
764
+ model_name = (
765
+ learner.module.meta.name
766
+ if learner.module.meta
767
+ else learner.module.__class__.__name__
768
+ )
769
+ model_path = checkpoint_path / f"{model_name}.mdlus"
770
+ if model_path.exists():
771
+ learner.module.load(str(model_path))
772
+ else:
773
+ # Fallback: try to find any .mdlus file
774
+ mdlus_files = list(checkpoint_path.glob("*.mdlus"))
775
+ if mdlus_files:
776
+ learner.module.load(str(mdlus_files[0]))
777
+ else:
778
+ # Non-Module learner: look for .pt file with class name
779
+ model_name = learner.__class__.__name__
780
+ model_path = checkpoint_path / f"{model_name}.pt"
781
+ if model_path.exists():
782
+ state_dict = torch.load(model_path, map_location="cpu")
783
+ learner.load_state_dict(state_dict)
784
+ else:
785
+ # Fallback: try to find any .pt file
786
+ pt_files = list(checkpoint_path.glob("*.pt"))
787
+ # Filter out checkpoint.pt and queue files
788
+ model_pt_files = [
789
+ f
790
+ for f in pt_files
791
+ if f.name
792
+ not in [
793
+ "checkpoint.pt",
794
+ "query_queue.pt",
795
+ "label_queue.pt",
796
+ "training_state.pt",
797
+ ]
798
+ ]
799
+ if model_pt_files:
800
+ state_dict = torch.load(model_pt_files[0], map_location="cpu")
801
+ learner.load_state_dict(state_dict)
802
+
803
+ # Instantiate Driver
804
+ driver = cls(
805
+ config=driver_config,
806
+ learner=learner,
807
+ strategies_config=strategies_config,
808
+ training_config=training_config,
809
+ inference_fn=kwargs.get("inference_fn", None),
810
+ )
811
+
812
+ # Restore active learning state
813
+ driver.active_learning_step_idx = checkpoint.active_learning_step_idx
814
+ driver.current_phase = checkpoint.active_learning_phase
815
+ driver._last_checkpoint_path = checkpoint_path
816
+
817
+ # Load training state (optimizer, scheduler) if training_config exists
818
+ # This delegates to the training loop's checkpoint loading logic
819
+ if driver.training_config is not None:
820
+ driver.configure_optimizer()
821
+
822
+ # Use training loop to load training state (including model weights again if needed)
823
+ from physicsnemo.active_learning.loop import DefaultTrainingLoop
824
+
825
+ DefaultTrainingLoop.load_training_checkpoint(
826
+ checkpoint_dir=checkpoint_path,
827
+ model=driver.learner,
828
+ optimizer=driver.optimizer,
829
+ lr_scheduler=driver.lr_scheduler
830
+ if hasattr(driver, "lr_scheduler")
831
+ else None,
832
+ )
833
+
834
+ # Restore queue states from separate files
835
+ if checkpoint.has_query_queue:
836
+ query_queue_file = checkpoint_path / "query_queue.pt"
837
+ driver._deserialize_queue(driver.query_queue, query_queue_file)
838
+
839
+ if checkpoint.has_label_queue:
840
+ label_queue_file = checkpoint_path / "label_queue.pt"
841
+ driver._deserialize_queue(driver.label_queue, label_queue_file)
842
+
843
+ driver.logger.info(
844
+ f"Loaded checkpoint from {checkpoint_path} at step "
845
+ f"{checkpoint.active_learning_step_idx}, phase {checkpoint.active_learning_phase}"
846
+ )
847
+ if training_epoch is not None:
848
+ driver.logger.info(f"Resuming from training epoch {training_epoch}")
849
+
850
+ return driver
851
+
852
+ def barrier(self) -> None:
853
+ """
854
+ Wrapper to call barrier on the correct device.
855
+
856
+ Becomes a no-op if distributed is not initialized, otherwise
857
+ will attempt to read the local device ID from either the distributed manager
858
+ or the default device.
859
+ """
860
+ if dist.is_initialized():
861
+ if (
862
+ self.dist_manager is not None
863
+ and self.dist_manager.device.type == "cuda"
864
+ ):
865
+ dist.barrier(device_ids=[self.dist_manager.local_rank])
866
+ elif torch.get_default_device().type == "cuda":
867
+ # this might occur if distributed manager is not used
868
+ dist.barrier(device_ids=[torch.cuda.current_device()])
869
+ else:
870
+ dist.barrier()
871
+
872
+ def _configure_model(self) -> None:
873
+ """
874
+ Method that encapsulates all the logic for preparing the model
875
+ ahead of time.
876
+
877
+ If the distributed manager has been configured and initialized
878
+ with a world size greater than 1, then we wrap the model in DDP.
879
+ Otherwise, we simply move the model to the correct device.
880
+
881
+ After the model has been moved to device, we configure the optimizer
882
+ and learning rate scheduler if training is enabled.
883
+ """
884
+ if self.dist_manager is not None and self.dist_manager.is_initialized():
885
+ if self.dist_manager.world_size > 1 and not isinstance(
886
+ self.learner, DistributedDataParallel
887
+ ):
888
+ # wrap the model in DDP
889
+ self.learner = torch.nn.parallel.DistributedDataParallel(
890
+ self.learner,
891
+ device_ids=[self.dist_manager.local_rank],
892
+ output_device=self.dist_manager.device,
893
+ broadcast_buffers=self.dist_manager.broadcast_buffers,
894
+ find_unused_parameters=self.dist_manager.find_unused_parameters,
895
+ )
896
+ else:
897
+ if self.config.device is not None:
898
+ self.learner = self.learner.to(self.config.device, self.config.dtype)
899
+ # assume all device management is done via the dist_manager, so at this
900
+ # point the model is on the correct device and we can set up the optimizer
901
+ # if we intend to train
902
+ if not self.config.skip_training and not self.is_optimizer_configured:
903
+ self.configure_optimizer()
904
+ if self.is_optimizer_configured and self.config.reset_optim_states:
905
+ self.optimizer.load_state_dict(self._original_optim_state)
906
+
907
+ def _get_phase_index(self, phase: p.ActiveLearningPhase | None) -> int:
908
+ """
909
+ Get index of phase in execution order.
910
+
911
+ Parameters
912
+ ----------
913
+ phase: p.ActiveLearningPhase | None
914
+ Phase to find index for. If None, returns 0 (start from beginning).
915
+
916
+ Returns
917
+ -------
918
+ int
919
+ Index in _PHASE_ORDER (0-3).
920
+ """
921
+ if phase is None:
922
+ return 0
923
+ try:
924
+ return self._PHASE_ORDER.index(phase)
925
+ except ValueError:
926
+ self.logger.warning(
927
+ f"Unknown phase {phase}, defaulting to start from beginning"
928
+ )
929
+ return 0
930
+
931
+ def _build_phase_queue(
932
+ self,
933
+ train_step_fn: p.TrainingProtocol | None,
934
+ validate_step_fn: p.ValidationProtocol | None,
935
+ args: tuple,
936
+ kwargs: dict,
937
+ ) -> list[Any]:
938
+ """
939
+ Build list of phase functions to execute for this AL step.
940
+
941
+ If current_phase is set (e.g., from checkpoint), only phases at or after
942
+ current_phase are included. Otherwise, all non-skipped phases are included.
943
+
944
+ Parameters
945
+ ----------
946
+ train_step_fn: p.TrainingProtocol | None
947
+ Training function to pass to training phase.
948
+ validate_step_fn: p.ValidationProtocol | None
949
+ Validation function to pass to training phase.
950
+ args: tuple
951
+ Additional arguments to pass to phase methods.
952
+ kwargs: dict
953
+ Additional keyword arguments to pass to phase methods.
954
+
955
+ Returns
956
+ -------
957
+ list[Callable]
958
+ Queue of phase functions to execute in order.
959
+ """
960
+ # Define all possible phases with their execution conditions
961
+ all_phases = [
962
+ (
963
+ p.ActiveLearningPhase.TRAINING,
964
+ lambda: self._training_phase(
965
+ train_step_fn, validate_step_fn, *args, **kwargs
966
+ ),
967
+ not self.config.skip_training,
968
+ ),
969
+ (
970
+ p.ActiveLearningPhase.METROLOGY,
971
+ lambda: self._metrology_phase(*args, **kwargs),
972
+ not self.config.skip_metrology,
973
+ ),
974
+ (
975
+ p.ActiveLearningPhase.QUERY,
976
+ lambda: self._query_phase(*args, **kwargs),
977
+ True, # Query phase always runs
978
+ ),
979
+ (
980
+ p.ActiveLearningPhase.LABELING,
981
+ lambda: self._labeling_phase(*args, **kwargs),
982
+ not self.config.skip_labeling,
983
+ ),
984
+ ]
985
+
986
+ # Find starting index based on current_phase (resume point)
987
+ start_idx = self._get_phase_index(self.current_phase)
988
+
989
+ if start_idx > 0:
990
+ self.logger.info(
991
+ f"Resuming AL step {self.active_learning_step_idx} from "
992
+ f"{self.current_phase}"
993
+ )
994
+
995
+ # Build queue: only phases from start_idx onwards that should run
996
+ phase_queue = []
997
+ for idx, (phase, phase_fn, should_run) in enumerate(all_phases):
998
+ # Skip phases before current_phase
999
+ if idx < start_idx:
1000
+ self.logger.debug(
1001
+ f"Skipping {phase} (already completed in this AL step)"
1002
+ )
1003
+ continue
1004
+
1005
+ # Add phase to queue if not skipped by config
1006
+ if should_run:
1007
+ phase_queue.append(phase_fn)
1008
+ else:
1009
+ self.logger.debug(f"Skipping {phase} (disabled in config)")
1010
+
1011
+ return phase_queue
1012
+
1013
+ def _construct_dataloader(
1014
+ self, pool: p.DataPool, shuffle: bool = False, drop_last: bool = False
1015
+ ) -> DataLoader:
1016
+ """
1017
+ Helper method to construct a data loader for a given data pool.
1018
+
1019
+ In the case that a distributed manager was provided, then a distributed
1020
+ sampler will be used, which will be bound to the current rank.
1021
+ Otherwise, a regular sampler will be used. Similarly, if your data
1022
+ structure requires a specialized function to construct batches,
1023
+ then this function can be provided via the `collate_fn` argument.
1024
+
1025
+ Parameters
1026
+ ----------
1027
+ pool: p.DataPool
1028
+ The data pool to construct a data loader for.
1029
+ shuffle: bool = False
1030
+ Whether to shuffle the data.
1031
+ drop_last: bool = False
1032
+ Whether to drop the last batch if it is not complete.
1033
+
1034
+ Returns
1035
+ -------
1036
+ DataLoader
1037
+ The constructed data loader.
1038
+ """
1039
+ # if a distributed manager was omitted, then we assume single process
1040
+ if self.dist_manager is not None and self.dist_manager.is_initialized():
1041
+ sampler = DistributedSampler(
1042
+ pool,
1043
+ num_replicas=self.dist_manager.world_size,
1044
+ rank=self.dist_manager.rank,
1045
+ shuffle=shuffle,
1046
+ drop_last=drop_last,
1047
+ )
1048
+ # set to None, because sampler will handle instead
1049
+ shuffle = None
1050
+ else:
1051
+ sampler = None
1052
+ # fully spec out the data loader
1053
+ pin_memory = False
1054
+ if self.dist_manager is not None and self.dist_manager.is_initialized():
1055
+ if self.dist_manager.device.type == "cuda":
1056
+ pin_memory = True
1057
+ loader = DataLoader(
1058
+ pool,
1059
+ shuffle=shuffle,
1060
+ sampler=sampler,
1061
+ collate_fn=self.config.collate_fn,
1062
+ batch_size=self.config.batch_size,
1063
+ num_workers=self.config.num_dataloader_workers,
1064
+ persistent_workers=self.config.num_dataloader_workers > 0,
1065
+ pin_memory=pin_memory,
1066
+ )
1067
+ return loader
1068
+
1069
+ def active_learning_step(
1070
+ self,
1071
+ train_step_fn: p.TrainingProtocol | None = None,
1072
+ validate_step_fn: p.ValidationProtocol | None = None,
1073
+ *args: Any,
1074
+ **kwargs: Any,
1075
+ ) -> None:
1076
+ """
1077
+ Performs a single active learning iteration.
1078
+
1079
+ This method will perform the following sequence of steps:
1080
+ 1. Train the model stored in ``Driver.learner`` by creating data loaders
1081
+ with ``Driver.train_datapool`` and ``Driver.val_datapool``.
1082
+ 2. Run the metrology strategies stored in ``Driver.metrology_strategies``.
1083
+ 3. Run the query strategies stored in ``Driver.query_strategies``, if available.
1084
+ 4. Run the labeling strategy stored in ``Driver.label_strategy``, if available.
1085
+
1086
+ When entering each stage, we check to ensure all components necessary for the
1087
+ minimum function for that stage are available before proceeding.
1088
+
1089
+ If current_phase is set (e.g., from checkpoint resumption), only phases at
1090
+ or after current_phase will be executed. After completing all phases,
1091
+ current_phase is reset to None for the next AL step.
1092
+
1093
+ Parameters
1094
+ ----------
1095
+ train_step_fn: p.TrainingProtocol | None = None
1096
+ The training function to use for training. If not provided, then the
1097
+ ``Driver.train_loop_fn`` will be used.
1098
+ validate_step_fn: p.ValidationProtocol | None = None
1099
+ The validation function to use for validation. If not provided, then
1100
+ validation will not be performed.
1101
+ args: Any
1102
+ Additional arguments to pass to the method. These will be passed to the
1103
+ training loop, metrology strategies, query strategies, and labeling strategies.
1104
+ kwargs: Any
1105
+ Additional keyword arguments to pass to the method. These will be passed to the
1106
+ training loop, metrology strategies, query strategies, and labeling strategies.
1107
+
1108
+ Raises
1109
+ ------
1110
+ ValueError
1111
+ If any of the required components for a stage are not available.
1112
+ """
1113
+ self._setup_active_learning_step()
1114
+
1115
+ # Build queue of phase functions based on current_phase
1116
+ phase_queue = self._build_phase_queue(
1117
+ train_step_fn, validate_step_fn, args, kwargs
1118
+ )
1119
+
1120
+ # Execute each phase in order (de-populate queue)
1121
+ for phase_fn in phase_queue:
1122
+ phase_fn()
1123
+
1124
+ # Reset current_phase after completing all phases in this AL step
1125
+ self.current_phase = None
1126
+
1127
+ self.logger.debug("Entering barrier for synchronization.")
1128
+ self.barrier()
1129
+ self.active_learning_step_idx += 1
1130
+ self.logger.info(
1131
+ f"Completed active learning step {self.active_learning_step_idx}"
1132
+ )
1133
+
1134
+ def _setup_active_learning_step(self) -> None:
1135
+ """Initialize distributed manager and configure model for the active learning step."""
1136
+ if self.dist_manager is not None and not self.dist_manager.is_initialized():
1137
+ self.logger.info(
1138
+ "Distributed manager configured but not initialized; initializing."
1139
+ )
1140
+ self.dist_manager.initialize()
1141
+ self._configure_model()
1142
+ self.logger.info(
1143
+ f"Starting active learning step {self.active_learning_step_idx}"
1144
+ )
1145
+
1146
+ def _training_phase(
1147
+ self,
1148
+ train_step_fn: p.TrainingProtocol | None,
1149
+ validate_step_fn: p.ValidationProtocol | None,
1150
+ *args: Any,
1151
+ **kwargs: Any,
1152
+ ) -> None:
1153
+ """Execute the training phase of the active learning step."""
1154
+ self._validate_training_requirements(train_step_fn, validate_step_fn)
1155
+
1156
+ # don't need to barrier because it'll be done at the end of training anyway
1157
+ with self._phase_context("training", call_barrier=False):
1158
+ # Note: Training phase checkpointing is handled by the training loop itself
1159
+ # during epoch execution based on model_checkpoint_frequency
1160
+
1161
+ train_loader = self._construct_dataloader(self.train_datapool, shuffle=True)
1162
+ self.logger.info(
1163
+ f"There are {len(train_loader)} batches in the training loader."
1164
+ )
1165
+ val_loader = None
1166
+ if self.val_datapool is not None:
1167
+ if validate_step_fn or hasattr(self.learner, "validation_step"):
1168
+ val_loader = self._construct_dataloader(
1169
+ self.val_datapool, shuffle=False
1170
+ )
1171
+ else:
1172
+ self.logger.warning(
1173
+ "Validation data is available, but no `validate_step_fn` "
1174
+ "or `validation_step` method in Learner is provided."
1175
+ )
1176
+ # if a fine-tuning lr is provided, adjust it after the first iteration
1177
+ if (
1178
+ self.config.fine_tuning_lr is not None
1179
+ and self.active_learning_step_idx > 0
1180
+ ):
1181
+ self.optimizer.param_groups[0]["lr"] = self.config.fine_tuning_lr
1182
+
1183
+ # Determine max epochs to train for this AL step
1184
+ if self.active_learning_step_idx > 0:
1185
+ target_max_epochs = self.training_config.max_fine_tuning_epochs
1186
+ else:
1187
+ target_max_epochs = self.training_config.max_training_epochs
1188
+
1189
+ # Check if resuming from mid-training checkpoint
1190
+ start_epoch = 1
1191
+ epochs_to_train = target_max_epochs
1192
+
1193
+ if self._last_checkpoint_path and self._last_checkpoint_path.exists():
1194
+ training_state_path = self._last_checkpoint_path / "training_state.pt"
1195
+ if training_state_path.exists():
1196
+ training_state = torch.load(
1197
+ training_state_path, map_location="cpu", weights_only=False
1198
+ )
1199
+ last_completed_epoch = training_state.get("training_epoch", 0)
1200
+ if last_completed_epoch > 0:
1201
+ start_epoch = last_completed_epoch + 1
1202
+ epochs_to_train = target_max_epochs - last_completed_epoch
1203
+ self.logger.info(
1204
+ f"Resuming training from epoch {start_epoch} "
1205
+ f"({epochs_to_train} epochs remaining)"
1206
+ )
1207
+
1208
+ # Skip training if all epochs already completed
1209
+ if epochs_to_train <= 0:
1210
+ self.logger.info(
1211
+ f"Training already complete ({target_max_epochs} epochs), "
1212
+ f"skipping training phase"
1213
+ )
1214
+ return
1215
+
1216
+ device = (
1217
+ self.dist_manager.device
1218
+ if self.dist_manager is not None
1219
+ else self.config.device
1220
+ )
1221
+ dtype = self.config.dtype
1222
+
1223
+ # Set checkpoint directory and frequency on training loop
1224
+ # This allows the training loop to handle training state checkpointing internally
1225
+ if hasattr(self.train_loop_fn, "checkpoint_base_dir") and hasattr(
1226
+ self.train_loop_fn, "checkpoint_frequency"
1227
+ ):
1228
+ # Checkpoint base is the current AL step's training directory
1229
+ checkpoint_base = (
1230
+ self.log_dir
1231
+ / "checkpoints"
1232
+ / f"step_{self.active_learning_step_idx}"
1233
+ / "training"
1234
+ )
1235
+ self.train_loop_fn.checkpoint_base_dir = checkpoint_base
1236
+ self.train_loop_fn.checkpoint_frequency = (
1237
+ self.config.model_checkpoint_frequency
1238
+ )
1239
+
1240
+ self.train_loop_fn(
1241
+ self.learner,
1242
+ self.optimizer,
1243
+ train_step_fn=train_step_fn,
1244
+ validate_step_fn=validate_step_fn,
1245
+ train_dataloader=train_loader,
1246
+ validation_dataloader=val_loader,
1247
+ lr_scheduler=self.lr_scheduler,
1248
+ max_epochs=epochs_to_train, # Only remaining epochs
1249
+ device=device,
1250
+ dtype=dtype,
1251
+ **kwargs,
1252
+ )
1253
+
1254
+ def _metrology_phase(self, *args: Any, **kwargs: Any) -> None:
1255
+ """Execute the metrology phase of the active learning step."""
1256
+
1257
+ with self._phase_context("metrology"):
1258
+ for strategy in self.metrology_strategies:
1259
+ self.logger.info(
1260
+ f"Running metrology strategy: {strategy.__class__.__name__}"
1261
+ )
1262
+ strategy(*args, **kwargs)
1263
+ self.logger.info(
1264
+ f"Completed metrics for strategy: {strategy.__class__.__name__}"
1265
+ )
1266
+ strategy.serialize_records(*args, **kwargs)
1267
+
1268
+ def _query_phase(self, *args: Any, **kwargs: Any) -> None:
1269
+ """Execute the query phase of the active learning step."""
1270
+ with self._phase_context("query"):
1271
+ for strategy in self.query_strategies:
1272
+ self.logger.info(
1273
+ f"Running query strategy: {strategy.__class__.__name__}"
1274
+ )
1275
+ strategy(self.query_queue, *args, **kwargs)
1276
+
1277
+ if self.query_queue.empty():
1278
+ self.logger.warning(
1279
+ "Querying strategies produced no samples this iteration."
1280
+ )
1281
+
1282
+ def _labeling_phase(self, *args: Any, **kwargs: Any) -> None:
1283
+ """Execute the labeling phase of the active learning step."""
1284
+ self._validate_labeling_requirements()
1285
+
1286
+ if self.query_queue.empty():
1287
+ self.logger.warning("No samples to label. Skipping labeling phase.")
1288
+ return
1289
+
1290
+ with self._phase_context("labeling"):
1291
+ try:
1292
+ self.label_strategy(self.query_queue, self.label_queue, *args, **kwargs)
1293
+ except Exception as e:
1294
+ self.logger.error(f"Exception encountered during labeling: {e}")
1295
+ self.logger.info("Labeling completed. Now appending to training pool.")
1296
+
1297
+ # TODO this is done serially, could be improved with batched writes
1298
+ sample_counter = 0
1299
+ while not self.label_queue.empty():
1300
+ self.train_datapool.append(self.label_queue.get())
1301
+ sample_counter += 1
1302
+ self.logger.info(f"Appended {sample_counter} samples to training pool.")
1303
+
1304
+ def _validate_training_requirements(
1305
+ self,
1306
+ train_step_fn: p.TrainingProtocol | None,
1307
+ validate_step_fn: p.ValidationProtocol | None,
1308
+ ) -> None:
1309
+ """Validate that all required components for training are available."""
1310
+ if self.training_config is None:
1311
+ raise ValueError(
1312
+ "`training_config` must be provided if `skip_training` is False."
1313
+ )
1314
+ if self.train_loop_fn is None:
1315
+ raise ValueError("`train_loop_fn` must be provided in training_config.")
1316
+ if self.train_datapool is None:
1317
+ raise ValueError("`train_datapool` must be provided in training_config.")
1318
+ if not train_step_fn and not hasattr(self.learner, "training_step"):
1319
+ raise ValueError(
1320
+ "`train_step_fn` must be provided if the model does not implement "
1321
+ "the `training_step` method."
1322
+ )
1323
+ if validate_step_fn and self.val_datapool is None:
1324
+ raise ValueError(
1325
+ "`val_datapool` must be provided in training_config if "
1326
+ "`validate_step_fn` is provided."
1327
+ )
1328
+
1329
+ def _validate_labeling_requirements(self) -> None:
1330
+ """Validate that all required components for labeling are available."""
1331
+ if self.label_strategy is None:
1332
+ raise ValueError(
1333
+ "`label_strategy` must be provided in strategies_config if "
1334
+ "`skip_labeling` is False."
1335
+ )
1336
+ if self.training_config is None or self.train_datapool is None:
1337
+ raise ValueError(
1338
+ "`train_datapool` must be provided in training_config for "
1339
+ "labeling, as data will be appended to it."
1340
+ )
1341
+
1342
+ @contextmanager
1343
+ def _phase_context(
1344
+ self, phase_name: p.ActiveLearningPhase, call_barrier: bool = True
1345
+ ) -> Generator[None, Any, None]:
1346
+ """
1347
+ Context manager for consistent phase tracking, error handling, and synchronization.
1348
+
1349
+ Sets the current phase for logging context, handles exceptions,
1350
+ and synchronizes distributed workers with a barrier. Also triggers
1351
+ checkpoint saves at the start of each phase if configured.
1352
+
1353
+ Parameters
1354
+ ----------
1355
+ phase_name: p.ActiveLearningPhase
1356
+ A discrete phase of the active learning workflow.
1357
+ call_barrier: bool
1358
+ Whether to call barrier for synchronization at the end.
1359
+ """
1360
+ self.current_phase = phase_name
1361
+
1362
+ # Save checkpoint at START of phase if configured
1363
+ # Exception: training phase handles checkpointing internally
1364
+ if phase_name != p.ActiveLearningPhase.TRAINING:
1365
+ should_checkpoint = getattr(
1366
+ self.config, f"checkpoint_on_{phase_name}", False
1367
+ )
1368
+ # Check if we should checkpoint based on interval
1369
+ if should_checkpoint and self._should_checkpoint_at_step():
1370
+ self.save_checkpoint()
1371
+
1372
+ try:
1373
+ yield
1374
+ except Exception as e:
1375
+ self.logger.error(f"Exception encountered during {phase_name}: {e}")
1376
+ raise
1377
+ finally:
1378
+ if call_barrier:
1379
+ self.logger.debug("Entering barrier for synchronization.")
1380
+ self.barrier()
1381
+
1382
+ def run(
1383
+ self,
1384
+ train_step_fn: p.TrainingProtocol | None = None,
1385
+ validate_step_fn: p.ValidationProtocol | None = None,
1386
+ *args: Any,
1387
+ **kwargs: Any,
1388
+ ) -> None:
1389
+ """
1390
+ Runs the active learning loop until the maximum number of
1391
+ active learning steps is reached.
1392
+
1393
+ Parameters
1394
+ ----------
1395
+ train_step_fn: p.TrainingProtocol | None = None
1396
+ The training function to use for training. If not provided, then the
1397
+ ``Driver.train_loop_fn`` will be used.
1398
+ validate_step_fn: p.ValidationProtocol | None = None
1399
+ The validation function to use for validation. If not provided, then
1400
+ validation will not be performed.
1401
+ args: Any
1402
+ Additional arguments to pass to the method. These will be passed to the
1403
+ training loop, metrology strategies, query strategies, and labeling strategies.
1404
+ kwargs: Any
1405
+ Additional keyword arguments to pass to the method. These will be passed to the
1406
+ training loop, metrology strategies, query strategies, and labeling strategies.
1407
+ """
1408
+ # TODO: refactor initialization logic here instead of inside the step
1409
+ while self.active_learning_step_idx < self.config.max_active_learning_steps:
1410
+ self.active_learning_step(
1411
+ train_step_fn=train_step_fn,
1412
+ validate_step_fn=validate_step_fn,
1413
+ *args,
1414
+ **kwargs,
1415
+ )
1416
+
1417
+ def __call__(
1418
+ self,
1419
+ train_step_fn: p.TrainingProtocol | None = None,
1420
+ validate_step_fn: p.ValidationProtocol | None = None,
1421
+ *args: Any,
1422
+ **kwargs: Any,
1423
+ ) -> None:
1424
+ """
1425
+ Provides syntactic sugar for running the active learning loop.
1426
+
1427
+ Calls ``Driver.run`` internally.
1428
+
1429
+ Parameters
1430
+ ----------
1431
+ train_step_fn: p.TrainingProtocol | None = None
1432
+ The training function to use for training. If not provided, then the
1433
+ ``Driver.train_loop_fn`` will be used.
1434
+ validate_step_fn: p.ValidationProtocol | None = None
1435
+ The validation function to use for validation. If not provided, then
1436
+ validation will not be performed.
1437
+ args: Any
1438
+ Additional arguments to pass to the method. These will be passed to the
1439
+ training loop, metrology strategies, query strategies, and labeling strategies.
1440
+ kwargs: Any
1441
+ Additional keyword arguments to pass to the method. These will be passed to the
1442
+ training loop, metrology strategies, query strategies, and labeling strategies.
1443
+ """
1444
+ self.run(
1445
+ train_step_fn=train_step_fn,
1446
+ validate_step_fn=validate_step_fn,
1447
+ *args,
1448
+ **kwargs,
1449
+ )
physics_mcp/source/physicsnemo/active_learning/logger.py ADDED
@@ -0,0 +1,330 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from __future__ import annotations
18
+
19
+ import json
20
+ import logging
21
+ from contextlib import contextmanager
22
+ from datetime import datetime
23
+ from pathlib import Path
24
+ from threading import local
25
+ from typing import Any
26
+
27
+ try:
28
+ from termcolor import colored
29
+ except ImportError:
30
+ colored = None
31
+
32
+
33
+ # Thread-local storage for context information
34
+ _context_storage = local()
35
+
36
+
37
+ class ActiveLearningLoggerAdapter(logging.LoggerAdapter):
38
+ """Logger adapter that automatically includes active learning iteration context.
39
+
40
+ This adapter automatically adds iteration information to log messages
41
+ by accessing the driver's current iteration state.
42
+ """
43
+
44
+ def __init__(self, logger: logging.Logger, driver_ref: Any = None):
45
+ """Initialize the adapter with a logger and optional driver reference.
46
+
47
+ Parameters
48
+ ----------
49
+ logger : logging.Logger
50
+ The underlying logger to adapt
51
+ driver_ref : Any, optional
52
+ Reference to the driver object to get iteration context from
53
+ """
54
+ super().__init__(logger, {})
55
+ self.driver_ref = driver_ref
56
+
57
+ def process(self, msg: str, kwargs: dict[str, Any]) -> tuple[str, dict[str, Any]]:
58
+ """Process the log message to add iteration, run ID, and phase context.
59
+
60
+ Parameters
61
+ ----------
62
+ msg : str
63
+ The log message
64
+ kwargs : dict[str, Any]
65
+ Additional keyword arguments
66
+
67
+ Returns
68
+ -------
69
+ tuple[str, dict[str, Any]]
70
+ Processed message and kwargs
71
+ """
72
+ # Add iteration, run ID, and phase context if driver reference is available
73
+ if self.driver_ref is not None:
74
+ extra = kwargs.get("extra", {})
75
+
76
+ # Add iteration context
77
+ if hasattr(self.driver_ref, "active_learning_step_idx"):
78
+ iteration = getattr(self.driver_ref, "active_learning_step_idx", None)
79
+ if iteration is not None:
80
+ extra["iteration"] = iteration
81
+
82
+ # Add run ID context
83
+ if hasattr(self.driver_ref, "run_id"):
84
+ run_id = getattr(self.driver_ref, "run_id", None)
85
+ if run_id is not None:
86
+ extra["run_id"] = run_id
87
+
88
+ # Add current phase context
89
+ if hasattr(self.driver_ref, "current_phase"):
90
+ phase = getattr(self.driver_ref, "current_phase", None)
91
+ if phase is not None:
92
+ extra["phase"] = phase
93
+
94
+ if extra:
95
+ kwargs["extra"] = extra
96
+
97
+ return msg, kwargs
98
+
99
+
100
+ class JSONFormatter(logging.Formatter):
101
+ """JSON formatter for structured logging to files.
102
+
103
+ This formatter converts log records to JSON format, including all
104
+ contextual information and metadata for structured analysis.
105
+ """
106
+
107
+ def format(self, record: logging.LogRecord) -> str:
108
+ """Format the log record as JSON.
109
+
110
+ Parameters
111
+ ----------
112
+ record : logging.LogRecord
113
+ The log record to format
114
+
115
+ Returns
116
+ -------
117
+ str
118
+ JSON-formatted log message
119
+ """
120
+ log_entry = {
121
+ "timestamp": datetime.fromtimestamp(record.created).isoformat(),
122
+ "level": record.levelname,
123
+ "logger": record.name,
124
+ "message": record.getMessage(),
125
+ "module": record.module,
126
+ "function": record.funcName,
127
+ "line": record.lineno,
128
+ }
129
+
130
+ # Add contextual information if available
131
+ if hasattr(record, "context"):
132
+ log_entry["context"] = record.context
133
+
134
+ if hasattr(record, "caller_object"):
135
+ log_entry["caller_object"] = record.caller_object
136
+
137
+ if hasattr(record, "iteration"):
138
+ log_entry["iteration"] = record.iteration
139
+
140
+ if hasattr(record, "phase"):
141
+ log_entry["phase"] = record.phase
142
+
143
+ extra_keys = list(filter(lambda x: x not in log_entry, record.__dict__.keys()))
144
+ # Add any extra fields
145
+ for key in extra_keys:
146
+ log_entry[key] = record.__dict__[key]
147
+
148
+ return json.dumps(log_entry)
149
+
150
+
151
+ def _get_context_stack():
152
+ """Get the context stack for the current thread."""
153
+ if not hasattr(_context_storage, "context_stack"):
154
+ _context_storage.context_stack = []
155
+ return _context_storage.context_stack
156
+
157
+
158
+ class ContextFormatter(logging.Formatter):
159
+ """Standard formatter that includes active learning context information with colors."""
160
+
161
+ def format(self, record):
162
+ # Build context string
163
+ context_parts = []
164
+ if hasattr(record, "caller_object") and record.caller_object:
165
+ context_parts.append(f"obj:{record.caller_object}")
166
+ if hasattr(record, "run_id") and record.run_id:
167
+ context_parts.append(f"run:{record.run_id}")
168
+ if hasattr(record, "iteration") and record.iteration is not None:
169
+ context_parts.append(f"iter:{record.iteration}")
170
+ if hasattr(record, "phase") and record.phase:
171
+ context_parts.append(f"phase:{record.phase}")
172
+ if hasattr(record, "context") and record.context:
173
+ for key, value in record.context.items():
174
+ context_parts.append(f"{key}:{value}")
175
+
176
+ context_str = f"[{', '.join(context_parts)}]" if context_parts else ""
177
+
178
+ # Use standard formatting
179
+ base_msg = super().format(record)
180
+
181
+ # Add color to the message based on level if termcolor is available
182
+ if colored is not None:
183
+ match record.levelno:
184
+ case level if level >= logging.ERROR:
185
+ base_msg = colored(base_msg, "red")
186
+ case level if level >= logging.WARNING:
187
+ base_msg = colored(base_msg, "yellow")
188
+ case level if level >= logging.INFO:
189
+ base_msg = colored(base_msg, "white")
190
+ case _: # DEBUG
191
+ base_msg = colored(base_msg, "cyan")
192
+
193
+ # Add colored context string
194
+ if context_str:
195
+ if colored is not None:
196
+ context_str = colored(context_str, "blue")
197
+ base_msg += f" {context_str}"
198
+
199
+ return base_msg
200
+
201
+
202
+ class ContextInjectingFilter(logging.Filter):
203
+ """Filter that injects contextual information into log records."""
204
+
205
+ def filter(self, record):
206
+ # Add context information from thread-local storage
207
+ context_stack = _get_context_stack()
208
+ if context_stack:
209
+ current_context = context_stack[-1]
210
+ if current_context["caller_object"]:
211
+ record.caller_object = current_context["caller_object"]
212
+ if current_context["iteration"] is not None:
213
+ record.iteration = current_context["iteration"]
214
+ if current_context.get("phase"):
215
+ record.phase = current_context["phase"]
216
+ if current_context["context"]:
217
+ record.context = current_context["context"]
218
+ return True
219
+
220
+
221
+ def setup_active_learning_logger(
222
+ name: str,
223
+ run_id: str,
224
+ log_dir: str | Path = Path("active_learning_logs"),
225
+ level: int = logging.INFO,
226
+ ) -> logging.Logger:
227
+ """Set up a logger with active learning-specific formatting and handlers.
228
+
229
+ Parameters
230
+ ----------
231
+ name : str
232
+ Logger name
233
+ run_id : str
234
+ Unique identifier for this run, used in log filename
235
+ log_dir : str | Path, optional
236
+ Directory to store log files, by default "./logs"
237
+ level : int, optional
238
+ Logging level, by default logging.INFO
239
+
240
+ Returns
241
+ -------
242
+ logging.Logger
243
+ Configured standard Python logger
244
+
245
+ Example
246
+ -------
247
+ >>> logger = setup_active_learning_logger("experiment", "run_001")
248
+ >>> logger.info("Starting experiment")
249
+ >>> with log_context(caller_object="Trainer", iteration=5):
250
+ ... logger.info("Training step")
251
+ """
252
+ # Get standard logger
253
+ logger = logging.getLogger(name)
254
+ logger.setLevel(level)
255
+
256
+ # Clear any existing handlers to avoid duplicates
257
+ logger.handlers.clear()
258
+
259
+ # Disable propagation to prevent duplicate messages from parent loggers
260
+ logger.propagate = False
261
+
262
+ # Create log directory if it doesn't exist
263
+ if isinstance(log_dir, str):
264
+ log_dir_path = Path(log_dir)
265
+ else:
266
+ log_dir_path = log_dir
267
+ log_dir_path.mkdir(parents=True, exist_ok=True)
268
+
269
+ # Set up console handler with standard formatting
270
+ console_handler = logging.StreamHandler()
271
+ console_handler.setLevel(logging.DEBUG)
272
+ console_handler.setFormatter(
273
+ ContextFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
274
+ )
275
+ console_handler.addFilter(ContextInjectingFilter())
276
+ logger.addHandler(console_handler)
277
+
278
+ # Set up file handler with JSON formatting
279
+ log_file = log_dir_path / f"{run_id}.log"
280
+ file_handler = logging.FileHandler(log_file, mode="w")
281
+ file_handler.setLevel(logging.DEBUG)
282
+ file_handler.setFormatter(JSONFormatter())
283
+ file_handler.addFilter(ContextInjectingFilter())
284
+ logger.addHandler(file_handler)
285
+
286
+ return logger
287
+
288
+
289
+ @contextmanager
290
+ def log_context(
291
+ caller_object: str | None = None,
292
+ iteration: int | None = None,
293
+ phase: str | None = None,
294
+ **kwargs: Any,
295
+ ):
296
+ """Context manager for adding contextual information to log messages.
297
+
298
+ Parameters
299
+ ----------
300
+ caller_object : str, optional
301
+ Name or identifier of the object making the log call
302
+ iteration : int, optional
303
+ Current iteration counter
304
+ phase : str, optional
305
+ Current phase of the active learning process
306
+ **kwargs : Any
307
+ Additional contextual key-value pairs
308
+
309
+ Example
310
+ -------
311
+ >>> from logging import getLogger
312
+ >>> from physicsnemo.active_learning.logger import log_context
313
+ >>> logger = getLogger("my_logger")
314
+ >>> with log_context(caller_object="Trainer", iteration=5, phase="training", epoch=2):
315
+ ... logger.info("Processing batch")
316
+ """
317
+ context_info = {
318
+ "caller_object": caller_object,
319
+ "iteration": iteration,
320
+ "phase": phase,
321
+ "context": kwargs,
322
+ }
323
+
324
+ context_stack = _get_context_stack()
325
+ context_stack.append(context_info)
326
+
327
+ try:
328
+ yield
329
+ finally:
330
+ context_stack.pop()
physics_mcp/source/physicsnemo/active_learning/loop.py ADDED
@@ -0,0 +1,534 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from __future__ import annotations
18
+
19
+ import inspect
20
+ from pathlib import Path
21
+ from typing import Any
22
+
23
+ import torch
24
+ from torch.optim import Optimizer
25
+ from torch.optim.lr_scheduler import _LRScheduler
26
+ from torch.utils.data import DataLoader
27
+ from tqdm import tqdm
28
+
29
+ from physicsnemo import Module
30
+ from physicsnemo.active_learning import protocols as p
31
+ from physicsnemo.distributed import DistributedManager
32
+ from physicsnemo.launch.logging import LaunchLogger
33
+ from physicsnemo.utils.capture import StaticCaptureEvaluateNoGrad, StaticCaptureTraining
34
+
35
+ __all__ = ["DefaultTrainingLoop"]
36
+
37
+
38
+ def _recursive_data_device_cast(
39
+ data: Any,
40
+ device: torch.device | str | None = None,
41
+ dtype: torch.dtype | None = None,
42
+ **kwargs: Any,
43
+ ) -> Any:
44
+ """
45
+ Recursively moves/cast input data to a specified device and dtype.
46
+
47
+ For iterable objects, we recurse through the elements depending on
48
+ the type of iterable until we reach an object that either has a ``to``
49
+ method that can be called, or just returns the data unchanged.
50
+
51
+ Parameters
52
+ ----------
53
+ data: Any
54
+ The data to move to the device.
55
+ device: torch.device | str | None = None
56
+ The device to move the data to.
57
+ dtype: torch.dtype | None = None
58
+ The dtype to move the data to.
59
+ kwargs: Any
60
+ Additional keyword arguments to pass to the `to` method.
61
+ By default, `non_blocking` is set to `True` to allow
62
+ asynchronous data transfers.
63
+
64
+ Returns
65
+ -------
66
+ Any
67
+ The data moved to the device.
68
+ """
69
+ kwargs.setdefault("non_blocking", True)
70
+ if hasattr(data, "to"):
71
+ # if there is a `to` method, then we can just call it
72
+ return data.to(device=device, dtype=dtype, **kwargs)
73
+ elif isinstance(data, dict):
74
+ return {
75
+ k: _recursive_data_device_cast(v, device, dtype) for k, v in data.items()
76
+ }
77
+ elif isinstance(data, list):
78
+ return [_recursive_data_device_cast(v, device, dtype) for v in data]
79
+ elif isinstance(data, tuple):
80
+ return tuple(_recursive_data_device_cast(v, device, dtype) for v in data)
81
+ else:
82
+ return data
83
+
84
+
85
+ class DefaultTrainingLoop(p.TrainingLoop):
86
+ def __new__(cls, *args: Any, **kwargs: Any) -> DefaultTrainingLoop:
87
+ """
88
+ Wrapper for instantiating DefaultTrainingLoop.
89
+
90
+ This method captures arguments used to instantiate the loop
91
+ and stores them in the `_args` attribute for serialization.
92
+ This follows the same pattern as `ActiveLearningProtocol.__new__`.
93
+
94
+ Parameters
95
+ ----------
96
+ args: Any
97
+ Arguments to pass to the loop's constructor.
98
+ kwargs: Any
99
+ Keyword arguments to pass to the loop's constructor.
100
+
101
+ Returns
102
+ -------
103
+ DefaultTrainingLoop
104
+ A new instance with an `_args` attribute for serialization.
105
+ """
106
+ out = super().__new__(cls)
107
+
108
+ # Get signature of __init__ function
109
+ sig = inspect.signature(cls.__init__)
110
+
111
+ # Bind args and kwargs to signature
112
+ bound_args = sig.bind_partial(
113
+ *([None] + list(args)), **kwargs
114
+ ) # Add None to account for self
115
+ bound_args.apply_defaults()
116
+
117
+ # Get args and kwargs (excluding self and unroll kwargs)
118
+ instantiate_args = {}
119
+ for param, (k, v) in zip(sig.parameters.values(), bound_args.arguments.items()):
120
+ # Skip self
121
+ if k == "self":
122
+ continue
123
+
124
+ # Add args and kwargs to instantiate_args
125
+ if param.kind == param.VAR_KEYWORD:
126
+ instantiate_args.update(v)
127
+ else:
128
+ # Special handling for device: convert torch.device to string
129
+ if k == "device" and isinstance(v, torch.device):
130
+ instantiate_args[k] = str(v)
131
+ # Special handling for dtype: convert to string representation
132
+ elif k == "dtype" and isinstance(v, torch.dtype):
133
+ instantiate_args[k] = str(v)
134
+ else:
135
+ instantiate_args[k] = v
136
+
137
+ # Store args needed for instantiation
138
+ out._args = {
139
+ "__name__": cls.__name__,
140
+ "__module__": cls.__module__,
141
+ "__args__": instantiate_args,
142
+ }
143
+ return out
144
+
145
+ def __init__(
146
+ self,
147
+ train_step_fn: p.TrainingProtocol | None = None,
148
+ validate_step_fn: p.ValidationProtocol | None = None,
149
+ enable_static_capture: bool = True,
150
+ use_progress_bars: bool = True,
151
+ device: str | torch.device | None = None,
152
+ dtype: torch.dtype | None = None,
153
+ checkpoint_frequency: int = 0,
154
+ **capture_kwargs: Any,
155
+ ) -> None:
156
+ """
157
+ Initializes the default training loop.
158
+
159
+ The general usage of this loop is to
160
+
161
+ TODO: add support for early stopping
162
+
163
+ Parameters
164
+ ----------
165
+ train_step_fn: TrainingProtocol | None = None
166
+ A callable that implements the logic for performing a single
167
+ training step. See ``protocols.TrainingProtocol`` for the expected
168
+ interface, but ultimately the function should return a scalar loss
169
+ value that has a ``backward`` method.
170
+ validate_step_fn: ValidationProtocol | None = None
171
+ A callable that implements the logic for performing a single
172
+ validation step. See ``protocols.ValidationProtocol`` for the expected
173
+ interface, but in contrast to ``train_step_fn`` this function should
174
+ not return anything.
175
+ enable_static_capture: bool = True
176
+ Whether to enable static capture for the training and validation steps.
177
+ use_progress_bars: bool = True
178
+ Whether to show ``tqdm`` progress bars to display epoch and step progress.
179
+ device: str | torch.device | None = None
180
+ The device used for performing the loop. If not provided, then the device
181
+ will default to the model's device at runtime.
182
+ dtype: torch.dtype | None = None
183
+ The dtype used for performing the loop. If not provided, then the dtype
184
+ will default to ``torch.get_default_dtype()``.
185
+ checkpoint_frequency: int = 0
186
+ How often to save checkpoints during training (every N epochs).
187
+ If 0, no checkpoints are saved during training. Set via Driver before
188
+ training execution.
189
+ capture_kwargs: Any
190
+ Additional keyword arguments to pass to the static capture decorators.
191
+ """
192
+ self.train_step_fn = train_step_fn
193
+ self.validate_step_fn = validate_step_fn
194
+ self.enable_static_capture = enable_static_capture
195
+ if isinstance(device, str):
196
+ device = torch.device(device)
197
+ # check to see if we can rely on DistributedManager
198
+ if device is None and DistributedManager.is_initialized():
199
+ device = DistributedManager.device
200
+ self.device = device
201
+ if dtype is None:
202
+ dtype = torch.get_default_dtype()
203
+ self.dtype = dtype
204
+ self.capture_kwargs = capture_kwargs
205
+ self.use_progress_bars = use_progress_bars
206
+ self.capture_functions = {}
207
+ self.checkpoint_frequency = checkpoint_frequency
208
+ self.checkpoint_base_dir: Path | None = None
209
+
210
+ def save_training_checkpoint(
211
+ self,
212
+ checkpoint_dir: Path,
213
+ model: Module | p.LearnerProtocol,
214
+ optimizer: Optimizer,
215
+ lr_scheduler: _LRScheduler | None = None,
216
+ training_epoch: int | None = None,
217
+ ) -> None:
218
+ """
219
+ Save training state to checkpoint directory.
220
+
221
+ Model weights are saved separately. Optimizer, scheduler, and epoch
222
+ metadata are combined into a single training_state.pt file.
223
+
224
+ Parameters
225
+ ----------
226
+ checkpoint_dir: Path
227
+ Directory to save checkpoint files.
228
+ model: Module | p.LearnerProtocol
229
+ Model to save weights for.
230
+ optimizer: Optimizer
231
+ Optimizer to save state from.
232
+ lr_scheduler: _LRScheduler | None
233
+ Optional LR scheduler to save state from.
234
+ training_epoch: int | None
235
+ Current training epoch for metadata.
236
+ """
237
+ checkpoint_dir.mkdir(parents=True, exist_ok=True)
238
+
239
+ # Save model weights separately
240
+ if isinstance(model, Module):
241
+ model_path = checkpoint_dir / "model.mdlus"
242
+ model.save(str(model_path))
243
+ else:
244
+ model_path = checkpoint_dir / "model_state.pt"
245
+ torch.save(model.state_dict(), model_path)
246
+
247
+ # Combine optimizer, scheduler, and epoch metadata into single file
248
+ training_state = {
249
+ "optimizer_state": optimizer.state_dict(),
250
+ "lr_scheduler_state": lr_scheduler.state_dict() if lr_scheduler else None,
251
+ "training_epoch": training_epoch,
252
+ }
253
+ training_state_path = checkpoint_dir / "training_state.pt"
254
+ torch.save(training_state, training_state_path)
255
+
256
+ @staticmethod
257
+ def load_training_checkpoint(
258
+ checkpoint_dir: Path,
259
+ model: Module | p.LearnerProtocol,
260
+ optimizer: Optimizer,
261
+ lr_scheduler: _LRScheduler | None = None,
262
+ ) -> int | None:
263
+ """
264
+ Load training state from checkpoint directory.
265
+
266
+ Model weights are loaded separately. Optimizer, scheduler, and epoch
267
+ metadata are loaded from the combined training_state.pt file.
268
+
269
+ Parameters
270
+ ----------
271
+ checkpoint_dir: Path
272
+ Directory containing checkpoint files.
273
+ model: Module | p.LearnerProtocol
274
+ Model to load weights into.
275
+ optimizer: Optimizer
276
+ Optimizer to load state into.
277
+ lr_scheduler: _LRScheduler | None
278
+ Optional LR scheduler to load state into.
279
+
280
+ Returns
281
+ -------
282
+ int | None
283
+ Training epoch from metadata if available, else None.
284
+ """
285
+ # Load model weights separately
286
+ if isinstance(model, Module):
287
+ model_path = checkpoint_dir / "model.mdlus"
288
+ if model_path.exists():
289
+ model.load(str(model_path))
290
+ else:
291
+ model_state_path = checkpoint_dir / "model_state.pt"
292
+ if model_state_path.exists():
293
+ state_dict = torch.load(model_state_path, map_location="cpu")
294
+ model.load_state_dict(state_dict)
295
+
296
+ # Load combined training state (optimizer, scheduler, epoch)
297
+ training_state_path = checkpoint_dir / "training_state.pt"
298
+ if training_state_path.exists():
299
+ training_state = torch.load(training_state_path, map_location="cpu")
300
+
301
+ # Restore optimizer state
302
+ if "optimizer_state" in training_state:
303
+ optimizer.load_state_dict(training_state["optimizer_state"])
304
+
305
+ # Restore scheduler state if present
306
+ if lr_scheduler and training_state.get("lr_scheduler_state"):
307
+ lr_scheduler.load_state_dict(training_state["lr_scheduler_state"])
308
+
309
+ # Return epoch metadata
310
+ return training_state.get("training_epoch", None)
311
+
312
+ return None
313
+
314
+ @property
315
+ def amp_type(self) -> torch.dtype:
316
+ if self.dtype in [torch.float16, torch.bfloat16]:
317
+ return self.dtype
318
+ else:
319
+ return torch.float16
320
+
321
+ def _create_capture_functions(
322
+ self,
323
+ model: Module | p.LearnerProtocol,
324
+ optimizer: Optimizer,
325
+ train_step_fn: p.TrainingProtocol | None = None,
326
+ validate_step_fn: p.ValidationProtocol | None = None,
327
+ ) -> tuple[p.TrainingProtocol | None, p.ValidationProtocol | None]:
328
+ """
329
+ Attempt to create static capture functions based off training and validation
330
+ functions.
331
+
332
+ This uses the Python object IDs to unique identify functions, and adds the
333
+ decorated functions to an internal `capture_functions` dictionary. If the
334
+ decorated functions already exist, then this function will be no-op.
335
+
336
+ Parameters
337
+ ----------
338
+ model: Module | p.LearnerProtocol
339
+ The model to train.
340
+ optimizer: Optimizer
341
+ The optimizer to use for training.
342
+ train_step_fn: p.TrainingProtocol | None = None
343
+ The training function to use for training.
344
+ validate_step_fn: p.ValidationProtocol | None = None
345
+ The validation function to use for validation.
346
+
347
+ Returns
348
+ -------
349
+ tuple[p.TrainingProtocol | None, p.ValidationProtocol | None]
350
+ The training and validation functions with static capture applied.
351
+ """
352
+ if not train_step_fn:
353
+ train_step_fn = self.train_step_fn
354
+ train_func_id = id(train_step_fn)
355
+ if train_func_id not in self.capture_functions:
356
+ try:
357
+ train_step_fn = StaticCaptureTraining(
358
+ model=model,
359
+ optim=optimizer,
360
+ amp_type=self.amp_type,
361
+ **self.capture_kwargs,
362
+ )(train_step_fn)
363
+ self.capture_functions[train_func_id] = train_step_fn
364
+ except Exception as e:
365
+ raise RuntimeError(
366
+ "Failed to create static capture for `train_step_fn`. "
367
+ ) from e
368
+ else:
369
+ train_step_fn = self.capture_functions[train_func_id]
370
+ if not validate_step_fn:
371
+ validate_step_fn = self.validate_step_fn
372
+ if validate_step_fn:
373
+ val_func_id = id(validate_step_fn)
374
+ if val_func_id not in self.capture_functions:
375
+ try:
376
+ validate_step_fn = StaticCaptureEvaluateNoGrad(
377
+ model=model, amp_type=self.amp_type, **self.capture_kwargs
378
+ )(validate_step_fn)
379
+ self.capture_functions[val_func_id] = validate_step_fn
380
+ except Exception as e:
381
+ raise RuntimeError(
382
+ "Failed to create static capture for `validate_step_fn`. "
383
+ ) from e
384
+ else:
385
+ validate_step_fn = self.capture_functions[val_func_id]
386
+ return train_step_fn, validate_step_fn
387
+
388
+ def __call__(
389
+ self,
390
+ model: Module | p.LearnerProtocol,
391
+ optimizer: Optimizer,
392
+ train_dataloader: DataLoader,
393
+ max_epochs: int,
394
+ validation_dataloader: DataLoader | None = None,
395
+ train_step_fn: p.TrainingProtocol | None = None,
396
+ validate_step_fn: p.ValidationProtocol | None = None,
397
+ lr_scheduler: _LRScheduler | None = None,
398
+ device: str | torch.device | None = None,
399
+ dtype: torch.dtype | None = None,
400
+ *args: Any,
401
+ **kwargs: Any,
402
+ ) -> None:
403
+ """
404
+ Performs ``max_epochs`` epochs of training and optionally validation.
405
+
406
+ Some of the arguments, such as ``train_step_fn`` and ``validate_step_fn``,
407
+ are optional only if the ``model`` implements the ``p.LearnerProtocol``.
408
+ If they are passed, however, they will take precedence over the methods
409
+ originally provided to the constructor method.
410
+
411
+ The bare minimum required arguments for this loop to work are:
412
+ 1. A model to train
413
+ 2. An optimizer to step
414
+ 3. A training dataloader to iterate over
415
+ 4. The maximum number of epochs to train for
416
+
417
+ If validation is required, then both ``validation_dataloader`` and
418
+ ``validate_step_fn`` must be specified.
419
+
420
+ Parameters
421
+ ----------
422
+ model: Module | p.LearnerProtocol
423
+ The model to train.
424
+ optimizer: torch.optim.Optimizer
425
+ The optimizer to use for training.
426
+ train_dataloader: DataLoader
427
+ The dataloader to use for training.
428
+ max_epochs: int
429
+ The number of epochs to train for.
430
+ validation_dataloader: DataLoader | None
431
+ The dataloader to use for validation. If not provided, then validation
432
+ will not be performed.
433
+ train_step_fn: p.TrainingProtocol | None = None
434
+ The training function to use for training. If passed, it will take
435
+ precedence over the method provided to the constructor method.
436
+ validate_step_fn: p.ValidationProtocol | None = None
437
+ The validation function to use for validation.
438
+ lr_scheduler: torch.optim.lr_scheduler._LRScheduler | None = None
439
+ The learning rate scheduler to use for training.
440
+ device: str | torch.device | None = None
441
+ The device used for performing the loop. If provided, it will
442
+ override the device specified in the constructor. If both values
443
+ are not provided, then we default to PyTorch's default device.
444
+ dtype: torch.dtype | None = None
445
+ The dtype used for performing the loop. If provided, it will
446
+ override the dtype specified in the constructor. If both values
447
+ are not provided, then we default to PyTorch's default dtype.
448
+ args: Any
449
+ Additional arguments to pass the training and validation
450
+ step functions.
451
+ kwargs: Any
452
+ Additional keyword arguments to pass the training and validation
453
+ step functions.
454
+ """
455
+ if not train_step_fn and not self.train_step_fn:
456
+ raise RuntimeError(
457
+ """
458
+ No training step function provided.
459
+ Either provide a `train_step_fn` to this constructor, or
460
+ provide a `train_step_fn` to the `__call__` method.
461
+ """
462
+ )
463
+ if not device and not self.device:
464
+ device = torch.get_default_device()
465
+ if not dtype and not self.dtype:
466
+ dtype = torch.get_default_dtype()
467
+ # if a device is specified, move the model
468
+ if device and device != model.device:
469
+ # not 100% sure this will trigger issues with the optimizer
470
+ # but allows a potentially different device to be used
471
+ model = model.to(device)
472
+ if self.enable_static_capture:
473
+ # if static capture is enabled, we check for a cache hit based on
474
+ # the incoming function IDs. If we miss, we then create new wrappers.
475
+ train_step_fn, validate_step_fn = self._create_capture_functions(
476
+ model, optimizer, train_step_fn, validate_step_fn
477
+ )
478
+ epoch_iter = range(1, max_epochs + 1)
479
+ if self.use_progress_bars:
480
+ epoch_iter = tqdm(epoch_iter, desc="Epoch", leave=False, position=0)
481
+ ########### EPOCH LOOP ###########
482
+ for epoch in epoch_iter:
483
+ model.train()
484
+ train_iter = iter(train_dataloader)
485
+ if self.use_progress_bars:
486
+ train_iter = tqdm(
487
+ train_iter, desc="Training step", leave=False, unit="batch"
488
+ )
489
+ ########### TRAINING STEP LOOP ###########
490
+ with LaunchLogger(
491
+ "train", epoch=epoch, num_mini_batch=len(train_dataloader)
492
+ ) as log:
493
+ for batch in train_iter:
494
+ batch = _recursive_data_device_cast(
495
+ batch, device=device, dtype=dtype
496
+ )
497
+ model.zero_grad(set_to_none=True)
498
+ loss = train_step_fn(model, batch, *args, **kwargs)
499
+ log.log_minibatch({"train_loss": loss.detach().item()})
500
+ # normally, static capture will call backward because of AMP
501
+ if not self.enable_static_capture:
502
+ loss.backward()
503
+ optimizer.step()
504
+ if lr_scheduler:
505
+ lr_scheduler.step()
506
+ ########### VALIDATION STEP LOOP ###########
507
+ if validate_step_fn and validation_dataloader:
508
+ model.eval()
509
+ val_iter = iter(validation_dataloader)
510
+ if self.use_progress_bars:
511
+ val_iter = tqdm(
512
+ val_iter, desc="Validation step", leave=False, unit="batch"
513
+ )
514
+ with LaunchLogger(
515
+ "validation", epoch=epoch, num_mini_batch=len(validation_dataloader)
516
+ ) as log:
517
+ for batch in val_iter:
518
+ batch = _recursive_data_device_cast(
519
+ batch, device=device, dtype=dtype
520
+ )
521
+ validate_step_fn(model, batch, *args, **kwargs)
522
+
523
+ ########### CHECKPOINT SAVE ###########
524
+ # Save training state at specified frequency
525
+ if self.checkpoint_base_dir and self.checkpoint_frequency > 0:
526
+ if epoch % self.checkpoint_frequency == 0:
527
+ epoch_checkpoint_dir = self.checkpoint_base_dir / f"epoch_{epoch}"
528
+ self.save_training_checkpoint(
529
+ checkpoint_dir=epoch_checkpoint_dir,
530
+ model=model,
531
+ optimizer=optimizer,
532
+ lr_scheduler=lr_scheduler,
533
+ training_epoch=epoch,
534
+ )
physics_mcp/source/physicsnemo/active_learning/protocols.py ADDED
@@ -0,0 +1,1394 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ This module contains base classes for active learning protocols.
19
+
20
+ These are protocols intended to be abstract, and importing these
21
+ classes specifically is intended to either be subclassed, or for
22
+ type annotations.
23
+
24
+ Protocol Architecture
25
+ ---------------------
26
+ Python ``Protocol``s are used for structural typing: essentially, they are used to
27
+ describe an expected interface in a way that is helpful for static type checkers
28
+ to make sure concrete implementations provide everything that is needed for a workflow
29
+ to function. ``Protocol``s are not actually enforced at runtime, and inheritance is not
30
+ required for them to function: as long as the implementation provides the expected
31
+ attributes and methods, they will be compatible with the protocol.
32
+
33
+ The active learning framework is built around several key protocol abstractions
34
+ that work together to orchestrate the active learning workflow:
35
+
36
+ **Core Infrastructure Protocols:**
37
+ - `AbstractQueue[T]` - Generic queue protocol for passing data between components
38
+ - `DataPool[T]` - Protocol for data reservoirs that support appending and sampling
39
+ - `ActiveLearningProtocol` - Base protocol providing common interface for all AL strategies
40
+
41
+ **Strategy Protocols (inherit from ActiveLearningProtocol):**
42
+ - `QueryStrategy` - Defines how to select data points for labeling
43
+ - `LabelStrategy` - Defines processes for adding ground truth labels to unlabeled data
44
+ - `MetrologyStrategy` - Defines procedures that assess model improvements beyond validation metrics
45
+
46
+ **Model Interface Protocols:**
47
+ - `TrainingProtocol` - Interface for training step functions
48
+ - `ValidationProtocol` - Interface for validation step functions
49
+ - `InferenceProtocol` - Interface for inference step functions
50
+ - `TrainingLoop` - Interface for complete training loop implementations
51
+ - `LearnerProtocol` - Comprehensive interface for learner modules (combines training/validation/inference)
52
+
53
+ **Orchestration Protocol:**
54
+ - `DriverProtocol` - Main orchestrator that coordinates all components in the active learning loop
55
+
56
+ Protocol Relationships
57
+ ----------------------
58
+
59
+ ```mermaid
60
+ graph TB
61
+ subgraph "Core Infrastructure"
62
+ AQ[AbstractQueue&lt;T&gt;]
63
+ DP[DataPool&lt;T&gt;]
64
+ ALP[ActiveLearningProtocol]
65
+ end
66
+
67
+ subgraph "Strategy Layer"
68
+ QS[QueryStrategy]
69
+ LS[LabelStrategy]
70
+ MS[MetrologyStrategy]
71
+ end
72
+
73
+ subgraph "Model Interface Layer"
74
+ TP[TrainingProtocol]
75
+ VP[ValidationProtocol]
76
+ IP[InferenceProtocol]
77
+ TL[TrainingLoop]
78
+ LP[LearnerProtocol]
79
+ end
80
+
81
+ subgraph "Orchestration Layer"
82
+ Driver[DriverProtocol]
83
+ end
84
+
85
+ %% Inheritance relationships (thick blue arrows)
86
+ ALP ==>|inherits| QS
87
+ ALP ==>|inherits| LS
88
+ ALP ==>|inherits| MS
89
+
90
+ %% Composition relationships (dashed green arrows)
91
+ Driver -.->|uses| LP
92
+ Driver -.->|manages| QS
93
+ Driver -.->|manages| LS
94
+ Driver -.->|manages| MS
95
+ Driver -.->|contains| DP
96
+ Driver -.->|contains| AQ
97
+
98
+ %% Protocol usage relationships (dotted purple arrows)
99
+ TL -.->|can use| TP
100
+ TL -.->|can use| VP
101
+ TL -.->|can use| LP
102
+ LP -.->|implements| TP
103
+ LP -.->|implements| VP
104
+ LP -.->|implements| IP
105
+
106
+ %% Data flow relationships (solid red arrows)
107
+ QS -->|enqueues to| AQ
108
+ AQ -->|consumed by| LS
109
+ LS -->|enqueues to| AQ
110
+
111
+ %% Styling for different relationship types
112
+ linkStyle 0 stroke:#1976d2,stroke-width:4px
113
+ linkStyle 1 stroke:#1976d2,stroke-width:4px
114
+ linkStyle 2 stroke:#1976d2,stroke-width:4px
115
+ linkStyle 3 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5
116
+ linkStyle 4 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5
117
+ linkStyle 5 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5
118
+ linkStyle 6 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5
119
+ linkStyle 7 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5
120
+ linkStyle 8 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5
121
+ linkStyle 9 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2
122
+ linkStyle 10 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2
123
+ linkStyle 11 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2
124
+ linkStyle 12 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2
125
+ linkStyle 13 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2
126
+ linkStyle 14 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2
127
+ linkStyle 15 stroke:#d32f2f,stroke-width:3px
128
+ linkStyle 16 stroke:#d32f2f,stroke-width:3px
129
+ linkStyle 17 stroke:#d32f2f,stroke-width:3px
130
+
131
+ %% Node styling
132
+ classDef coreInfra fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
133
+ classDef strategy fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px
134
+ classDef modelInterface fill:#e8f5e8,stroke:#388e3c,stroke-width:2px
135
+ classDef orchestration fill:#fff3e0,stroke:#f57c00,stroke-width:3px
136
+
137
+ class AQ,DP,ALP coreInfra
138
+ class QS,LS,MS strategy
139
+ class TP,VP,IP,TL,LP modelInterface
140
+ class Driver orchestration
141
+ ```
142
+
143
+ **Relationship Legend:**
144
+ - **Blue thick arrows (==>)**: Inheritance relationships (subclass extends parent)
145
+ - **Green dashed arrows (-.->)**: Composition relationships (object contains/manages other objects)
146
+ - **Purple dotted arrows (-.->)**: Protocol usage relationships (can use or implements interface)
147
+ - **Red solid arrows (-->)**: Data flow relationships (data moves between components)
148
+
149
+ Active Learning Workflow
150
+ ------------------------
151
+
152
+ The typical active learning workflow orchestrated by `DriverProtocol` follows this sequence:
153
+
154
+ 1. **Training Phase**: Use `LearnerProtocol` or `TrainingLoop` to train the model on `training_pool`
155
+ 2. **Metrology Phase** (optional): Apply `MetrologyStrategy` instances to assess model performance
156
+ 3. **Query Phase**: Apply `QueryStrategy` instances to select samples from `unlabeled_pool` → `query_queue`
157
+ 4. **Labeling Phase** (optional): Apply `LabelStrategy` instances to label queued samples → `label_queue`
158
+ 5. **Data Integration**: Move labeled data from `label_queue` to `training_pool`
159
+
160
+ Type Parameters
161
+ ---------------
162
+ - `T`: Data structure containing both inputs and ground truth labels
163
+ - `S`: Data structure containing only inputs (no ground truth labels)
164
+ """
165
+
166
+ from __future__ import annotations
167
+
168
+ import inspect
169
+ import logging
170
+ from enum import StrEnum
171
+ from logging import Logger
172
+ from pathlib import Path
173
+ from typing import Any, Iterator, Protocol, TypeVar
174
+
175
+ import torch
176
+ from torch.optim import Optimizer
177
+ from torch.optim.lr_scheduler import _LRScheduler
178
+ from torch.utils.data import DataLoader
179
+
180
+ from physicsnemo import Module
181
+
182
+ # T is used to denote a data structure that contains inputs for a model and ground truths
183
+ T = TypeVar("T")
184
+ # S is used to denote a data structure that has inputs for a model, but no ground truth labels
185
+ S = TypeVar("S")
186
+
187
+
188
+ class ActiveLearningPhase(StrEnum):
189
+ """
190
+ An enumeration of the different phases of the active learning workflow.
191
+
192
+ This is primarily used in the metadata for restarting an ongoing active
193
+ learning experiment.
194
+ """
195
+
196
+ TRAINING = "training"
197
+ METROLOGY = "metrology"
198
+ QUERY = "query"
199
+ LABELING = "labeling"
200
+ DATA_INTEGRATION = "data_integration"
201
+
202
+
203
+ class AbstractQueue(Protocol[T]):
204
+ """
205
+ Defines a generic queue protocol for data that is passed between active
206
+ learning components.
207
+
208
+ This can be a simple local `queue.Queue`, or a more sophisticated
209
+ distributed queue system.
210
+
211
+ The primary use case for this is to allow a query strategy to
212
+ enqueue some data structure for the labeling strategy to consume,
213
+ and once the labeling is done, enqueue to a data serialization
214
+ workflow. While there is no explcit restriction on the **type**
215
+ of queue that is implemented, a reasonable assumption to make
216
+ would be a FIFO queue, unless otherwise specified by the concrete
217
+ implementation.
218
+
219
+ Optional Serialization Methods
220
+ -------------------------------
221
+ Implementations may optionally provide `to_list()` and `from_list()`
222
+ methods for checkpoint serialization. If not provided, the queue
223
+ will be serialized using `torch.save()` as a fallback.
224
+
225
+ Type Parameters
226
+ ---------------
227
+ T
228
+ The type of items that will be stored in the queue.
229
+ """
230
+
231
+ def put(self, item: T) -> None:
232
+ """
233
+ Method to put a data structure into the queue.
234
+
235
+ Parameters
236
+ ----------
237
+ item: T
238
+ The data structure to put into the queue.
239
+ """
240
+ ...
241
+
242
+ def get(self) -> T:
243
+ """
244
+ Method to get a data structure from the queue.
245
+
246
+ This method should remove the data structure from the queue,
247
+ and return it to a consumer.
248
+
249
+ Returns
250
+ -------
251
+ T
252
+ The data structure that was removed from the queue.
253
+ """
254
+ ...
255
+
256
+ def empty(self) -> bool:
257
+ """
258
+ Method to check if the queue is empty/has been depleted.
259
+
260
+ Returns
261
+ -------
262
+ bool
263
+ True if the queue is empty, False otherwise.
264
+ """
265
+ ...
266
+
267
+
268
+ class DataPool(Protocol[T]):
269
+ """
270
+ An abstract protocol for some reservoir of data that is
271
+ used for some part of active learning, parametrized such
272
+ that it will return data structures of an arbitrary type ``T``.
273
+
274
+ **All** methods are left abstract, and need to be defined
275
+ by concrete implementations. For the most part, a `torch.utils.data.Dataset`
276
+ would match this protocol, provided that it implements the ``append`` method
277
+ which will allow data to be persisted to a filesystem.
278
+
279
+ Methods
280
+ -------
281
+ __getitem__(self, index: int) -> T:
282
+ Method to get a single data structure from the data pool.
283
+ __len__(self) -> int:
284
+ Method to get the length of the data pool.
285
+ __iter__(self) -> Iterator[T]:
286
+ Method to iterate over the data pool.
287
+ append(self, item: T) -> None:
288
+ Method to append a data structure to the data pool.
289
+ """
290
+
291
+ def __getitem__(self, index: int) -> T:
292
+ """
293
+ Method to get a data structure from the data pool.
294
+
295
+ This method should retrieve an item from the pool by a
296
+ flat index.
297
+
298
+ Parameters
299
+ ----------
300
+ index: int
301
+ The index of the data structure to get.
302
+
303
+ Returns
304
+ -------
305
+ T
306
+ The data structure at the given index.
307
+ """
308
+ ...
309
+
310
+ def __len__(self) -> int:
311
+ """
312
+ Method to get the length of the data pool.
313
+
314
+ Returns
315
+ -------
316
+ int
317
+ The length of the data pool.
318
+ """
319
+ ...
320
+
321
+ def __iter__(self) -> Iterator[T]:
322
+ """
323
+ Method to iterate over the data pool.
324
+
325
+ This method should return an iterator over the data pool.
326
+
327
+ Returns
328
+ -------
329
+ Iterator[T]
330
+ An iterator over the data pool.
331
+ """
332
+ ...
333
+
334
+ def append(self, item: T) -> None:
335
+ """
336
+ Method to append a data structure to the data pool.
337
+
338
+ For persistent storage pools, this will actually mean that the
339
+ ``item`` is serialized to a filesystem.
340
+
341
+ Parameters
342
+ ----------
343
+ item: T
344
+ The data structure to append to the data pool.
345
+ """
346
+ ...
347
+
348
+
349
+ class ActiveLearningProtocol(Protocol):
350
+ """
351
+ This protocol acts as a basis for all active learning protocols.
352
+
353
+ This ensures that all protocols have some common interface, for
354
+ example the ability to `attach` to another object for scope
355
+ management.
356
+
357
+ Attributes
358
+ ----------
359
+ __protocol_name__: str
360
+ The name of the protocol. This is primarily used for `repr`
361
+ and `str` f-strings. This should be defined by concrete
362
+ implementations.
363
+ _args: dict[str, Any]
364
+ A dictionary of arguments that were used to instantiate the protocol.
365
+ This is used for serialization and deserialization of the protocol,
366
+ and follows the same pattern as the ``_args`` attribute of
367
+ ``physicsnemo.Module``.
368
+
369
+ Methods
370
+ -------
371
+ attach(self, other: object) -> None:
372
+ This method is used to attach the current object to another,
373
+ allowing the protocol to access the attached object's scope.
374
+ The use case for this is to allow a protocol access to the
375
+ driver's scope to access dataset, model, etc. as needed.
376
+ This needs to be implemented by concrete implementations.
377
+ is_attached: bool
378
+ Whether the current object is attached to another object.
379
+ This is left abstract, as it depends on how ``attach`` is implemented.
380
+ logger: Logger
381
+ The logger for this protocol. This is used to log information
382
+ about the protocol's progress.
383
+ _setup_logger(self) -> None:
384
+ This method is used to setup the logger for the protocol.
385
+ The default implementation is to configure the logger similarly
386
+ to how ``physicsnemo`` loggers are configured.
387
+ """
388
+
389
+ __protocol_name__: str
390
+ __protocol_type__: ActiveLearningPhase
391
+ _args: dict[str, Any]
392
+
393
+ def __new__(cls, *args: Any, **kwargs: Any) -> ActiveLearningProtocol:
394
+ """
395
+ Wrapper for instantiating any subclass of `ActiveLearningProtocol`.
396
+
397
+ This method will use `inspect` to capture arguments and keyword
398
+ arguments that were used to instantiate the protocol, and stash
399
+ them into the `_args` attribute of the instance, following
400
+ what is done with `physicsnemo.Module`.
401
+
402
+ This approach is useful for reconstructing strategies from checkpoints.
403
+
404
+ Parameters
405
+ ----------
406
+ args: Any
407
+ Arguments to pass to the protocol's constructor.
408
+ kwargs: Any
409
+ Keyword arguments to pass to the protocol's constructor.
410
+
411
+ Returns
412
+ -------
413
+ ActiveLearningProtocol
414
+ A new instance of the protocol class. The instance will have an
415
+ `_args` attribute that contains the keys `__name__`, `__module__`,
416
+ and `__args__` as metadata for the protocol.
417
+ """
418
+ out = super().__new__(cls)
419
+
420
+ # Get signature of __init__ function
421
+ sig = inspect.signature(cls.__init__)
422
+
423
+ # Bind args and kwargs to signature
424
+ bound_args = sig.bind_partial(
425
+ *([None] + list(args)), **kwargs
426
+ ) # Add None to account for self
427
+ bound_args.apply_defaults()
428
+
429
+ # Get args and kwargs (excluding self and unroll kwargs)
430
+ instantiate_args = {}
431
+ for param, (k, v) in zip(sig.parameters.values(), bound_args.arguments.items()):
432
+ # Skip self
433
+ if k == "self":
434
+ continue
435
+
436
+ # Add args and kwargs to instantiate_args
437
+ if param.kind == param.VAR_KEYWORD:
438
+ instantiate_args.update(v)
439
+ else:
440
+ instantiate_args[k] = v
441
+
442
+ # Store args needed for instantiation
443
+ out._args = {
444
+ "__name__": cls.__name__,
445
+ "__module__": cls.__module__,
446
+ "__args__": instantiate_args,
447
+ }
448
+ return out
449
+
450
+ def attach(self, other: object) -> None:
451
+ """
452
+ This method is used to attach another object to the current protocol,
453
+ allowing the attached object to access the scope of this protocol.
454
+ The primary reason for this is to allow the protocol to access
455
+ things like the dataset, the learner model, etc. as needed.
456
+
457
+ Example use cases would be for a query strategy to access the ``unlabeled_pool``;
458
+ for a metrology strategy to access the ``validation_pool``, and for any
459
+ strategy to be able to access the surrogate/learner model.
460
+
461
+ This method can be as simple as setting ``self.driver = other``, but
462
+ is left abstract in case there are other potential use cases
463
+ where multiple protocols could share information.
464
+
465
+ Parameters
466
+ ----------
467
+ other: object
468
+ The object to attach to.
469
+ """
470
+ ...
471
+
472
+ @property
473
+ def is_attached(self) -> bool:
474
+ """
475
+ Property to check if the current object is already attached.
476
+
477
+ This is left abstract, as it depends on how ``attach`` is implemented.
478
+
479
+ Returns
480
+ -------
481
+ bool
482
+ True if the current object is attached, False otherwise.
483
+ """
484
+ ...
485
+
486
+ @property
487
+ def logger(self) -> Logger:
488
+ """
489
+ Property to access the logger for this protocol.
490
+
491
+ If the logger has not been configured yet, the property
492
+ will call the `_setup_logger` method to configure it.
493
+
494
+ Returns
495
+ -------
496
+ Logger
497
+ The logger for this protocol.
498
+ """
499
+ if not hasattr(self, "_logger"):
500
+ self._setup_logger()
501
+ return self._logger
502
+
503
+ @logger.setter
504
+ def logger(self, logger: Logger) -> None:
505
+ """
506
+ Setter for the logger for this protocol.
507
+
508
+ Parameters
509
+ ----------
510
+ logger: Logger
511
+ The logger to set for this protocol.
512
+ """
513
+ self._logger = logger
514
+
515
+ def _setup_logger(self) -> None:
516
+ """
517
+ Method to setup the logger for all active learning protocols.
518
+
519
+ Each protocol should have their own logger
520
+ """
521
+ self.logger = logging.getLogger(
522
+ f"core.active_learning.{self.__protocol_name__}"
523
+ )
524
+ # Don't add handlers here - let the parent logger handle formatting
525
+ # This prevents duplicate console output
526
+ self.logger.setLevel(logging.WARNING)
527
+
528
+ @property
529
+ def strategy_dir(self) -> Path:
530
+ """
531
+ Returns the directory where the underlying strategy can use
532
+ to persist data.
533
+
534
+ Depending on the strategy abstraction, further nesting may be
535
+ required (e.g active learning step index, phase, etc.).
536
+
537
+ Returns
538
+ -------
539
+ Path
540
+ The directory where the metrology strategy will persist
541
+ its records.
542
+
543
+ Raises
544
+ ------
545
+ RuntimeError
546
+ If the metrology strategy is not attached to a driver yet.
547
+ """
548
+ if not self.is_attached:
549
+ raise RuntimeError(
550
+ f"{self.__class__.__name__} is not attached to a driver yet."
551
+ )
552
+ path = (
553
+ self.driver.log_dir / str(self.__protocol_type__) / self.__class__.__name__
554
+ )
555
+ path.mkdir(parents=True, exist_ok=True)
556
+ return path
557
+
558
+ @property
559
+ def checkpoint_dir(self) -> Path:
560
+ """
561
+ Utility property for strategies to conveniently access the checkpoint directory.
562
+
563
+ This is useful for (de)serializing data tied to checkpointing.
564
+
565
+ Returns
566
+ -------
567
+ Path
568
+ The checkpoint directory, which includes the active learning step index.
569
+
570
+ Raises
571
+ ------
572
+ RuntimeError
573
+ If the strategy is not attached to a driver yet.
574
+ """
575
+ if not self.is_attached:
576
+ raise RuntimeError(
577
+ f"{self.__class__.__name__} is not attached to a driver yet."
578
+ )
579
+ path = (
580
+ self.driver.log_dir
581
+ / "checkpoints"
582
+ / f"step_{self.driver.active_learning_step_idx}"
583
+ )
584
+ path.mkdir(parents=True, exist_ok=True)
585
+ return path
586
+
587
+
588
+ class QueryStrategy(ActiveLearningProtocol):
589
+ """
590
+ This protocol defines a query strategy for active learning.
591
+
592
+ A query strategy is responsible for selecting data points for labeling.
593
+ In the most general sense, concrete instances of this protocol
594
+ will specify how many samples to query, and the heuristics for
595
+ selecting samples.
596
+
597
+ Attributes
598
+ ----------
599
+ max_samples: int
600
+ The maximum number of samples to query. This can be interpreted
601
+ as the exact number of samples to query, or as an upper limit
602
+ for querying methods that are threshold based.
603
+ """
604
+
605
+ max_samples: int
606
+ __protocol_type__ = ActiveLearningPhase.QUERY
607
+
608
+ def sample(self, query_queue: AbstractQueue[T], *args: Any, **kwargs: Any) -> None:
609
+ """
610
+ Method that implements the logic behind querying data to be labeled.
611
+
612
+ This method should be implemented by concrete implementations,
613
+ and assume that an active learning driver will pass a queue
614
+ for this method to enqueue data to be labeled.
615
+
616
+ Additional ``args`` and ``kwargs`` are passed to the method,
617
+ and can be used to pass additional information to the query strategy.
618
+
619
+ This method will enqueue in place, and should not return anything.
620
+
621
+ Parameters
622
+ ----------
623
+ query_queue: AbstractQueue[T]
624
+ The queue to enqueue data to be labeled.
625
+ args: Any
626
+ Additional arguments to pass to the method.
627
+ kwargs: Any
628
+ Additional keyword arguments to pass to the method.
629
+ """
630
+ ...
631
+
632
+ def __call__(
633
+ self, query_queue: AbstractQueue[T], *args: Any, **kwargs: Any
634
+ ) -> None:
635
+ """
636
+ Syntactic sugar for the ``sample`` method.
637
+
638
+ This allows the object to be called as a function, and will pass
639
+ the arguments to the strategy's ``sample`` method.
640
+
641
+ Parameters
642
+ ----------
643
+ query_queue: AbstractQueue[T]
644
+ The queue to enqueue data to be labeled.
645
+ args: Any
646
+ Additional arguments to pass to the method.
647
+ kwargs: Any
648
+ Additional keyword arguments to pass to the method.
649
+ """
650
+ self.sample(query_queue, *args, **kwargs)
651
+
652
+
653
+ class LabelStrategy(ActiveLearningProtocol):
654
+ """
655
+ This protocol defines a label strategy for active learning.
656
+
657
+ A label strategy is responsible for labeling data points; this may
658
+ be an simple Python function for demonstrating a concept, or an external,
659
+ potentially time consuming and complex, process.
660
+
661
+ Attributes
662
+ ----------
663
+ __is_external_process__: bool
664
+ Whether the label strategy is running in an external process.
665
+ __provides_fields__: set[str]
666
+ The fields that the label strategy provides. This should be
667
+ set by concrete implementations, and should be used to write
668
+ and map labeled data to fields within the data structure ``T``.
669
+ """
670
+
671
+ __is_external_process__: bool
672
+ __provides_fields__: set[str] | None = None
673
+ __protocol_type__ = ActiveLearningPhase.LABELING
674
+
675
+ def label(
676
+ self,
677
+ queue_to_label: AbstractQueue[T],
678
+ serialize_queue: AbstractQueue[T],
679
+ *args: Any,
680
+ **kwargs: Any,
681
+ ) -> None:
682
+ """
683
+ Method that implements the logic behind labeling data.
684
+
685
+ This method should be implemented by concrete implementations,
686
+ and assume that an active learning driver will pass a queue
687
+ for this method to dequeue data to be labeled.
688
+
689
+ Parameters
690
+ ----------
691
+ queue_to_label: AbstractQueue[T]
692
+ Queue containing data structures to be labeled. Generally speaking,
693
+ this should be passed over after running query strateg(ies).
694
+ serialize_queue: AbstractQueue[T]
695
+ Queue for enqueing labeled data to be serialized.
696
+ args: Any
697
+ Additional arguments to pass to the method.
698
+ kwargs: Any
699
+ Additional keyword arguments to pass to the method.
700
+ """
701
+ ...
702
+
703
+ def __call__(
704
+ self,
705
+ queue_to_label: AbstractQueue[T],
706
+ serialize_queue: AbstractQueue[T],
707
+ *args: Any,
708
+ **kwargs: Any,
709
+ ) -> None:
710
+ """
711
+ Syntactic sugar for the ``label`` method.
712
+
713
+ This allows the object to be called as a function, and will pass
714
+ the arguments to the strategy's ``label`` method.
715
+
716
+ Parameters
717
+ ----------
718
+ queue_to_label: AbstractQueue[T]
719
+ Queue containing data structures to be labeled.
720
+ serialize_queue: AbstractQueue[T]
721
+ Queue for enqueing labeled data to be serialized.
722
+ args: Any
723
+ Additional arguments to pass to the method.
724
+ kwargs: Any
725
+ Additional keyword arguments to pass to the method.
726
+ """
727
+ self.label(queue_to_label, serialize_queue, *args, **kwargs)
728
+
729
+
730
+ class MetrologyStrategy(ActiveLearningProtocol):
731
+ """
732
+ This protocol defines a metrology strategy for active learning.
733
+
734
+ A metrology strategy is responsible for assessing the improvements to the underlying
735
+ model, beyond simple validation metrics. This should reflect the application
736
+ requirements of the model, which may include running a simulation.
737
+
738
+ Attributes
739
+ ----------
740
+ records: list[S]
741
+ A sequence of record data structures that records the
742
+ history of the active learning process, as viewed by
743
+ this particular metrology view.
744
+ """
745
+
746
+ records: list[S]
747
+ __protocol_type__ = ActiveLearningPhase.METROLOGY
748
+
749
+ def append(self, record: S) -> None:
750
+ """
751
+ Method to append a record to the metrology strategy.
752
+
753
+ Parameters
754
+ ----------
755
+ record: S
756
+ The record to append to the metrology strategy.
757
+ """
758
+ self.records.append(record)
759
+
760
+ def __len__(self) -> int:
761
+ """
762
+ Method to get the length of the metrology strategy.
763
+
764
+ Returns
765
+ -------
766
+ int
767
+ The length of the metrology strategy.
768
+ """
769
+ return len(self.records)
770
+
771
+ def serialize_records(
772
+ self, path: Path | None = None, *args: Any, **kwargs: Any
773
+ ) -> None:
774
+ """
775
+ Method to serialize the records of the metrology strategy.
776
+
777
+ This should be defined by a concrete implementation, which dictates
778
+ how the records are persisted, e.g. to a JSON file, database, etc.
779
+
780
+ The `strategy_dir` property can be used to determine the directory where
781
+ the records should be persisted.
782
+
783
+ Parameters
784
+ ----------
785
+ path: Path | None
786
+ The path to serialize the records to. If not provided, the strategy
787
+ should provide a reasonable default, such as with the checkpointing
788
+ or within the corresponding metrology directory via `strategy_dir`.
789
+ args: Any
790
+ Additional arguments to pass to the method.
791
+ kwargs: Any
792
+ Additional keyword arguments to pass to the method.
793
+ """
794
+ ...
795
+
796
+ def load_records(self, path: Path | None = None, *args: Any, **kwargs: Any) -> None:
797
+ """
798
+ Method to load the records of the metrology strategy, i.e.
799
+ the reverse of `serialize_records`.
800
+
801
+ This should be defined by a concrete implementation, which dictates
802
+ how the records are loaded, e.g. from a JSON file, database, etc.
803
+
804
+ If no path is provided, the strategy should load the latest records
805
+ as sensible defaults. The `records` attribute should then be overwritten
806
+ in-place.
807
+
808
+ Parameters
809
+ ----------
810
+ path: Path | None
811
+ The path to load the records from. If not provided, the strategy
812
+ should load the latest records as sensible defaults.
813
+ args: Any
814
+ Additional arguments to pass to the method.
815
+ kwargs: Any
816
+ Additional keyword arguments to pass to the method.
817
+ """
818
+ ...
819
+
820
+ def compute(self, *args: Any, **kwargs: Any) -> None:
821
+ """
822
+ Method to compute the metrology strategy. No data is passed to
823
+ this method, as it is expected that the data be drawn as needed
824
+ from various ``DataPool`` connected to the driver.
825
+
826
+ This method defines the core logic for computing a particular view
827
+ of performance by the underlying model on the data. Once computed,
828
+ the data needs to be formatted into a record data structure ``S``,
829
+ that is then appended to the ``records`` attribute.
830
+
831
+ Parameters
832
+ ----------
833
+ args: Any
834
+ Additional arguments to pass to the method.
835
+ kwargs: Any
836
+ Additional keyword arguments to pass to the method.
837
+ """
838
+ ...
839
+
840
+ def __call__(self, *args: Any, **kwargs: Any) -> None:
841
+ """
842
+ Syntactic sugar for the ``compute`` method.
843
+
844
+ This allows the object to be called as a function, and will pass
845
+ the arguments to the strategy's ``compute`` method.
846
+
847
+ Parameters
848
+ ----------
849
+ args: Any
850
+ Additional arguments to pass to the method.
851
+ kwargs: Any
852
+ Additional keyword arguments to pass to the method.
853
+ """
854
+ self.compute(*args, **kwargs)
855
+
856
+ def reset(self) -> None:
857
+ """
858
+ Method to reset any stateful attributes of the metrology strategy.
859
+
860
+ By default, the ``records`` attribute is reset to an empty list.
861
+ """
862
+ self.records = []
863
+
864
+
865
+ class TrainingProtocol(Protocol):
866
+ """
867
+ This protocol defines the interface for training steps: given
868
+ a model and some input data, compute the reduced, differentiable
869
+ loss tensor and return it.
870
+
871
+ A concrete implementation can simply be a function with a signature that
872
+ matches what is defined in ``__call__``.
873
+ """
874
+
875
+ def __call__(
876
+ self, model: Module, data: T, *args: Any, **kwargs: Any
877
+ ) -> torch.Tensor:
878
+ """
879
+ Implements the training logic for a single training sample or batch.
880
+
881
+ For a PhysicsNeMo ``Module`` with trainable parameters, the output
882
+ of this function should correspond to a PyTorch tensor that is
883
+ ``backward``-ready. If there are any logging operations associated
884
+ with training, they should be performed within this function.
885
+
886
+ For ideal performance, this function should also be wrappable with
887
+ ``StaticCaptureTraining`` for optimization.
888
+
889
+ Parameters
890
+ ----------
891
+ model: Module
892
+ The model to train.
893
+ data: T
894
+ The data to train on. This data structure should comprise
895
+ both input and ground truths to compute the loss.
896
+ args: Any
897
+ Additional arguments to pass to the method.
898
+ kwargs: Any
899
+ Additional keyword arguments to pass to the method.
900
+
901
+ Returns
902
+ -------
903
+ torch.Tensor
904
+ The reduced, differentiable loss tensor.
905
+
906
+ Example
907
+ -------
908
+ Minimum viable implementation:
909
+ >>> import torch
910
+ >>> def training_step(model, data):
911
+ ... output = model(data)
912
+ ... loss = torch.sum(torch.pow(output - data, 2))
913
+ ... return loss
914
+ """
915
+ ...
916
+
917
+
918
+ class ValidationProtocol(Protocol):
919
+ """
920
+ This protocol defines the interface for validation steps: given
921
+ a model and some input data, compute metrics of interest and if
922
+ relevant to do so, log the results.
923
+
924
+ A concrete implementation can simply be a function with a signature that
925
+ matches what is defined in ``__call__``.
926
+ """
927
+
928
+ def __call__(self, model: Module, data: T, *args: Any, **kwargs: Any) -> None:
929
+ """
930
+ Implements the validation logic for a single sample or batch.
931
+
932
+ This method will be called in validation steps **only**, and not used
933
+ for training, query, or metrology steps. In those cases, implement the
934
+ ``inference_step`` method instead.
935
+
936
+ This function should not return anything, but should contain the logic
937
+ for computing metrics of interest over a validation/test set. If there
938
+ are any logging operations that need to be performed, they should also
939
+ be performed here.
940
+
941
+ Depending on the type of model architecture, consider wrapping this method
942
+ with ``StaticCaptureEvaluateNoGrad`` for performance optimizations. This
943
+ should be used if the model does not require autograd as part of its
944
+ forward pass.
945
+
946
+ Parameters
947
+ ----------
948
+ model: Module
949
+ The model to validate.
950
+ data: T
951
+ The data to validate on. This data structure should comprise
952
+ both input and ground truths to compute the loss.
953
+ args: Any
954
+ Additional arguments to pass to the method.
955
+ kwargs: Any
956
+ Additional keyword arguments to pass to the method.
957
+
958
+ Example
959
+ -------
960
+ Minimum viable implementation:
961
+ >>> import torch
962
+ >>> def validation_step(model, data):
963
+ ... output = model(data)
964
+ ... loss = torch.sum(torch.pow(output - data, 2))
965
+ ... return loss
966
+ """
967
+ ...
968
+
969
+
970
+ class InferenceProtocol(Protocol):
971
+ """
972
+ This protocol defines the interface for inference steps: given
973
+ a model and some input data, return the output of the model's forward pass.
974
+
975
+ A concrete implementation can simply be a function with a signature that
976
+ matches what is defined in ``__call__``.
977
+ """
978
+
979
+ def __call__(self, model: Module, data: S, *args: Any, **kwargs: Any) -> Any:
980
+ """
981
+ Implements the inference logic for a single sample or batch.
982
+
983
+ This method will be called in query and metrology steps, and should
984
+ return the output of the model's forward pass, likely minimally processed
985
+ so that any transformations can be performed by strategies that utilize
986
+ this protocol.
987
+
988
+ The key difference between this protocol and the other two training and
989
+ validation protocols is that the data structure ``S`` does not need
990
+ to contain ground truth values to compute a loss.
991
+
992
+ Similar to ``ValidationProtocol``, if relevant to the underlying architecture,
993
+ consider wrapping a concrete implementation of this protocol with
994
+ ``StaticCaptureInference`` for performance optimizations.
995
+
996
+ Parameters
997
+ ----------
998
+ model: Module
999
+ The model to infer on.
1000
+ data: S
1001
+ The data to infer on. This data structure should comprise
1002
+ only input values to compute the forward pass.
1003
+ args: Any
1004
+ Additional arguments to pass to the method.
1005
+ kwargs: Any
1006
+ Additional keyword arguments to pass to the method.
1007
+
1008
+ Returns
1009
+ -------
1010
+ Any
1011
+ The output of the model's forward pass.
1012
+
1013
+ Example
1014
+ -------
1015
+ Minimum viable implementation:
1016
+ >>> def inference_step(model, data):
1017
+ ... output = model(data)
1018
+ ... return output
1019
+ """
1020
+ ...
1021
+
1022
+
1023
+ class TrainingLoop(Protocol):
1024
+ """
1025
+ Defines a protocol that implements a training loop.
1026
+
1027
+ This protocol is intended to be called within the active learning loop
1028
+ during the training phase, where the model is trained on a specified
1029
+ number of epochs or training steps, and optionally validated on a dataset.
1030
+
1031
+ If a ``LearnerProtocol`` is provided, then ``train_fn`` and ``validate_fn``
1032
+ become optional as they will be defined within the ``LearnerProtocol``. If
1033
+ they are provided, however, then they should override the ``LearnerProtocol``
1034
+ variants.
1035
+
1036
+ If graph capture/compilation is intended, then ``train_fn`` and ``validate_fn``
1037
+ should be wrapped with ``StaticCaptureTraining`` and ``StaticCaptureEvaluateNoGrad``,
1038
+ respectively.
1039
+ """
1040
+
1041
+ def __call__(
1042
+ self,
1043
+ model: Module | LearnerProtocol,
1044
+ optimizer: Optimizer,
1045
+ train_dataloader: DataLoader,
1046
+ validation_dataloader: DataLoader | None = None,
1047
+ train_step_fn: TrainingProtocol | None = None,
1048
+ validate_step_fn: ValidationProtocol | None = None,
1049
+ max_epochs: int | None = None,
1050
+ max_train_steps: int | None = None,
1051
+ max_val_steps: int | None = None,
1052
+ lr_scheduler: _LRScheduler | None = None,
1053
+ device: str | torch.device | None = None,
1054
+ dtype: torch.dtype | None = None,
1055
+ *args: Any,
1056
+ **kwargs: Any,
1057
+ ) -> None:
1058
+ """
1059
+ Defines the signature for a minimal viable training loop.
1060
+
1061
+ The protocol defines a ``model`` with trainable parameters
1062
+ tracked by ``optimizer`` will go through multiple epochs or
1063
+ training steps. In the latter, the ``train_dataloader`` will be
1064
+ exhausted ``max_epochs`` times, while the mutually exclusive
1065
+ ``max_train_steps`` will limit the number of training batches,
1066
+ which can be greater or less than the length of the ``train_dataloader``.
1067
+
1068
+ (Optional) Validation is intended to be performed either at the end of a training
1069
+ epoch, or when the maximum number of training steps is reached. The
1070
+ ``max_val_steps`` parameter can be used to limit the number of batches to validate with
1071
+ on a per-epoch basis. Validation is only performed if a ``validate_step_fn`` is provided,
1072
+ alongside ``validation_dataloader``.
1073
+
1074
+ The pseudocode for training to ``max_epochs`` would look like this:
1075
+
1076
+ .. code-block:: python
1077
+
1078
+ max_epochs = 10
1079
+ for epoch in range(max_epochs):
1080
+ for train_idx, batch in enumerate(train_dataloader):
1081
+ optimizer.zero_grad()
1082
+ loss = train_step_fn(model, batch)
1083
+ loss.backward()
1084
+ optimizer.step()
1085
+ if train_idx + 1 == max_train_steps:
1086
+ break
1087
+ if validate_step_fn and validation_dataloader:
1088
+ for val_idx, batch in enumerate(validation_dataloader):
1089
+ validate_step_fn(model, batch)
1090
+ if val_idx + 1 == max_val_steps:
1091
+ break
1092
+
1093
+ The pseudocode for training with a ``LearnerProtocol`` would look like this:
1094
+
1095
+ .. code-block:: python
1096
+
1097
+ for epoch in range(max_epochs):
1098
+ for train_idx, batch in enumerate(train_dataloader):
1099
+ loss = model.training_step(batch)
1100
+ if train_idx + 1 == max_train_steps:
1101
+ break
1102
+ if validation_dataloader:
1103
+ for val_idx, batch in enumerate(validation_dataloader):
1104
+ model.validation_step(batch)
1105
+ if val_idx + 1 == max_val_steps:
1106
+ break
1107
+
1108
+ The key difference between specifying ``train_step_fn`` and ``LearnerProtocol``
1109
+ is that the former excludes the backward pass and optimizer step logic,
1110
+ whereas the latter encapsulates them.
1111
+
1112
+ The ``device`` and ``dtype`` parameters are used to specify the device and
1113
+ dtype to use for the training loop. If not provided, a reasonable default
1114
+ should be used (e.g. from ``torch.get_default_device()`` and ``torch.get_default_dtype()``).
1115
+
1116
+ Parameters
1117
+ ----------
1118
+ model: Module | LearnerProtocol
1119
+ The model to train.
1120
+ optimizer: Optimizer
1121
+ The optimizer to use for training.
1122
+ train_dataloader: DataLoader
1123
+ The dataloader to use for training.
1124
+ validation_dataloader: DataLoader | None
1125
+ The dataloader to use for validation.
1126
+ train_step_fn: TrainingProtocol | None
1127
+ The training function to use for training. This is optional only
1128
+ if ``model`` implements the ``LearnerProtocol``. If this is
1129
+ provided and ``model`` implements the ``LearnerProtocol``,
1130
+ then this function will take precedence over the
1131
+ ``LearnerProtocol.training_step`` method.
1132
+ validate_step_fn: ValidationProtocol | None
1133
+ The validation function to use for validation, only if it is
1134
+ provided alongside ``validation_dataloader``. If ``model`` implements
1135
+ the ``LearnerProtocol``, then this function will take precedence over
1136
+ the ``LearnerProtocol.validation_step`` method.
1137
+ max_epochs: int | None
1138
+ The maximum number of epochs to train for. Mututally exclusive
1139
+ with ``max_train_steps``.
1140
+ max_train_steps: int | None
1141
+ The maximum number of training steps to perform. Mututally exclusive
1142
+ with ``max_epochs``. If this value is greater than the length
1143
+ of ``train_dataloader``, then the training loop will recycle the data
1144
+ (i.e. more than one epoch) until the maximum number of training steps
1145
+ is reached.
1146
+ max_val_steps: int | None
1147
+ The maximum number of validation steps to perform per training
1148
+ epoch. If ``None``, then the full validation set will be used.
1149
+ lr_scheduler: _LRScheduler | None = None,
1150
+ The learning rate scheduler to use for training. If provided,
1151
+ this will be used to update the learning rate of the optimizer
1152
+ during training. If not provided, then the learning rate will
1153
+ not be adjusted within this function.
1154
+ device: str | torch.device | None = None
1155
+ The device to use for the training loop.
1156
+ dtype: torch.dtype | None = None
1157
+ The dtype to use for the training loop.
1158
+ args: Any
1159
+ Additional arguments to pass to the method.
1160
+ kwargs: Any
1161
+ Additional keyword arguments to pass to the method.
1162
+ """
1163
+ ...
1164
+
1165
+
1166
+ class LearnerProtocol:
1167
+ """
1168
+ This protocol represents the learner part of an active learning
1169
+ algorithm.
1170
+
1171
+ This corresponds to a set of trainable parameters that are optimized,
1172
+ and subsequently used for inference and evaluation.
1173
+
1174
+ The required methods make this classes that implement this protocol
1175
+ provide all the required functionality across all active learning steps.
1176
+ Keep in mind that, similar to all other protocols in this module, this
1177
+ is merely the required interface and not the actual implementation.
1178
+ """
1179
+
1180
+ def training_step(self, data: T, *args: Any, **kwargs: Any) -> None:
1181
+ """
1182
+ Implements the training logic for a single batch.
1183
+
1184
+ This method will be called in training steps **only**, and not used
1185
+ for validation, query, or metrology steps. Specifically this means
1186
+ that gradients will be computed and used to update parameters.
1187
+
1188
+ In cases where gradients are not needed, consider implementing the
1189
+ ``validation_step`` method instead.
1190
+
1191
+ This should mirror the ``TrainingProtocol`` definition, except that
1192
+ the model corresponds to this object.
1193
+
1194
+ Parameters
1195
+ ----------
1196
+ data: T
1197
+ The data to train on. Typically assumed to be a batch
1198
+ of data.
1199
+ args: Any
1200
+ Additional arguments to pass to the method.
1201
+ kwargs: Any
1202
+ Additional keyword arguments to pass to the method.
1203
+ """
1204
+ ...
1205
+
1206
+ def validation_step(self, data: T, *args: Any, **kwargs: Any) -> None:
1207
+ """
1208
+ Implements the validation logic for a single batch.
1209
+
1210
+ This can match the forward pass, without the need for weight updates.
1211
+ This method will be called in validation steps **only**, and not used
1212
+ for query or metrology steps. In those cases, implement the ``inference_step``
1213
+ method instead.
1214
+
1215
+ This should mirror the ``ValidationProtocol`` definition, except that
1216
+ the model corresponds to this object.
1217
+
1218
+ Parameters
1219
+ ----------
1220
+ data: T
1221
+ The data to validate on. Typically assumed to be a batch
1222
+ of data.
1223
+ args: Any
1224
+ Additional arguments to pass to the method.
1225
+ kwargs: Any
1226
+ Additional keyword arguments to pass to the method.
1227
+ """
1228
+ ...
1229
+
1230
+ def inference_step(self, data: T | S, *args: Any, **kwargs: Any) -> None:
1231
+ """
1232
+ Implements the inference logic for a single batch.
1233
+
1234
+ This can match the forward pass exactly, but provides an opportunity
1235
+ to differentiate (or lack thereof, with no pun intended). Specifically,
1236
+ this method will be called during query and metrology steps.
1237
+
1238
+ This should mirror the ``InferenceProtocol`` definition, except that
1239
+ the model corresponds to this object.
1240
+
1241
+ Parameters
1242
+ ----------
1243
+ data: T
1244
+ The data to infer on. Typically assumed to be a batch
1245
+ of data.
1246
+ args: Any
1247
+ Additional arguments to pass to the method.
1248
+ kwargs: Any
1249
+ Additional keyword arguments to pass to the method.
1250
+ """
1251
+ ...
1252
+
1253
+ @property
1254
+ def parameters(self) -> Iterator[torch.Tensor]:
1255
+ """
1256
+ Returns an iterator over the parameters of the learner.
1257
+
1258
+ If subclassing from `torch.nn.Module`, this will automatically return
1259
+ the parameters of the module.
1260
+
1261
+ Returns
1262
+ -------
1263
+ Iterator[torch.Tensor]
1264
+ An iterator over the parameters of the learner.
1265
+ """
1266
+ ...
1267
+
1268
+ def forward(self, *args: Any, **kwargs: Any) -> Any:
1269
+ """
1270
+ Implements the forward pass for a single batch.
1271
+
1272
+ This method is called between all active learning steps, and should
1273
+ contain the logic for how a model ingests data and produces predictions.
1274
+
1275
+ Parameters
1276
+ ----------
1277
+ args: Any
1278
+ Additional arguments to pass to the model.
1279
+ kwargs: Any
1280
+ Additional keyword arguments to pass to the model.
1281
+
1282
+ Returns
1283
+ -------
1284
+ Any
1285
+ The output of the model's forward pass.
1286
+ """
1287
+ ...
1288
+
1289
+
1290
+ class DriverProtocol:
1291
+ """
1292
+ This protocol specifies the expected interface for an active learning
1293
+ driver: for a concrete implementation, refer to the `driver` module
1294
+ instead. The specification is provided mostly as a reference, and for
1295
+ ease of type hinting to prevent circular imports.
1296
+
1297
+ Attributes
1298
+ ----------
1299
+ learner: LearnerProtocol
1300
+ The learner module that will be used as the surrogate within
1301
+ the active learning loop.
1302
+ query_strategies: list[QueryStrategy]
1303
+ The query strategies that will be used for selecting data points to label.
1304
+ A list of strategies can be included, and will sequentially be used to
1305
+ populate the ``query_queue`` that passes samples over to labeling.
1306
+ query_queue: AbstractQueue[T]
1307
+ The queue containing data samples to be labeled. ``QueryStrategy`` instances
1308
+ should enqueue samples to this queue.
1309
+ label_strategy: LabelStrategy | None
1310
+ The label strategy that will be used for labeling data points. In contrast
1311
+ to the other strategies, only a single label strategy is supported.
1312
+ This strategy will consume the ``query_queue`` and enqueue labeled data to
1313
+ the ``label_queue``.
1314
+ label_queue: AbstractQueue[T] | None
1315
+ The queue containing freshly labeled data. ``LabelStrategy`` instances
1316
+ should enqueue labeled data to this queue, and the driver will subsequently
1317
+ serialize data contained within this queue to a persistent format.
1318
+ metrology_strategies: list[MetrologyStrategy] | None
1319
+ The metrology strategies that will be used for assessing the performance
1320
+ of the surrogate. A list of strategies can be included, and will sequentially
1321
+ be used to populate the ``metrology_queue`` that passes data over to the
1322
+ learner.
1323
+ training_pool: DataPool[T]
1324
+ The pool of data to be used for training. This data will be used to train
1325
+ the underlying model, and is assumed to be mutable in that additional data
1326
+ can be added to the pool over the course of active learning.
1327
+ validation_pool: DataPool[T] | None
1328
+ The pool of data to be used for validation. This data will be used for both
1329
+ conventional validation, as well as for metrology. This dataset is considered
1330
+ to be immutable, and should not be modified over the course of active learning.
1331
+ This dataset is considered optional, as both validation and metrology are.
1332
+ unlabeled_pool: DataPool[T] | None
1333
+ An optional pool of data to be used for querying and labeling. If supplied,
1334
+ this dataset can be depleted by a query strategy to select data points for labeling.
1335
+ In principle, this could also represent a generative model, i.e. not just a static
1336
+ dataset, but at a high level represents a distribution of data.
1337
+ """
1338
+
1339
+ learner: LearnerProtocol
1340
+ query_strategies: list[QueryStrategy]
1341
+ query_queue: AbstractQueue[T]
1342
+ label_strategy: LabelStrategy | None
1343
+ label_queue: AbstractQueue[T] | None
1344
+ metrology_strategies: list[MetrologyStrategy] | None
1345
+ training_pool: DataPool[T]
1346
+ validation_pool: DataPool[T] | None
1347
+ unlabeled_pool: DataPool[T] | None
1348
+
1349
+ def active_learning_step(self, *args: Any, **kwargs: Any) -> None:
1350
+ """
1351
+ Implements the active learning step.
1352
+
1353
+ This step performs a single pass of the active learning loop, with the
1354
+ intended order being: training, metrology, query, labeling, with
1355
+ the metrology and labeling steps being optional.
1356
+
1357
+ Parameters
1358
+ ----------
1359
+ args: Any
1360
+ Additional arguments to pass to the method.
1361
+ kwargs: Any
1362
+ Additional keyword arguments to pass to the method.
1363
+ """
1364
+ ...
1365
+
1366
+ def _setup_logger(self) -> None:
1367
+ """
1368
+ Sets up the logger for the driver.
1369
+
1370
+ The intended concrete method should account for the ability to
1371
+ scope logging, such that things like active learning iteration
1372
+ counts, etc. can be logged.
1373
+ """
1374
+ ...
1375
+
1376
+ def attach_strategies(self) -> None:
1377
+ """
1378
+ Attaches all provided strategies.
1379
+
1380
+ This method relies on the ``attach`` method of the strategies, which
1381
+ will subsequently give the strategy access to the driver's scope.
1382
+
1383
+ Example use cases would be for any strategy (apart from label strategy)
1384
+ to access the underlying model (``LearnerProtocol``); for a query
1385
+ strategy to access the ``unlabeled_pool``; for a metrology strategy
1386
+ to access the ``validation_pool``.
1387
+ """
1388
+ for strategy in self.query_strategies:
1389
+ strategy.attach(self)
1390
+ if self.label_strategy:
1391
+ self.label_strategy.attach(self)
1392
+ if self.metrology_strategies:
1393
+ for strategy in self.metrology_strategies:
1394
+ strategy.attach(self)
physics_mcp/source/physicsnemo/constants.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ constant values used by PhysicsNeMo
19
+ """
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ # string used to determine derivatives
25
+ diff_str: str = "__"
26
+
27
+
28
+ def diff(y: str, x: str, degree: int = 1) -> str:
29
+ """Function to apply diff string"""
30
+ return diff_str.join([y] + degree * [x])
31
+
32
+
33
+ # for changing to float16 or float64
34
+ tf_dt = torch.float32
35
+ np_dt = np.float32
36
+
37
+ # tensorboard naming
38
+ TF_SUMMARY = False
39
+
40
+ # Pytorch Version for which JIT will be default on
41
+ # Torch version of NGC container 22.08
42
+ JIT_PYTORCH_VERSION = "1.13.0a0+d321be6"
43
+
44
+ # No scaling is needed if using NO_OP_SCALE
45
+ NO_OP_SCALE = (0.0, 1.0)
46
+
47
+ # If using NO_OP_NORM, it is effectively doing no normalization
48
+ NO_OP_NORM = (-1.0, 1.0)
physics_mcp/source/physicsnemo/datapipes/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
physics_mcp/source/physicsnemo/datapipes/benchmarks/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
physics_mcp/source/physicsnemo/datapipes/benchmarks/darcy.py ADDED
@@ -0,0 +1,322 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import sys
18
+ from dataclasses import dataclass
19
+ from typing import Dict, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import warp as wp
24
+
25
+ from ..datapipe import Datapipe
26
+ from ..meta import DatapipeMetaData
27
+ from .kernels.finite_difference import (
28
+ darcy_mgrid_jacobi_iterative_batched_2d,
29
+ mgrid_inf_residual_batched_2d,
30
+ )
31
+ from .kernels.initialization import init_uniform_random_4d
32
+ from .kernels.utils import (
33
+ bilinear_upsample_batched_2d,
34
+ fourier_to_array_batched_2d,
35
+ threshold_3d,
36
+ )
37
+
38
+ Tensor = torch.Tensor
39
+ # TODO unsure if better to remove this. Keeping this in for now
40
+ wp.init()
41
+
42
+
43
+ @dataclass
44
+ class MetaData(DatapipeMetaData):
45
+ name: str = "Darcy2D"
46
+ # Optimization
47
+ auto_device: bool = True
48
+ cuda_graphs: bool = True
49
+ # Parallel
50
+ ddp_sharding: bool = False
51
+
52
+
53
+ class Darcy2D(Datapipe):
54
+ """2D Darcy flow benchmark problem datapipe.
55
+
56
+ This datapipe continuously generates solutions to the 2D Darcy equation with variable
57
+ permeability. All samples are generated on the fly and is meant to be a benchmark
58
+ problem for testing data driven models. Permeability is drawn from a random Fourier
59
+ series and threshold it to give a piecewise constant function. The solution is obtained
60
+ using a GPU enabled multi-grid Jacobi iterative method.
61
+
62
+ Parameters
63
+ ----------
64
+ resolution : int, optional
65
+ Resolution to run simulation at, by default 256
66
+ batch_size : int, optional
67
+ Batch size of simulations, by default 64
68
+ nr_permeability_freq : int, optional
69
+ Number of frequencies to use for generating random permeability. Higher values
70
+ will give higher freq permeability fields., by default 5
71
+ max_permeability : float, optional
72
+ Max permeability, by default 2.0
73
+ min_permeability : float, optional
74
+ Min permeability, by default 0.5
75
+ max_iterations : int, optional
76
+ Maximum iterations to use for each multi-grid, by default 30000
77
+ convergence_threshold : float, optional
78
+ Solver L-Infinity convergence threshold, by default 1e-6
79
+ iterations_per_convergence_check : int, optional
80
+ Number of Jacobi iterations to run before checking convergence, by default 1000
81
+ nr_multigrids : int, optional
82
+ Number of multi-grid levels, by default 4
83
+ normaliser : Union[Dict[str, Tuple[float, float]], None], optional
84
+ Dictionary with keys `permeability` and `darcy`. The values for these keys are two floats corresponding to mean and std `(mean, std)`.
85
+ device : Union[str, torch.device], optional
86
+ Device for datapipe to run place data on, by default "cuda"
87
+
88
+ Raises
89
+ ------
90
+ ValueError
91
+ Incompatable multi-grid and resolution settings
92
+ """
93
+
94
+ def __init__(
95
+ self,
96
+ resolution: int = 256,
97
+ batch_size: int = 64,
98
+ nr_permeability_freq: int = 5,
99
+ max_permeability: float = 2.0,
100
+ min_permeability: float = 0.5,
101
+ max_iterations: int = 30000,
102
+ convergence_threshold: float = 1e-6,
103
+ iterations_per_convergence_check: int = 1000,
104
+ nr_multigrids: int = 4,
105
+ normaliser: Union[Dict[str, Tuple[float, float]], None] = None,
106
+ device: Union[str, torch.device] = "cuda",
107
+ ):
108
+ super().__init__(meta=MetaData())
109
+
110
+ # simulation params
111
+ self.resolution = resolution
112
+ self.batch_size = batch_size
113
+ self.nr_permeability_freq = nr_permeability_freq
114
+ self.max_permeability = max_permeability
115
+ self.min_permeability = min_permeability
116
+ self.max_iterations = max_iterations
117
+ self.convergence_threshold = convergence_threshold
118
+ self.iterations_per_convergence_check = iterations_per_convergence_check
119
+ self.nr_multigrids = nr_multigrids
120
+ self.normaliser = normaliser
121
+
122
+ # check normaliser keys
123
+ if self.normaliser is not None:
124
+ if not {"permeability", "darcy"}.issubset(set(self.normaliser.keys())):
125
+ raise ValueError(
126
+ "normaliser need to have keys permeability and darcy with mean and std"
127
+ )
128
+
129
+ # Set up device for warp, warp has same naming convention as torch.
130
+ if isinstance(device, torch.device):
131
+ device = str(device)
132
+ self.device = device
133
+
134
+ # spatial dims
135
+ self.dx = 1.0 / (self.resolution + 1) # pad edges by 1 for multi-grid
136
+ self.dim = (self.batch_size, self.resolution + 1, self.resolution + 1)
137
+ self.fourier_dim = (
138
+ 4,
139
+ self.batch_size,
140
+ self.nr_permeability_freq,
141
+ self.nr_permeability_freq,
142
+ )
143
+
144
+ # assert resolution is compatible with multi-grid method
145
+ if (resolution % 2 ** (nr_multigrids - 1)) != 0:
146
+ raise ValueError("Resolution is incompatible with number of sub grids.")
147
+
148
+ # allocate arrays for constructing dataset
149
+ self.darcy0 = wp.zeros(self.dim, dtype=float, device=self.device)
150
+ self.darcy1 = wp.zeros(self.dim, dtype=float, device=self.device)
151
+ self.permeability = wp.zeros(self.dim, dtype=float, device=self.device)
152
+ self.rand_fourier = wp.zeros(self.fourier_dim, dtype=float, device=self.device)
153
+ self.inf_residual = wp.zeros([1], dtype=float, device=self.device)
154
+
155
+ # Output tenors
156
+ self.output_k = None
157
+ self.output_p = None
158
+
159
+ def initialize_batch(self) -> None:
160
+ """Initializes arrays for new batch of simulations"""
161
+
162
+ # initialize permeability
163
+ self.permeability.zero_()
164
+ seed = np.random.randint(np.iinfo(np.uint64).max, dtype=np.uint64)
165
+ wp.launch(
166
+ kernel=init_uniform_random_4d,
167
+ dim=self.fourier_dim,
168
+ inputs=[self.rand_fourier, -1.0, 1.0, seed],
169
+ device=self.device,
170
+ )
171
+ wp.launch(
172
+ kernel=fourier_to_array_batched_2d,
173
+ dim=self.dim,
174
+ inputs=[
175
+ self.permeability,
176
+ self.rand_fourier,
177
+ self.nr_permeability_freq,
178
+ self.resolution,
179
+ self.resolution,
180
+ ],
181
+ device=self.device,
182
+ )
183
+ wp.launch(
184
+ kernel=threshold_3d,
185
+ dim=self.dim,
186
+ inputs=[
187
+ self.permeability,
188
+ 0.0,
189
+ self.min_permeability,
190
+ self.max_permeability,
191
+ ],
192
+ device=self.device,
193
+ )
194
+
195
+ # zero darcy arrays
196
+ self.darcy0.zero_()
197
+ self.darcy1.zero_()
198
+
199
+ def generate_batch(self) -> None:
200
+ """Solve for new batch of simulations"""
201
+
202
+ # initialize tensors with random permeability
203
+ self.initialize_batch()
204
+
205
+ # run solver
206
+ for res in range(self.nr_multigrids):
207
+ # calculate grid reduction factor and reduced dim
208
+ grid_reduction_factor = 2 ** (self.nr_multigrids - res - 1)
209
+ if grid_reduction_factor > 1:
210
+ multigrid_dim = tuple(
211
+ [self.batch_size] + 2 * [(self.resolution) // grid_reduction_factor]
212
+ )
213
+ else:
214
+ multigrid_dim = self.dim
215
+
216
+ # run till max steps is reached
217
+ for k in range(
218
+ self.max_iterations // self.iterations_per_convergence_check
219
+ ):
220
+ # run jacobi iterations
221
+ for s in range(self.iterations_per_convergence_check):
222
+ # iterate solver
223
+ wp.launch(
224
+ kernel=darcy_mgrid_jacobi_iterative_batched_2d,
225
+ dim=multigrid_dim,
226
+ inputs=[
227
+ self.darcy0,
228
+ self.darcy1,
229
+ self.permeability,
230
+ 1.0,
231
+ self.dim[1],
232
+ self.dim[2],
233
+ self.dx,
234
+ grid_reduction_factor,
235
+ ],
236
+ device=self.device,
237
+ )
238
+
239
+ # swap buffers
240
+ (self.darcy0, self.darcy1) = (self.darcy1, self.darcy0)
241
+
242
+ # compute residual
243
+ self.inf_residual.zero_()
244
+ wp.launch(
245
+ kernel=mgrid_inf_residual_batched_2d,
246
+ dim=multigrid_dim,
247
+ inputs=[
248
+ self.darcy0,
249
+ self.darcy1,
250
+ self.inf_residual,
251
+ grid_reduction_factor,
252
+ ],
253
+ device=self.device,
254
+ )
255
+ normalized_inf_residual = self.inf_residual.numpy()[0]
256
+
257
+ # check if converged
258
+ if normalized_inf_residual < (
259
+ self.convergence_threshold * grid_reduction_factor
260
+ ):
261
+ break
262
+
263
+ # upsample to higher resolution
264
+ if grid_reduction_factor > 1:
265
+ wp.launch(
266
+ kernel=bilinear_upsample_batched_2d,
267
+ dim=self.dim,
268
+ inputs=[
269
+ self.darcy0,
270
+ self.dim[1],
271
+ self.dim[2],
272
+ grid_reduction_factor,
273
+ ],
274
+ device=self.device,
275
+ )
276
+
277
+ def __iter__(self) -> Tuple[Tensor, Tensor]:
278
+ """
279
+ Yields
280
+ ------
281
+ Iterator[Tuple[Tensor, Tensor]]
282
+ Infinite iterator that returns a batch of (permeability, darcy pressure)
283
+ fields of size [batch, resolution, resolution]
284
+ """
285
+ # infinite generator
286
+ while True:
287
+ # run simulation
288
+ self.generate_batch()
289
+
290
+ # convert warp arrays to pytorch
291
+ permeability = wp.to_torch(self.permeability)
292
+ darcy = wp.to_torch(self.darcy0)
293
+
294
+ # add channel dims
295
+ permeability = torch.unsqueeze(permeability, axis=1)
296
+ darcy = torch.unsqueeze(darcy, axis=1)
297
+
298
+ # crop edges by 1 from multi-grid TODO messy
299
+ permeability = permeability[:, :, : self.resolution, : self.resolution]
300
+ darcy = darcy[:, :, : self.resolution, : self.resolution]
301
+
302
+ # normalize values
303
+ if self.normaliser is not None:
304
+ permeability = (
305
+ permeability - self.normaliser["permeability"][0]
306
+ ) / self.normaliser["permeability"][1]
307
+ darcy = (darcy - self.normaliser["darcy"][0]) / self.normaliser[
308
+ "darcy"
309
+ ][1]
310
+
311
+ # CUDA graphs static copies
312
+ if self.output_k is None:
313
+ self.output_k = permeability
314
+ self.output_p = darcy
315
+ else:
316
+ self.output_k.data.copy_(permeability)
317
+ self.output_p.data.copy_(darcy)
318
+
319
+ yield {"permeability": self.output_k, "darcy": self.output_p}
320
+
321
+ def __len__(self):
322
+ return sys.maxsize
physics_mcp/source/physicsnemo/datapipes/benchmarks/kelvin_helmholtz.py ADDED
@@ -0,0 +1,436 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import sys
18
+ from dataclasses import dataclass
19
+ from typing import Dict, Tuple, Union
20
+
21
+ import numpy as np
22
+ import torch
23
+ import warp as wp
24
+
25
+ from ..datapipe import Datapipe
26
+ from ..meta import DatapipeMetaData
27
+ from .kernels.finite_volume import (
28
+ euler_apply_flux_batched_2d,
29
+ euler_conserved_to_primitive_batched_2d,
30
+ euler_extrapolation_batched_2d,
31
+ euler_get_flux_batched_2d,
32
+ euler_primitive_to_conserved_batched_2d,
33
+ initialize_kelvin_helmoltz_batched_2d,
34
+ )
35
+ from .kernels.initialization import init_uniform_random_2d
36
+
37
+ Tensor = torch.Tensor
38
+ # TODO unsure if better to remove this
39
+ wp.init()
40
+
41
+
42
+ @dataclass
43
+ class MetaData(DatapipeMetaData):
44
+ name: str = "KelvinHelmholtz2D"
45
+ # Optimization
46
+ auto_device: bool = True
47
+ cuda_graphs: bool = True
48
+ # Parallel
49
+ ddp_sharding: bool = False
50
+
51
+
52
+ class KelvinHelmholtz2D(Datapipe):
53
+ """Kelvin-Helmholtz instability benchmark problem datapipe.
54
+
55
+ This datapipe continuously generates samples with random initial conditions. All samples
56
+ are generated on the fly and is meant to be a benchmark problem for testing data driven
57
+ models. Initial conditions are given in the form of small perturbations. The solution
58
+ is obtained using a GPU enabled Finite Volume Method.
59
+
60
+ Parameters
61
+ ----------
62
+ resolution : int, optional
63
+ Resolution to run simulation at, by default 512
64
+ batch_size : int, optional
65
+ Batch size of simulations, by default 16
66
+ seq_length : int, optional
67
+ Sequence length of output samples, by default 8
68
+ nr_perturbation_freq : int, optional
69
+ Number of frequencies to use for generating random initial perturbations, by default 5
70
+ perturbation_range : float, optional
71
+ Range to use for random perturbations. This value will be the max amplitude of the
72
+ initial perturbation, by default 0.1
73
+ nr_snapshots : int, optional
74
+ Number of snapshots of simulation to generate for data generation. This will
75
+ control how long the simulation is run for, by default 256
76
+ iteration_per_snapshot : int, optional
77
+ Number of finite volume steps to take between each snapshot. Each step size is
78
+ fixed as the smallest possible value that satisfies the Courant-Friedrichs-Lewy
79
+ condition, by default 32
80
+ gamma : float, optional
81
+ Heat capacity ratio, by default 5.0/3.0
82
+ normaliser : Union[Dict[str, Tuple[float, float]], None], optional
83
+ Dictionary with keys `density`, `velocity`, and `pressure`. The values for these keys are two floats corresponding to mean and std `(mean, std)`.
84
+ device : Union[str, torch.device], optional
85
+ Device for datapipe to run place data on, by default "cuda"
86
+ """
87
+
88
+ def __init__(
89
+ self,
90
+ resolution: int = 512,
91
+ batch_size: int = 16,
92
+ seq_length: int = 8,
93
+ nr_perturbation_freq: int = 5,
94
+ perturbation_range: float = 0.1,
95
+ nr_snapshots: int = 256,
96
+ iteration_per_snapshot: int = 32,
97
+ gamma: float = 5.0 / 3.0,
98
+ normaliser: Union[Dict[str, Tuple[float, float]], None] = None,
99
+ device: Union[str, torch.device] = "cuda",
100
+ ):
101
+ super().__init__(meta=MetaData())
102
+
103
+ # simulation params
104
+ self.resolution = resolution
105
+ self.batch_size = batch_size
106
+ self.seq_length = seq_length
107
+ self.nr_perturbation_freq = nr_perturbation_freq
108
+ self.perturbation_range = perturbation_range
109
+ self.nr_snapshots = nr_snapshots
110
+ self.iteration_per_snapshot = iteration_per_snapshot
111
+ self.gamma = gamma
112
+ self.courant_fac = 0.4 # hard set
113
+ self.normaliser = normaliser
114
+
115
+ # check normaliser keys
116
+ if self.normaliser is not None:
117
+ if not {"density", "velocity", "pressure"}.issubset(
118
+ set(self.normaliser.keys())
119
+ ):
120
+ raise ValueError(
121
+ "normaliser need to have keys `density`, `velocity` and `pressure` with mean and std"
122
+ )
123
+
124
+ # Set up device for warp, warp has same naming convention as torch.
125
+ if isinstance(device, torch.device):
126
+ device = str(device)
127
+ self.device = device
128
+
129
+ # spatial dims
130
+ self.dx = 1.0 / resolution
131
+ self.dt = (
132
+ self.courant_fac * self.dx / (np.sqrt(self.gamma * 5.0) + 2.0)
133
+ ) # hard set to smallest possible step needed
134
+ self.vol = self.dx**2
135
+ self.dim = (self.batch_size, self.resolution, self.resolution)
136
+
137
+ # allocate array for initial freq perturbation
138
+ self.w = wp.zeros(
139
+ (self.batch_size, self.nr_perturbation_freq),
140
+ dtype=float,
141
+ device=self.device,
142
+ )
143
+
144
+ # allocate conservation quantities
145
+ self.mass = wp.zeros(self.dim, dtype=float, device=self.device)
146
+ self.mom = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
147
+ self.e = wp.zeros(self.dim, dtype=float, device=self.device)
148
+
149
+ # allocate primitive quantities
150
+ self.rho = wp.zeros(self.dim, dtype=float, device=self.device)
151
+ self.vel = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
152
+ self.p = wp.zeros(self.dim, dtype=float, device=self.device)
153
+
154
+ # allocate flux values for computation
155
+ self.mass_flux_x = wp.zeros(self.dim, dtype=float, device=self.device)
156
+ self.mass_flux_y = wp.zeros(self.dim, dtype=float, device=self.device)
157
+ self.mom_flux_x = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
158
+ self.mom_flux_y = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
159
+ self.e_flux_x = wp.zeros(self.dim, dtype=float, device=self.device)
160
+ self.e_flux_y = wp.zeros(self.dim, dtype=float, device=self.device)
161
+
162
+ # allocate extrapolation values for computation
163
+ self.rho_xl = wp.zeros(self.dim, dtype=float, device=self.device)
164
+ self.rho_xr = wp.zeros(self.dim, dtype=float, device=self.device)
165
+ self.rho_yl = wp.zeros(self.dim, dtype=float, device=self.device)
166
+ self.rho_yr = wp.zeros(self.dim, dtype=float, device=self.device)
167
+ self.vel_xl = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
168
+ self.vel_xr = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
169
+ self.vel_yl = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
170
+ self.vel_yr = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
171
+ self.p_xl = wp.zeros(self.dim, dtype=float, device=self.device)
172
+ self.p_xr = wp.zeros(self.dim, dtype=float, device=self.device)
173
+ self.p_yl = wp.zeros(self.dim, dtype=float, device=self.device)
174
+ self.p_yr = wp.zeros(self.dim, dtype=float, device=self.device)
175
+
176
+ # allocate arrays for storing results
177
+ self.seq_rho = [
178
+ wp.zeros(self.dim, dtype=float, device=self.device)
179
+ for _ in range(self.nr_snapshots)
180
+ ]
181
+ self.seq_vel = [
182
+ wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
183
+ for _ in range(self.nr_snapshots)
184
+ ]
185
+ self.seq_p = [
186
+ wp.zeros(self.dim, dtype=float, device=self.device)
187
+ for _ in range(self.nr_snapshots)
188
+ ]
189
+
190
+ self.output_rho = None
191
+ self.output_vel = None
192
+ self.output_p = None
193
+
194
+ def initialize_batch(self) -> None:
195
+ """Initializes arrays for new batch of simulations"""
196
+
197
+ # initialize random Fourier freq
198
+ seed = np.random.randint(np.iinfo(np.uint64).max, dtype=np.uint64)
199
+ wp.launch(
200
+ init_uniform_random_2d,
201
+ dim=[self.batch_size, self.nr_perturbation_freq],
202
+ inputs=[self.w, -self.perturbation_range, self.perturbation_range, seed],
203
+ device=self.device,
204
+ )
205
+
206
+ # initialize fields
207
+ wp.launch(
208
+ initialize_kelvin_helmoltz_batched_2d,
209
+ dim=self.dim,
210
+ inputs=[
211
+ self.rho,
212
+ self.vel,
213
+ self.p,
214
+ self.w,
215
+ 0.05 / np.sqrt(2.0),
216
+ self.dim[1],
217
+ self.dim[2],
218
+ self.nr_perturbation_freq,
219
+ ],
220
+ device=self.device,
221
+ )
222
+ wp.launch(
223
+ euler_primitive_to_conserved_batched_2d,
224
+ dim=self.dim,
225
+ inputs=[
226
+ self.rho,
227
+ self.vel,
228
+ self.p,
229
+ self.mass,
230
+ self.mom,
231
+ self.e,
232
+ self.gamma,
233
+ self.vol,
234
+ self.dim[1],
235
+ self.dim[2],
236
+ ],
237
+ device=self.device,
238
+ )
239
+
240
+ def generate_batch(self) -> None:
241
+ """Solve for new batch of simulations"""
242
+
243
+ # initialize tensors with random coef
244
+ self.initialize_batch()
245
+
246
+ # run solver
247
+ for s in range(self.nr_snapshots):
248
+ # save arrays for
249
+ wp.copy(self.seq_rho[s], self.rho)
250
+ wp.copy(self.seq_vel[s], self.vel)
251
+ wp.copy(self.seq_p[s], self.p)
252
+
253
+ # iterations
254
+ for i in range(self.iteration_per_snapshot):
255
+ # compute primitives
256
+ wp.launch(
257
+ euler_conserved_to_primitive_batched_2d,
258
+ dim=self.dim,
259
+ inputs=[
260
+ self.mass,
261
+ self.mom,
262
+ self.e,
263
+ self.rho,
264
+ self.vel,
265
+ self.p,
266
+ self.gamma,
267
+ self.vol,
268
+ self.dim[1],
269
+ self.dim[2],
270
+ ],
271
+ device=self.device,
272
+ )
273
+
274
+ # compute extrapolations to faces
275
+ wp.launch(
276
+ euler_extrapolation_batched_2d,
277
+ dim=self.dim,
278
+ inputs=[
279
+ self.rho,
280
+ self.vel,
281
+ self.p,
282
+ self.rho_xl,
283
+ self.rho_xr,
284
+ self.rho_yl,
285
+ self.rho_yr,
286
+ self.vel_xl,
287
+ self.vel_xr,
288
+ self.vel_yl,
289
+ self.vel_yr,
290
+ self.p_xl,
291
+ self.p_xr,
292
+ self.p_yl,
293
+ self.p_yr,
294
+ self.gamma,
295
+ self.dx,
296
+ self.dt,
297
+ self.dim[1],
298
+ self.dim[2],
299
+ ],
300
+ device=self.device,
301
+ )
302
+
303
+ # compute fluxes
304
+ wp.launch(
305
+ euler_get_flux_batched_2d,
306
+ dim=self.dim,
307
+ inputs=[
308
+ self.rho_xl,
309
+ self.rho_xr,
310
+ self.rho_yl,
311
+ self.rho_yr,
312
+ self.vel_xl,
313
+ self.vel_xr,
314
+ self.vel_yl,
315
+ self.vel_yr,
316
+ self.p_xl,
317
+ self.p_xr,
318
+ self.p_yl,
319
+ self.p_yr,
320
+ self.mass_flux_x,
321
+ self.mass_flux_y,
322
+ self.mom_flux_x,
323
+ self.mom_flux_y,
324
+ self.e_flux_x,
325
+ self.e_flux_y,
326
+ self.gamma,
327
+ self.dim[1],
328
+ self.dim[2],
329
+ ],
330
+ device=self.device,
331
+ )
332
+
333
+ # apply fluxes
334
+ wp.launch(
335
+ euler_apply_flux_batched_2d,
336
+ dim=self.dim,
337
+ inputs=[
338
+ self.mass_flux_x,
339
+ self.mass_flux_y,
340
+ self.mom_flux_x,
341
+ self.mom_flux_y,
342
+ self.e_flux_x,
343
+ self.e_flux_y,
344
+ self.mass,
345
+ self.mom,
346
+ self.e,
347
+ self.dx,
348
+ self.dt,
349
+ self.dim[1],
350
+ self.dim[2],
351
+ ],
352
+ device=self.device,
353
+ )
354
+
355
+ def __iter__(self) -> Tuple[Tensor, Tensor, Tensor]:
356
+ """
357
+ Yields
358
+ ------
359
+ Iterator[Tuple[Tensor, Tensor]]
360
+ Infinite iterator that returns a batch of timeseries with (density, velocity, pressure)
361
+ fields of size [batch, seq_length, dim, resolution, resolution]
362
+ """
363
+ # infinite generator
364
+ while True:
365
+ # run simulation
366
+ self.generate_batch()
367
+
368
+ # return all samples generated before rerunning simulation
369
+ batch_ind = [
370
+ np.arange(self.nr_snapshots - self.seq_length)
371
+ for _ in range(self.batch_size)
372
+ ]
373
+ for b_ind in batch_ind:
374
+ np.random.shuffle(b_ind)
375
+ for bb in range(self.nr_snapshots - self.seq_length):
376
+ # run over batch to gather samples
377
+ batched_seq_rho = []
378
+ batched_seq_vel = []
379
+ batched_seq_p = []
380
+ for b in range(self.batch_size):
381
+ # gather seq from each batch
382
+ seq_rho = []
383
+ seq_vel = []
384
+ seq_p = []
385
+ for s in range(self.seq_length):
386
+ # get variables
387
+ rho = wp.to_torch(self.seq_rho[batch_ind[b][bb] + s])[b]
388
+ vel = wp.to_torch(self.seq_vel[batch_ind[b][bb] + s])[b]
389
+ p = wp.to_torch(self.seq_p[batch_ind[b][bb] + s])[b]
390
+
391
+ # add channels
392
+ rho = torch.unsqueeze(rho, 0)
393
+ vel = torch.permute(vel, (2, 0, 1))
394
+ p = torch.unsqueeze(p, 0)
395
+
396
+ # normalize values
397
+ if self.normaliser is not None:
398
+ rho = (
399
+ rho - self.normaliser["density"][0]
400
+ ) / self.normaliser["density"][1]
401
+ vel = (
402
+ vel - self.normaliser["velocity"][0]
403
+ ) / self.normaliser["velocity"][1]
404
+ p = (p - self.normaliser["pressure"][0]) / self.normaliser[
405
+ "pressure"
406
+ ][1]
407
+
408
+ # store for producing seq
409
+ seq_rho.append(rho)
410
+ seq_vel.append(vel)
411
+ seq_p.append(p)
412
+
413
+ # concat seq
414
+ batched_seq_rho.append(torch.stack(seq_rho, axis=0))
415
+ batched_seq_vel.append(torch.stack(seq_vel, axis=0))
416
+ batched_seq_p.append(torch.stack(seq_p, axis=0))
417
+
418
+ # CUDA graphs static copies
419
+ if self.output_rho is None:
420
+ # concat batches
421
+ self.output_rho = torch.stack(batched_seq_rho, axis=0)
422
+ self.output_vel = torch.stack(batched_seq_vel, axis=0)
423
+ self.output_p = torch.stack(batched_seq_p, axis=0)
424
+ else:
425
+ self.output_rho.data.copy_(torch.stack(batched_seq_rho, axis=0))
426
+ self.output_vel.data.copy_(torch.stack(batched_seq_vel, axis=0))
427
+ self.output_p.data.copy_(torch.stack(batched_seq_p, axis=0))
428
+
429
+ yield {
430
+ "density": self.output_rho,
431
+ "velocity": self.output_vel,
432
+ "pressure": self.output_p,
433
+ }
434
+
435
+ def __len__(self):
436
+ return sys.maxsize
physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/finite_difference.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ try:
19
+ import warp as wp
20
+ except ImportError:
21
+ print(
22
+ """NVIDIA WARP is required for this datapipe. This package is under the
23
+ NVIDIA Source Code License (NVSCL). To install use:
24
+
25
+ pip install warp-lang
26
+ """
27
+ )
28
+ raise SystemExit(1)
29
+
30
+ from .indexing import index_clamped_edges_batched_2d, index_zero_edges_batched_2d
31
+
32
+
33
+ @wp.kernel
34
+ def darcy_mgrid_jacobi_iterative_batched_2d(
35
+ darcy0: wp.array3d(dtype=float),
36
+ darcy1: wp.array3d(dtype=float),
37
+ permeability: wp.array3d(dtype=float),
38
+ source: float,
39
+ lx: int,
40
+ ly: int,
41
+ dx: float,
42
+ mgrid_reduction_factor: int,
43
+ ): # pragma: no cover
44
+ """Mult-grid jacobi step for Darcy equation.
45
+
46
+ Parameters
47
+ ----------
48
+ darcy0 : wp.array3d
49
+ Darcy solution previous step
50
+ darcy1 : wp.array3d
51
+ Darcy solution for next step
52
+ permeability : wp.array3d
53
+ Permeability field for Darcy equation
54
+ source : float
55
+ Source value for Darcy equation
56
+ lx : int
57
+ Length of domain in x dim
58
+ ly : int
59
+ Length of domain in y dim
60
+ dx : float
61
+ Grid cell size
62
+ mgrid_reduction_factor : int
63
+ Current multi-grid running at
64
+ """
65
+
66
+ # get index
67
+ b, x, y = wp.tid()
68
+
69
+ # update index from grid reduction factor
70
+ gx = mgrid_reduction_factor * x + (mgrid_reduction_factor - 1)
71
+ gy = mgrid_reduction_factor * y + (mgrid_reduction_factor - 1)
72
+ gdx = dx * wp.float32(mgrid_reduction_factor)
73
+
74
+ # compute darcy stensil
75
+ d_0_1 = index_zero_edges_batched_2d(
76
+ darcy0, b, gx - mgrid_reduction_factor, gy, lx, ly
77
+ )
78
+ d_2_1 = index_zero_edges_batched_2d(
79
+ darcy0, b, gx + mgrid_reduction_factor, gy, lx, ly
80
+ )
81
+ d_1_0 = index_zero_edges_batched_2d(
82
+ darcy0, b, gx, gy - mgrid_reduction_factor, lx, ly
83
+ )
84
+ d_1_2 = index_zero_edges_batched_2d(
85
+ darcy0, b, gx, gy + mgrid_reduction_factor, lx, ly
86
+ )
87
+
88
+ # compute permeability stensil
89
+ p_1_1 = index_clamped_edges_batched_2d(permeability, b, gx, gy, lx, ly)
90
+ p_0_1 = index_clamped_edges_batched_2d(
91
+ permeability, b, gx - mgrid_reduction_factor, gy, lx, ly
92
+ )
93
+ p_2_1 = index_clamped_edges_batched_2d(
94
+ permeability, b, gx + mgrid_reduction_factor, gy, lx, ly
95
+ )
96
+ p_1_0 = index_clamped_edges_batched_2d(
97
+ permeability, b, gx, gy - mgrid_reduction_factor, lx, ly
98
+ )
99
+ p_1_2 = index_clamped_edges_batched_2d(
100
+ permeability, b, gx, gy + mgrid_reduction_factor, lx, ly
101
+ )
102
+
103
+ # compute terms
104
+ dx_squared = gdx * gdx
105
+ t_1 = p_1_1 * (d_0_1 + d_2_1 + d_1_0 + d_1_2) / dx_squared
106
+ t_2 = ((p_2_1 - p_0_1) * (d_2_1 - d_0_1)) / (2.0 * gdx)
107
+ t_3 = ((p_1_2 - p_1_0) * (d_1_2 - d_1_0)) / (2.0 * gdx)
108
+
109
+ # jacobi iterative method
110
+ d_star = (t_1 + t_2 + t_3 + source) / (p_1_1 * 4.0 / dx_squared)
111
+
112
+ # buffers get swapped each iteration
113
+ darcy1[b, gx, gy] = d_star
114
+
115
+
116
+ @wp.kernel
117
+ def mgrid_inf_residual_batched_2d(
118
+ phi0: wp.array3d(dtype=float),
119
+ phi1: wp.array3d(dtype=float),
120
+ inf_res: wp.array(dtype=float),
121
+ mgrid_reduction_factor: int,
122
+ ): # pragma: no cover
123
+ """Infinity norm for checking multi-grid solutions.
124
+
125
+ Parameters
126
+ ----------
127
+ phi0 : wp.array3d
128
+ Previous solution
129
+ phi1 : wp.array3d
130
+ Current solution
131
+ inf_res : wp.array
132
+ Array to hold infinity norm value in
133
+ mgrid_reduction_factor : int
134
+ Current multi-grid running at
135
+ """
136
+ b, x, y = wp.tid()
137
+ gx = mgrid_reduction_factor * x + (mgrid_reduction_factor - 1)
138
+ gy = mgrid_reduction_factor * y + (mgrid_reduction_factor - 1)
139
+ wp.atomic_max(inf_res, 0, wp.abs(phi0[b, gx, gy] - phi1[b, gx, gy]))
physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/finite_volume.py ADDED
@@ -0,0 +1,759 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ try:
18
+ import warp as wp
19
+ except ImportError:
20
+ print(
21
+ """NVIDIA WARP is required for this datapipe. This package is under the
22
+ NVIDIA Source Code License (NVSCL). To install use:
23
+
24
+ pip install warp-lang
25
+ """
26
+ )
27
+ raise SystemExit(1)
28
+
29
+ from .indexing import (
30
+ index_periodic_edges_batched_2d,
31
+ index_vec2_periodic_edges_batched_2d,
32
+ )
33
+
34
+
35
+ @wp.func
36
+ def extrapolate_to_face_2d(
37
+ f: float, f_dx: float, f_dy: float, dx: float
38
+ ): # pragma: no cover
39
+ """Extrapolate cell values to edges of face
40
+
41
+ Parameters
42
+ ----------
43
+ f : float
44
+ Cell value
45
+ f_dx : float
46
+ X derivative of cell value
47
+ f_dy : float
48
+ Y derivative of cell value
49
+ dx : float
50
+ Cell size
51
+
52
+ Returns
53
+ -------
54
+ wp.vec4
55
+ (value on left x, value on right x, value left y, value right y)
56
+ """
57
+ f_xl = f - f_dx * (dx / 2.0)
58
+ f_xr = f + f_dx * (dx / 2.0)
59
+ f_yl = f - f_dy * (dx / 2.0)
60
+ f_yr = f + f_dy * (dx / 2.0)
61
+ return wp.vec4(f_xl, f_xr, f_yl, f_yr)
62
+
63
+
64
+ @wp.func
65
+ def apply_flux_2d(
66
+ f: float,
67
+ flux_f_xl_dx: float,
68
+ flux_f_xr_dx: float,
69
+ flux_f_yl_dy: float,
70
+ flux_f_yr_dy: float,
71
+ dx: float,
72
+ dt: float,
73
+ ): # pragma: no cover
74
+ """Apply flux to cell
75
+
76
+ Parameters
77
+ ----------
78
+ f : float
79
+ Cell value
80
+ flux_f_xl_dx : float
81
+ Left x flux
82
+ flux_f_xr_dx : float
83
+ Right x flux
84
+ flux_f_yl_dy : float
85
+ Left y flux
86
+ flux_f_yr_dy : float
87
+ Right y flux
88
+ dx : float
89
+ Cell size
90
+ dt : float
91
+ Time step size
92
+
93
+ Returns
94
+ -------
95
+ float
96
+ Cell value with added flux
97
+ """
98
+ f += -dt * dx * flux_f_xl_dx
99
+ f += dt * dx * flux_f_xr_dx
100
+ f += -dt * dx * flux_f_yl_dy
101
+ f += dt * dx * flux_f_yr_dy
102
+ return f
103
+
104
+
105
+ @wp.func
106
+ def apply_flux_vec2_2d(
107
+ f: wp.vec2,
108
+ flux_f_xl_dx: wp.vec2,
109
+ flux_f_xr_dx: wp.vec2,
110
+ flux_f_yl_dy: wp.vec2,
111
+ flux_f_yr_dy: wp.vec2,
112
+ dx: float,
113
+ dt: float,
114
+ ): # pragma: no cover
115
+ """Apply flux on cell with vector value
116
+
117
+ Parameters
118
+ ----------
119
+ f : wp.vec2
120
+ Cell vector value
121
+ flux_f_xl_dx : wp.vec2
122
+ Vector flux in x left
123
+ flux_f_xr_dx : wp.vec2
124
+ Vector flux in x right
125
+ flux_f_yl_dy : wp.vec2
126
+ Vector flux in y left
127
+ flux_f_yr_dy : wp.vec2
128
+ Vector flux in y right
129
+ dx : float
130
+ Cell size
131
+ dt : float
132
+ Time step size
133
+
134
+ Returns
135
+ -------
136
+ wp.vec2
137
+ Vector cell value with added flux
138
+ """
139
+ f += -dt * dx * flux_f_xl_dx
140
+ f += dt * dx * flux_f_xr_dx
141
+ f += -dt * dx * flux_f_yl_dy
142
+ f += dt * dx * flux_f_yr_dy
143
+ return f
144
+
145
+
146
+ @wp.func
147
+ def euler_flux_2d(
148
+ rho_l: float,
149
+ rho_r: float,
150
+ vx_l: float,
151
+ vx_r: float,
152
+ vy_l: float,
153
+ vy_r: float,
154
+ p_l: float,
155
+ p_r: float,
156
+ gamma: float,
157
+ ): # pragma: no cover
158
+ """Compute Euler flux
159
+
160
+ Parameters
161
+ ----------
162
+ rho_l : float
163
+ Density left
164
+ rho_r : float
165
+ Density right
166
+ vx_l : float
167
+ X velocity left
168
+ vx_r : float
169
+ X velocity right
170
+ vy_l : float
171
+ Y velocity left
172
+ vy_r : float
173
+ Y velocity right
174
+ p_l : float
175
+ Pressure left
176
+ p_r : float
177
+ Pressure right
178
+ gamma : float
179
+ Gas constant
180
+
181
+ Returns
182
+ -------
183
+ wp.vec4
184
+ Vector containing mass, momentum x, momentum y, and energy flux.
185
+ """
186
+ # get energies
187
+ e_l = p_l / (gamma - 1.0) + 0.5 * rho_l * (vx_l * vx_l + vy_l * vy_l)
188
+ e_r = p_r / (gamma - 1.0) + 0.5 * rho_r * (vx_r * vx_r + vy_r * vy_r)
189
+
190
+ # averaged states
191
+ rho_ave = 0.5 * (rho_l + rho_r)
192
+ momx_ave = 0.5 * (rho_l * vx_l + rho_r * vx_r)
193
+ momy_ave = 0.5 * (rho_l * vy_l + rho_r * vy_r)
194
+ e_ave = 0.5 * (e_l + e_r)
195
+ p_ave = (gamma - 1.0) * (
196
+ e_ave - 0.5 * (momx_ave * momx_ave + momy_ave * momy_ave) / rho_ave
197
+ )
198
+
199
+ # compute fluxes
200
+ flux_mass = momx_ave
201
+ flux_momx = momx_ave * momx_ave / rho_ave + p_ave
202
+ flux_momy = momx_ave * momy_ave / rho_ave
203
+ flux_e = (e_ave + p_ave) * momx_ave / rho_ave
204
+
205
+ # compute wavespeed
206
+ c_l = wp.sqrt(gamma * p_l / rho_l) + wp.abs(vx_l)
207
+ c_r = wp.sqrt(gamma * p_r / rho_r) + wp.abs(vx_r)
208
+ c = wp.max(c_l, c_r)
209
+
210
+ # add stabilizing diffusion term
211
+ flux_mass -= c * 0.5 * (rho_l - rho_r)
212
+ flux_momx -= c * 0.5 * (rho_l * vx_l - rho_r * vx_r)
213
+ flux_momy -= c * 0.5 * (rho_l * vy_l - rho_r * vy_r)
214
+ flux_e -= c * 0.5 * (e_l - e_r)
215
+
216
+ return wp.vec4(flux_mass, flux_momx, flux_momy, flux_e)
217
+
218
+
219
+ @wp.kernel
220
+ def euler_primitive_to_conserved_batched_2d(
221
+ rho: wp.array3d(dtype=float),
222
+ vel: wp.array3d(dtype=wp.vec2),
223
+ p: wp.array3d(dtype=float),
224
+ mass: wp.array3d(dtype=float),
225
+ mom: wp.array3d(dtype=wp.vec2),
226
+ e: wp.array3d(dtype=float),
227
+ gamma: float,
228
+ vol: float,
229
+ lx: int,
230
+ ly: int,
231
+ ): # pragma: no cover
232
+ """Primitive Euler to conserved values
233
+
234
+ Parameters
235
+ ----------
236
+ rho : wp.array3d
237
+ Density
238
+ vel : wp.array3d
239
+ Velocity
240
+ p : wp.array3d
241
+ Pressure
242
+ mass : wp.array3d
243
+ Mass
244
+ mom : wp.array3d
245
+ Momentum
246
+ e : wp.array3d
247
+ Energy
248
+ gamma : float
249
+ Gas constant
250
+ vol : float
251
+ Volume of cell
252
+ lx : int
253
+ Grid size x dim
254
+ ly : int
255
+ Grid size y dim
256
+ """
257
+
258
+ # get index
259
+ b, i, j = wp.tid()
260
+
261
+ # get conserve values
262
+ rho_i_j = index_periodic_edges_batched_2d(rho, b, i, j, lx, ly)
263
+ vel_i_j = index_vec2_periodic_edges_batched_2d(vel, b, i, j, lx, ly)
264
+ p_i_j = index_periodic_edges_batched_2d(p, b, i, j, lx, ly)
265
+
266
+ # get primitive values
267
+ mass_i_j = rho_i_j * vol
268
+ mom_i_j = vel_i_j * rho_i_j * vol
269
+ e_i_j = (
270
+ p_i_j / (gamma - 1.0)
271
+ + 0.5 * rho_i_j * (vel_i_j[0] * vel_i_j[0] + vel_i_j[1] * vel_i_j[1])
272
+ ) * vol
273
+
274
+ # set values
275
+ mass[b, i, j] = mass_i_j
276
+ mom[b, i, j] = mom_i_j
277
+ e[b, i, j] = e_i_j
278
+
279
+
280
+ @wp.kernel
281
+ def euler_conserved_to_primitive_batched_2d(
282
+ mass: wp.array3d(dtype=float),
283
+ mom: wp.array3d(dtype=wp.vec2),
284
+ e: wp.array3d(dtype=float),
285
+ rho: wp.array3d(dtype=float),
286
+ vel: wp.array3d(dtype=wp.vec2),
287
+ p: wp.array3d(dtype=float),
288
+ gamma: float,
289
+ vol: float,
290
+ lx: int,
291
+ ly: int,
292
+ ): # pragma: no cover
293
+ """Conserved Euler to primitive value
294
+
295
+ Parameters
296
+ ----------
297
+ mass : wp.array3d
298
+ Mass
299
+ mom : wp.array3d
300
+ Momentum
301
+ e : wp.array3d
302
+ Energy
303
+ rho : wp.array3d
304
+ Density
305
+ vel : wp.array3d
306
+ Velocity
307
+ p : wp.array3d
308
+ Pressure
309
+ gamma : float
310
+ Gas constant
311
+ vol : float
312
+ Cell volume
313
+ lx : int
314
+ Grid size X dim
315
+ ly : int
316
+ Grid size Y dim
317
+ """
318
+
319
+ # get index
320
+ b, i, j = wp.tid()
321
+
322
+ # get conserve values
323
+ mass_i_j = index_periodic_edges_batched_2d(mass, b, i, j, lx, ly)
324
+ mom_i_j = index_vec2_periodic_edges_batched_2d(mom, b, i, j, lx, ly)
325
+ e_i_j = index_periodic_edges_batched_2d(e, b, i, j, lx, ly)
326
+
327
+ # get primitive values
328
+ rho_i_j = mass_i_j / vol
329
+ vel_i_j = mom_i_j / rho_i_j / vol
330
+ p_i_j = (
331
+ e_i_j / vol
332
+ - 0.5 * rho_i_j * (vel_i_j[0] * vel_i_j[0] + vel_i_j[1] * vel_i_j[1])
333
+ ) * (gamma - 1.0)
334
+
335
+ # set values
336
+ rho[b, i, j] = rho_i_j
337
+ vel[b, i, j] = vel_i_j
338
+ p[b, i, j] = p_i_j
339
+
340
+
341
+ @wp.kernel
342
+ def euler_extrapolation_batched_2d(
343
+ rho: wp.array3d(dtype=float),
344
+ vel: wp.array3d(dtype=wp.vec2),
345
+ p: wp.array3d(dtype=float),
346
+ rho_xl: wp.array3d(dtype=float),
347
+ rho_xr: wp.array3d(dtype=float),
348
+ rho_yl: wp.array3d(dtype=float),
349
+ rho_yr: wp.array3d(dtype=float),
350
+ vel_xl: wp.array3d(dtype=wp.vec2),
351
+ vel_xr: wp.array3d(dtype=wp.vec2),
352
+ vel_yl: wp.array3d(dtype=wp.vec2),
353
+ vel_yr: wp.array3d(dtype=wp.vec2),
354
+ p_xl: wp.array3d(dtype=float),
355
+ p_xr: wp.array3d(dtype=float),
356
+ p_yl: wp.array3d(dtype=float),
357
+ p_yr: wp.array3d(dtype=float),
358
+ gamma: float,
359
+ dx: float,
360
+ dt: float,
361
+ lx: int,
362
+ ly: int,
363
+ ): # pragma: no cover
364
+ """Extrapolate Euler values to edges
365
+
366
+ Parameters
367
+ ----------
368
+ rho : wp.array3d
369
+ Density
370
+ vel : wp.array3d
371
+ Velocity
372
+ p : wp.array3d
373
+ Pressure
374
+ rho_xl : wp.array3d
375
+ Density x left
376
+ rho_xr : wp.array3d
377
+ Density x right
378
+ rho_yl : wp.array3d
379
+ Density y left
380
+ rho_yr : wp.array3d
381
+ Density y right
382
+ vel_xl : wp.array3d
383
+ Velocity x left
384
+ vel_xr : wp.array3d
385
+ Velocity x right
386
+ vel_yl : wp.array3d
387
+ Velocity y left
388
+ vel_yr : wp.array3d
389
+ Velocity y right
390
+ p_xl : wp.array3d
391
+ Pressure x left
392
+ p_xr : wp.array3d
393
+ Pressure x right
394
+ p_yl : wp.array3d
395
+ Pressure y left
396
+ p_yr : wp.array3d
397
+ Pressure y right
398
+ gamma : float
399
+ Gas constant
400
+ dx : float
401
+ Cell size
402
+ dt : float
403
+ Time step size
404
+ lx : int
405
+ Grid size x
406
+ ly : int
407
+ Grid size y
408
+ """
409
+
410
+ # get index
411
+ b, i, j = wp.tid()
412
+
413
+ # get rho stensil
414
+ rho_1_1 = index_periodic_edges_batched_2d(rho, b, i, j, lx, ly)
415
+ rho_2_1 = index_periodic_edges_batched_2d(rho, b, i + 1, j, lx, ly)
416
+ rho_1_2 = index_periodic_edges_batched_2d(rho, b, i, j + 1, lx, ly)
417
+ rho_0_1 = index_periodic_edges_batched_2d(rho, b, i - 1, j, lx, ly)
418
+ rho_1_0 = index_periodic_edges_batched_2d(rho, b, i, j - 1, lx, ly)
419
+
420
+ # get momentum stensil
421
+ vel_1_1 = index_vec2_periodic_edges_batched_2d(vel, b, i, j, lx, ly)
422
+ vel_2_1 = index_vec2_periodic_edges_batched_2d(vel, b, i + 1, j, lx, ly)
423
+ vel_1_2 = index_vec2_periodic_edges_batched_2d(vel, b, i, j + 1, lx, ly)
424
+ vel_0_1 = index_vec2_periodic_edges_batched_2d(vel, b, i - 1, j, lx, ly)
425
+ vel_1_0 = index_vec2_periodic_edges_batched_2d(vel, b, i, j - 1, lx, ly)
426
+
427
+ # get energy stensil
428
+ p_1_1 = index_periodic_edges_batched_2d(p, b, i, j, lx, ly)
429
+ p_2_1 = index_periodic_edges_batched_2d(p, b, i + 1, j, lx, ly)
430
+ p_1_2 = index_periodic_edges_batched_2d(p, b, i, j + 1, lx, ly)
431
+ p_0_1 = index_periodic_edges_batched_2d(p, b, i - 1, j, lx, ly)
432
+ p_1_0 = index_periodic_edges_batched_2d(p, b, i, j - 1, lx, ly)
433
+
434
+ # compute density grad
435
+ rho_dx = (rho_2_1 - rho_0_1) / (2.0 * dx)
436
+ rho_dy = (rho_1_2 - rho_1_0) / (2.0 * dx)
437
+
438
+ # compute velocity grad
439
+ vel_dx = (vel_2_1 - vel_0_1) / (2.0 * dx)
440
+ vel_dy = (vel_1_2 - vel_1_0) / (2.0 * dx)
441
+
442
+ # compute pressure grad
443
+ p_dx = (p_2_1 - p_0_1) / (2.0 * dx)
444
+ p_dy = (p_1_2 - p_1_0) / (2.0 * dx)
445
+
446
+ # extrapolate half time step density
447
+ rho_prime = rho_1_1 - 0.5 * dt * (
448
+ vel_1_1[0] * rho_dx
449
+ + rho_1_1 * vel_dx[0]
450
+ + vel_1_1[1] * rho_dy
451
+ + rho_1_1 * vel_dy[1]
452
+ )
453
+ vx_prime = vel_1_1[0] - 0.5 * dt * (
454
+ vel_1_1[0] * vel_dx[0] + vel_1_1[1] * vel_dy[0] + (1.0 / rho_1_1) * p_dx
455
+ )
456
+ vy_prime = vel_1_1[1] - 0.5 * dt * (
457
+ vel_1_1[0] * vel_dx[1] + vel_1_1[1] * vel_dy[1] + (1.0 / rho_1_1) * p_dy
458
+ )
459
+ p_prime = p_1_1 - 0.5 * dt * (
460
+ gamma * p_1_1 * (vel_dx[0] + vel_dy[1]) + vel_1_1[0] * p_dx + vel_1_1[1] * p_dy
461
+ )
462
+
463
+ # extrapolate in space to face centers
464
+ rho_space_extra = extrapolate_to_face_2d(rho_prime, rho_dx, rho_dy, dx)
465
+ vx_space_extra = extrapolate_to_face_2d(vx_prime, vel_dx[0], vel_dy[0], dx)
466
+ vy_space_extra = extrapolate_to_face_2d(vy_prime, vel_dx[1], vel_dy[1], dx)
467
+ p_space_extra = extrapolate_to_face_2d(p_prime, p_dx, p_dy, dx)
468
+
469
+ # store values
470
+ rho_xl[b, i, j] = rho_space_extra[0]
471
+ rho_xr[b, i, j] = rho_space_extra[1]
472
+ rho_yl[b, i, j] = rho_space_extra[2]
473
+ rho_yr[b, i, j] = rho_space_extra[3]
474
+ vel_xl[b, i, j] = wp.vec2(vx_space_extra[0], vy_space_extra[0])
475
+ vel_xr[b, i, j] = wp.vec2(vx_space_extra[1], vy_space_extra[1])
476
+ vel_yl[b, i, j] = wp.vec2(vx_space_extra[2], vy_space_extra[2])
477
+ vel_yr[b, i, j] = wp.vec2(vx_space_extra[3], vy_space_extra[3])
478
+ p_xl[b, i, j] = p_space_extra[0]
479
+ p_xr[b, i, j] = p_space_extra[1]
480
+ p_yl[b, i, j] = p_space_extra[2]
481
+ p_yr[b, i, j] = p_space_extra[3]
482
+
483
+
484
+ @wp.kernel
485
+ def euler_get_flux_batched_2d(
486
+ rho_xl: wp.array3d(dtype=float),
487
+ rho_xr: wp.array3d(dtype=float),
488
+ rho_yl: wp.array3d(dtype=float),
489
+ rho_yr: wp.array3d(dtype=float),
490
+ vel_xl: wp.array3d(dtype=wp.vec2),
491
+ vel_xr: wp.array3d(dtype=wp.vec2),
492
+ vel_yl: wp.array3d(dtype=wp.vec2),
493
+ vel_yr: wp.array3d(dtype=wp.vec2),
494
+ p_xl: wp.array3d(dtype=float),
495
+ p_xr: wp.array3d(dtype=float),
496
+ p_yl: wp.array3d(dtype=float),
497
+ p_yr: wp.array3d(dtype=float),
498
+ mass_flux_x: wp.array3d(dtype=float),
499
+ mass_flux_y: wp.array3d(dtype=float),
500
+ mom_flux_x: wp.array3d(dtype=wp.vec2),
501
+ mom_flux_y: wp.array3d(dtype=wp.vec2),
502
+ e_flux_x: wp.array3d(dtype=float),
503
+ e_flux_y: wp.array3d(dtype=float),
504
+ gamma: float,
505
+ lx: int,
506
+ ly: int,
507
+ ): # pragma: no cover
508
+ """Use extrapolated Euler values to compute fluxes
509
+
510
+ Parameters
511
+ ----------
512
+ rho_xl : wp.array3d
513
+ Density x left
514
+ rho_xr : wp.array3d
515
+ Density x right
516
+ rho_yl : wp.array3d
517
+ Density y left
518
+ rho_yr : wp.array3d
519
+ Density y right
520
+ vel_xl : wp.array3d
521
+ Velocity x left
522
+ vel_xr : wp.array3d
523
+ Velocity x right
524
+ vel_yl : wp.array3d
525
+ Velocity y left
526
+ vel_yr : wp.array3d
527
+ Velocity y right
528
+ p_xl : wp.array3d
529
+ Pressure x left
530
+ p_xr : wp.array3d
531
+ Pressure x right
532
+ p_yl : wp.array3d
533
+ Pressure y left
534
+ p_yr : wp.array3d
535
+ Pressure y right
536
+ mass_flux_x : wp.array3d
537
+ Mass flux x
538
+ mass_flux_y : wp.array3d
539
+ Mass flux y
540
+ mom_flux_x : wp.array3d
541
+ Momentum flux x
542
+ mom_flux_y : wp.array3d
543
+ Momentum flux y
544
+ e_flux_x : wp.array3d
545
+ Energy flux x
546
+ e_flux_y : wp.array3d
547
+ Energy flux y
548
+ gamma : float
549
+ Gas constant
550
+ lx : int
551
+ Grid size x
552
+ ly : int
553
+ Grid size y
554
+ """
555
+
556
+ # get index
557
+ b, i, j = wp.tid()
558
+
559
+ # get space extrapolation for faces
560
+ rho_xl_1 = index_periodic_edges_batched_2d(rho_xl, b, i + 1, j, lx, ly)
561
+ rho_xr_0 = index_periodic_edges_batched_2d(rho_xr, b, i, j, lx, ly)
562
+ rho_yl_1 = index_periodic_edges_batched_2d(rho_yl, b, i, j + 1, lx, ly)
563
+ rho_yr_0 = index_periodic_edges_batched_2d(rho_yr, b, i, j, lx, ly)
564
+ vel_xl_1 = index_vec2_periodic_edges_batched_2d(vel_xl, b, i + 1, j, lx, ly)
565
+ vel_xr_0 = index_vec2_periodic_edges_batched_2d(vel_xr, b, i, j, lx, ly)
566
+ vel_yl_1 = index_vec2_periodic_edges_batched_2d(vel_yl, b, i, j + 1, lx, ly)
567
+ vel_yr_0 = index_vec2_periodic_edges_batched_2d(vel_yr, b, i, j, lx, ly)
568
+ p_xl_1 = index_periodic_edges_batched_2d(p_xl, b, i + 1, j, lx, ly)
569
+ p_xr_0 = index_periodic_edges_batched_2d(p_xr, b, i, j, lx, ly)
570
+ p_yl_1 = index_periodic_edges_batched_2d(p_yl, b, i, j + 1, lx, ly)
571
+ p_yr_0 = index_periodic_edges_batched_2d(p_yr, b, i, j, lx, ly)
572
+
573
+ # compute fluxes
574
+ flux_x = euler_flux_2d(
575
+ rho_xl_1,
576
+ rho_xr_0,
577
+ vel_xl_1[0],
578
+ vel_xr_0[0],
579
+ vel_xl_1[1],
580
+ vel_xr_0[1],
581
+ p_xl_1,
582
+ p_xr_0,
583
+ gamma,
584
+ )
585
+ flux_y = euler_flux_2d(
586
+ rho_yl_1,
587
+ rho_yr_0,
588
+ vel_yl_1[1],
589
+ vel_yr_0[1],
590
+ vel_yl_1[0],
591
+ vel_yr_0[0],
592
+ p_yl_1,
593
+ p_yr_0,
594
+ gamma,
595
+ )
596
+
597
+ # set values
598
+ mass_flux_x[b, i, j] = flux_x[0]
599
+ mass_flux_y[b, i, j] = flux_y[0]
600
+ mom_flux_x[b, i, j] = wp.vec2(flux_x[1], flux_x[2])
601
+ mom_flux_y[b, i, j] = wp.vec2(flux_y[2], flux_y[1])
602
+ e_flux_x[b, i, j] = flux_x[3]
603
+ e_flux_y[b, i, j] = flux_y[3]
604
+
605
+
606
+ @wp.kernel
607
+ def euler_apply_flux_batched_2d(
608
+ mass_flux_x: wp.array3d(dtype=float),
609
+ mass_flux_y: wp.array3d(dtype=float),
610
+ mom_flux_x: wp.array3d(dtype=wp.vec2),
611
+ mom_flux_y: wp.array3d(dtype=wp.vec2),
612
+ e_flux_x: wp.array3d(dtype=float),
613
+ e_flux_y: wp.array3d(dtype=float),
614
+ mass: wp.array3d(dtype=float),
615
+ mom: wp.array3d(dtype=wp.vec2),
616
+ e: wp.array3d(dtype=float),
617
+ dx: float,
618
+ dt: float,
619
+ lx: int,
620
+ ly: int,
621
+ ): # pragma: no cover
622
+ """Apply fluxes to Euler values
623
+
624
+ Parameters
625
+ ----------
626
+ mass_flux_x : wp.array3d
627
+ Mass flux X
628
+ mass_flux_y : wp.array3d
629
+ Mass flux Y
630
+ mom_flux_x : wp.array3d
631
+ Momentum flux X
632
+ mom_flux_y : wp.array3d
633
+ Momentum flux Y
634
+ e_flux_x : wp.array3d
635
+ Energy flux X
636
+ e_flux_y : wp.array3d
637
+ Energy flux Y
638
+ mass : wp.array3d
639
+ Mass
640
+ mom : wp.array3d
641
+ Momentum
642
+ e : wp.array3d
643
+ Energy
644
+ dx : float
645
+ Cell size
646
+ dt : float
647
+ Time step size
648
+ lx : int
649
+ Grid size x
650
+ ly : int
651
+ Grid size y
652
+ """
653
+
654
+ # get index
655
+ b, i, j = wp.tid()
656
+
657
+ # get new mass
658
+ mass_1 = index_periodic_edges_batched_2d(mass, b, i, j, lx, ly)
659
+ mass_flux_x_1 = index_periodic_edges_batched_2d(mass_flux_x, b, i, j, lx, ly)
660
+ mass_flux_x_0 = index_periodic_edges_batched_2d(mass_flux_x, b, i - 1, j, lx, ly)
661
+ mass_flux_y_1 = index_periodic_edges_batched_2d(mass_flux_y, b, i, j, lx, ly)
662
+ mass_flux_y_0 = index_periodic_edges_batched_2d(mass_flux_y, b, i, j - 1, lx, ly)
663
+ new_mass = apply_flux_2d(
664
+ mass_1, mass_flux_x_1, mass_flux_x_0, mass_flux_y_1, mass_flux_y_0, dx, dt
665
+ )
666
+
667
+ # get new mom
668
+ mom_1 = index_vec2_periodic_edges_batched_2d(mom, b, i, j, lx, ly)
669
+ mom_flux_x_1 = index_vec2_periodic_edges_batched_2d(mom_flux_x, b, i, j, lx, ly)
670
+ mom_flux_x_0 = index_vec2_periodic_edges_batched_2d(mom_flux_x, b, i - 1, j, lx, ly)
671
+ mom_flux_y_1 = index_vec2_periodic_edges_batched_2d(mom_flux_y, b, i, j, lx, ly)
672
+ mom_flux_y_0 = index_vec2_periodic_edges_batched_2d(mom_flux_y, b, i, j - 1, lx, ly)
673
+ new_mom = apply_flux_vec2_2d(
674
+ mom_1, mom_flux_x_1, mom_flux_x_0, mom_flux_y_1, mom_flux_y_0, dx, dt
675
+ )
676
+
677
+ # get new energy
678
+ e_1 = index_periodic_edges_batched_2d(e, b, i, j, lx, ly)
679
+ e_flux_x_1 = index_periodic_edges_batched_2d(e_flux_x, b, i, j, lx, ly)
680
+ e_flux_x_0 = index_periodic_edges_batched_2d(e_flux_x, b, i - 1, j, lx, ly)
681
+ e_flux_y_1 = index_periodic_edges_batched_2d(e_flux_y, b, i, j, lx, ly)
682
+ e_flux_y_0 = index_periodic_edges_batched_2d(e_flux_y, b, i, j - 1, lx, ly)
683
+ new_e = apply_flux_2d(e_1, e_flux_x_1, e_flux_x_0, e_flux_y_1, e_flux_y_0, dx, dt)
684
+
685
+ # set values
686
+ mass[b, i, j] = new_mass
687
+ mom[b, i, j] = new_mom
688
+ e[b, i, j] = new_e
689
+
690
+
691
+ @wp.kernel
692
+ def initialize_kelvin_helmoltz_batched_2d(
693
+ rho: wp.array3d(dtype=float),
694
+ vel: wp.array3d(dtype=wp.vec2),
695
+ p: wp.array3d(dtype=float),
696
+ w: wp.array2d(dtype=float),
697
+ sigma: float,
698
+ lx: float,
699
+ ly: float,
700
+ nr_freq: int,
701
+ ): # pragma: no cover
702
+ """Initialize state for Kelvin Helmoltz Instability
703
+
704
+ Parameters
705
+ ----------
706
+ rho : wp.array3d
707
+ Density
708
+ vel : wp.array3d
709
+ Velocity
710
+ p : wp.array3d
711
+ Pressure
712
+ w : wp.array2d
713
+ Perturbation frequency amplitude
714
+ sigma : float
715
+ Perturbation sigma
716
+ vol : float
717
+ Volume of cell
718
+ gamma : float
719
+ Gas constant
720
+ lx : float
721
+ Grid size x
722
+ ly : float
723
+ Grid size y
724
+ nr_freq : int
725
+ Number of frequencies in perturbation
726
+ """
727
+
728
+ # get cell coords
729
+ b, i, j = wp.tid()
730
+ x = wp.float(i) / wp.float(lx)
731
+ y = wp.float(j) / wp.float(ly)
732
+
733
+ # initial flow bands
734
+ if wp.abs(y - 0.5) < 0.25:
735
+ ux = 0.5
736
+ r = 2.0
737
+ else:
738
+ ux = -0.5
739
+ r = 1.0
740
+
741
+ # perturbation
742
+ uy = wp.float32(0.0)
743
+ for f in range(nr_freq):
744
+ ff = wp.float32(f + 1)
745
+ uy += (
746
+ ff
747
+ * w[b, f]
748
+ * wp.sin(4.0 * 3.14159 * x * ff)
749
+ * (
750
+ wp.exp(-(y - 0.25) * (y - 0.25) / (2.0 * sigma * sigma))
751
+ + wp.exp(-(y - 0.75) * (y - 0.75) / (2.0 * sigma * sigma))
752
+ )
753
+ )
754
+ u = wp.vec2(ux, uy)
755
+
756
+ # set values
757
+ rho[b, i, j] = r
758
+ vel[b, i, j] = u
759
+ p[b, i, j] = 2.5
physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/indexing.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ try:
18
+ import warp as wp
19
+ except ImportError:
20
+ print(
21
+ """NVIDIA WARP is required for this datapipe. This package is under the
22
+ NVIDIA Source Code License (NVSCL). To install use:
23
+
24
+ pip install warp-lang
25
+ """
26
+ )
27
+ raise SystemExit(1)
28
+
29
+
30
+ # TODO bug in warp mod function
31
+ @wp.func
32
+ def _mod_int(x: int, length: int): # pragma: no cover
33
+ """Mod int
34
+
35
+ Parameters
36
+ ----------
37
+ x : int
38
+ Int to mod
39
+ length : int
40
+ Mod by value
41
+
42
+ Returns
43
+ -------
44
+ int
45
+ Mod of x
46
+ """
47
+ if x < 0:
48
+ return x + length
49
+ elif x > length - 1:
50
+ return x - length
51
+ return x
52
+
53
+
54
+ @wp.func
55
+ def index_zero_edges_batched_2d(
56
+ array: wp.array3d(dtype=float), b: int, x: int, y: int, lx: int, ly: int
57
+ ): # pragma: no cover
58
+ """Index batched 2d array with zero on edges
59
+
60
+ Parameters
61
+ ----------
62
+ array : wp.array3d
63
+ Array to index
64
+ b : int
65
+ Batch index
66
+ x : int
67
+ X index
68
+ y : int
69
+ Y index
70
+ lx : int
71
+ Grid size x
72
+ ly : int
73
+ Grid size y
74
+
75
+ Returns
76
+ -------
77
+ float
78
+ Array value
79
+ """
80
+ if x == -1:
81
+ return 0.0
82
+ elif x == lx:
83
+ return 0.0
84
+ elif y == -1:
85
+ return 0.0
86
+ elif y == ly:
87
+ return 0.0
88
+ else:
89
+ return array[b, x, y]
90
+
91
+
92
+ @wp.func
93
+ def index_clamped_edges_batched_2d(
94
+ array: wp.array3d(dtype=float), b: int, x: int, y: int, lx: int, ly: int
95
+ ): # pragma: no cover
96
+ """Index batched 2d array with edges clamped to same value
97
+
98
+ Parameters
99
+ ----------
100
+ array : wp.array3d
101
+ Array to index
102
+ b : int
103
+ Batch index
104
+ x : int
105
+ X index
106
+ y : int
107
+ Y index
108
+ lx : int
109
+ Grid size x
110
+ ly : int
111
+ Grid size y
112
+
113
+ Returns
114
+ -------
115
+ float
116
+ Array value
117
+ """
118
+ x = wp.clamp(x, 0, lx - 1)
119
+ y = wp.clamp(y, 0, ly - 1)
120
+ return array[b, x, y]
121
+
122
+
123
+ @wp.func
124
+ def index_periodic_edges_batched_2d(
125
+ array: wp.array3d(dtype=float), b: int, x: int, y: int, lx: int, ly: int
126
+ ): # pragma: no cover
127
+ """Index batched 2d array with periodic edges
128
+
129
+ Parameters
130
+ ----------
131
+ array : wp.array3d
132
+ Array to index
133
+ b : int
134
+ Batch index
135
+ x : int
136
+ X index
137
+ y : int
138
+ Y index
139
+ lx : int
140
+ Grid size x
141
+ ly : int
142
+ Grid size y
143
+
144
+ Returns
145
+ -------
146
+ float
147
+ Array value
148
+ """
149
+ x = _mod_int(x, lx)
150
+ y = _mod_int(y, ly)
151
+ return array[b, x, y]
152
+
153
+
154
+ @wp.func
155
+ def index_vec2_periodic_edges_batched_2d(
156
+ vec: wp.array3d(dtype=wp.vec2), b: int, x: int, y: int, lx: int, ly: int
157
+ ): # pragma: no cover
158
+ """Index batched 2d array of wp.vec2 with periodic edges
159
+
160
+ Parameters
161
+ ----------
162
+ vec : wp.array3d
163
+ Array to index
164
+ b : int
165
+ Batch index
166
+ x : int
167
+ X index
168
+ y : int
169
+ Y index
170
+ lx : int
171
+ Grid size x
172
+ ly : int
173
+ Grid size y
174
+
175
+ Returns
176
+ -------
177
+ wp.vec2
178
+ Vector value
179
+ """
180
+ x = _mod_int(x, lx)
181
+ y = _mod_int(y, ly)
182
+ return vec[b, x, y]
physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/initialization.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ try:
18
+ import warp as wp
19
+ except ImportError:
20
+ print(
21
+ """NVIDIA WARP is required for this datapipe. This package is under the
22
+ NVIDIA Source Code License (NVSCL). To install use:
23
+
24
+ pip install warp-lang
25
+ """
26
+ )
27
+ raise SystemExit(1)
28
+
29
+
30
+ @wp.kernel
31
+ def init_uniform_random_2d(
32
+ array: wp.array2d(dtype=float),
33
+ min_value: float,
34
+ max_value: float,
35
+ external_seed: int,
36
+ ): # pragma: no cover
37
+ """Initialize 2d array with uniform random values
38
+
39
+ Parameters
40
+ ----------
41
+ array : wp.array2d
42
+ Array to initialize
43
+ min_value : float
44
+ Min random value
45
+ max_value : float
46
+ Max random value
47
+ external_seed : int
48
+ External seed to use
49
+ """
50
+ i, j = wp.tid()
51
+ state = wp.rand_init(external_seed, wp.tid())
52
+ array[i, j] = wp.randf(state, -min_value, max_value)
53
+
54
+
55
+ @wp.kernel
56
+ def init_uniform_random_4d(
57
+ array: wp.array4d(dtype=float),
58
+ min_value: float,
59
+ max_value: float,
60
+ external_seed: int,
61
+ ): # pragma: no cover
62
+ """Initialize 4d array with uniform random values
63
+
64
+ Parameters
65
+ ----------
66
+ array : wp.array4d
67
+ Array to initialize
68
+ min_value : float
69
+ Min random value
70
+ max_value : float
71
+ Max random value
72
+ external_seed : int
73
+ External seed to use
74
+ """
75
+ b, i, j, k = wp.tid()
76
+ state = wp.rand_init(external_seed, wp.tid())
77
+ array[b, i, j, k] = wp.randf(state, min_value, max_value)
physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/utils.py ADDED
@@ -0,0 +1,141 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ try:
18
+ import warp as wp
19
+ except ImportError:
20
+ print(
21
+ """NVIDIA WARP is required for this datapipe. This package is under the
22
+ NVIDIA Source Code License (NVSCL). To install use:
23
+
24
+ pip install warp-lang
25
+ """
26
+ )
27
+ raise SystemExit(1)
28
+
29
+ from .indexing import index_zero_edges_batched_2d
30
+
31
+
32
+ @wp.kernel
33
+ def bilinear_upsample_batched_2d(
34
+ array: wp.array3d(dtype=float), lx: int, ly: int, grid_reduction_factor: int
35
+ ): # pragma: no cover
36
+ """Bilinear upsampling from batch 2d array
37
+
38
+ Parameters
39
+ ----------
40
+ array : wp.array3d
41
+ Array to perform upsampling on
42
+ lx : int
43
+ Grid size X
44
+ ly : int
45
+ Grid size Y
46
+ grid_reduction_factor : int
47
+ Grid reduction factor for multi-grid
48
+ """
49
+ # get index
50
+ b, x, y = wp.tid()
51
+
52
+ # get four neighbors coordinates
53
+ x_0 = x - (x + 1) % grid_reduction_factor
54
+ x_1 = x + (x + 1) % grid_reduction_factor
55
+ y_0 = y - (y + 1) % grid_reduction_factor
56
+ y_1 = y + (y + 1) % grid_reduction_factor
57
+
58
+ # simple linear upsampling
59
+ d_0_0 = index_zero_edges_batched_2d(array, b, x_0, y_0, lx, ly)
60
+ d_1_0 = index_zero_edges_batched_2d(array, b, x_1, y_0, lx, ly)
61
+ d_0_1 = index_zero_edges_batched_2d(array, b, x_0, y_1, lx, ly)
62
+ d_1_1 = index_zero_edges_batched_2d(array, b, x_1, y_1, lx, ly)
63
+
64
+ # get relative distance
65
+ rel_x = wp.float32(x - x_0) / wp.float32(grid_reduction_factor)
66
+ rel_y = wp.float32(y - y_0) / wp.float32(grid_reduction_factor)
67
+
68
+ # interpolation in x direction
69
+ d_x_0 = (1.0 - rel_x) * d_0_0 + rel_x * d_1_0
70
+ d_x_1 = (1.0 - rel_x) * d_0_1 + rel_x * d_1_1
71
+
72
+ # interpolation in y direction
73
+ d = (1.0 - rel_y) * d_x_0 + rel_y * d_x_1
74
+
75
+ # set interpolation
76
+ array[b, x, y] = d
77
+
78
+
79
+ @wp.kernel
80
+ def threshold_3d(
81
+ array: wp.array3d(dtype=float), threshold: float, min_value: float, max_value: float
82
+ ): # pragma: no cover
83
+ """Threshold 3d array by value. Values bellow threshold will be `min_value` and those above will be `max_value`.
84
+
85
+ Parameters
86
+ ----------
87
+ array : wp.array3d
88
+ Array to apply threshold on
89
+ threshold : float
90
+ Threshold value
91
+ min_value : float
92
+ Value to set if bellow threshold
93
+ max_value : float
94
+ Value to set if above threshold
95
+ """
96
+ i, j, k = wp.tid()
97
+ if array[i, j, k] < threshold:
98
+ array[i, j, k] = min_value
99
+ else:
100
+ array[i, j, k] = max_value
101
+
102
+
103
+ @wp.kernel
104
+ def fourier_to_array_batched_2d(
105
+ array: wp.array3d(dtype=float),
106
+ fourier: wp.array4d(dtype=float),
107
+ nr_freq: int,
108
+ lx: int,
109
+ ly: int,
110
+ ): # pragma: no cover
111
+ """Array of Fourier amplitudes to batched 2d spatial array
112
+
113
+ Parameters
114
+ ----------
115
+ array : wp.array3d
116
+ Spatial array
117
+ fourier : wp.array4d
118
+ Array of Fourier amplitudes
119
+ nr_freq : int
120
+ Number of frequencies in Fourier array
121
+ lx : int
122
+ Grid size x
123
+ ly : int
124
+ Grid size y
125
+ """
126
+ b, x, y = wp.tid()
127
+ dx = 6.28318 / wp.float32(lx)
128
+ dy = 6.28318 / wp.float32(ly)
129
+ rx = dx * wp.float32(x)
130
+ ry = dy * wp.float32(y)
131
+ for i in range(nr_freq):
132
+ for j in range(nr_freq):
133
+ ri = wp.float32(i)
134
+ rj = wp.float32(j)
135
+ ss = fourier[0, b, i, j] * wp.sin(ri * rx) * wp.sin(rj * ry)
136
+ cs = fourier[1, b, i, j] * wp.cos(ri * rx) * wp.sin(rj * ry)
137
+ sc = fourier[2, b, i, j] * wp.sin(ri * rx) * wp.cos(rj * ry)
138
+ cc = fourier[3, b, i, j] * wp.cos(ri * rx) * wp.cos(rj * ry)
139
+ wp.atomic_add(
140
+ array, b, x, y, 1.0 / (wp.float32(nr_freq) ** 2.0) * (ss + cs + sc + cc)
141
+ )
physics_mcp/source/physicsnemo/datapipes/cae/__init__.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from .domino_datapipe import DoMINODataPipe
18
+ from .mesh_datapipe import MeshDatapipe
physics_mcp/source/physicsnemo/datapipes/cae/cae_dataset.py ADDED
@@ -0,0 +1,1275 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import pathlib
18
+ import time
19
+ from abc import ABC, abstractmethod
20
+ from concurrent.futures import ThreadPoolExecutor
21
+
22
+ import numpy as np
23
+ import torch
24
+ import torch.distributed as dist
25
+ import zarr
26
+ from torch.distributed.tensor import Replicate, Shard
27
+
28
+ try:
29
+ import tensorstore as ts
30
+
31
+ TENSORSTORE_AVAILABLE = True
32
+ except ImportError:
33
+ TENSORSTORE_AVAILABLE = False
34
+
35
+ try:
36
+ import pyvista as pv
37
+
38
+ PV_AVAILABLE = True
39
+ except ImportError:
40
+ PV_AVAILABLE = False
41
+
42
+ from physicsnemo.distributed import ShardTensor, ShardTensorSpec
43
+ from physicsnemo.distributed.utils import compute_split_shapes
44
+
45
+ # Abstractions:
46
+ # - want to read npy/npz/.zarr/.stl/.vtp files
47
+ # - Need to share next level abstractions
48
+ # - Domain parallel dataloading is supported: output will be ShardTensor instead.
49
+ # - need to be able to configure preprocessing
50
+ # - CPU -> GPU transfer happens here, needs to be isolated in it's own stream
51
+ # - Output of dataloader should be torch.Tensor objects.
52
+
53
+
54
+ """
55
+ This datapipe handles reading files from Zarr and piping into torch.Tensor objects.
56
+
57
+ It's expected that the files are organized as groups, with each .zarr
58
+ file representing one training example. To improve IO performance, the files
59
+ should be chunked for each array. The reader takes a list of keys in the
60
+ group to read, and will not read keys that are not specified. The exception
61
+ is if _no_ keys are passed, in which case _all_ keys will be read.
62
+ """
63
+
64
+
65
+ class BackendReader(ABC):
66
+ """
67
+ Abstract base class for backend readers.
68
+ """
69
+
70
+ def __init__(
71
+ self,
72
+ keys_to_read: list[str] | None,
73
+ keys_to_read_if_available: dict[str, torch.Tensor] | None,
74
+ ) -> None:
75
+ """
76
+ Initialize the backend reader.
77
+ """
78
+ self.keys_to_read = keys_to_read
79
+ self.keys_to_read_if_available = keys_to_read_if_available
80
+
81
+ self.volume_sampling_size = None
82
+
83
+ self.is_volumetric = any(["volume" in key for key in self.keys_to_read])
84
+
85
+ @abstractmethod
86
+ def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
87
+ """
88
+ Read a file and return a dictionary of tensors.
89
+ """
90
+ pass
91
+
92
+ @abstractmethod
93
+ def read_file_sharded(
94
+ self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh
95
+ ) -> tuple[dict[str, torch.Tensor], dict[str, dict]]:
96
+ """
97
+ Read a file and return a dictionary of tensors ready to convert to ShardTensors.
98
+
99
+ NOTE: this function does not actually convert torch tensors to ShardTensors.
100
+ It's possible that the conversion, in some cases, can be a collective function.
101
+ Due to the async nature of the loader, we don't rely on any ordering of
102
+ collectives and defer them to the last possible minute.
103
+
104
+ Additionally, these functions return CPU tensors and we don't actually
105
+ define shard tensors on cpu.
106
+
107
+ So, the dataset itself will convert a local tensor + shard info to shard tensor
108
+ after the cpu-> gpu movement.
109
+ """
110
+ pass
111
+
112
+ def fill_optional_keys(
113
+ self, data: dict[str, torch.Tensor]
114
+ ) -> dict[str, torch.Tensor]:
115
+ """
116
+ Fill missing keys with the keys from the keys_to_read_if_available dictionary.
117
+ """
118
+ for key in self.keys_to_read_if_available:
119
+ if key not in data.keys():
120
+ data[key] = self.keys_to_read_if_available[key]
121
+ return data
122
+
123
+ def _get_slice_boundaries(
124
+ self, array_shape: tuple[int], this_rank: int, n_splits: int, split_dim: int = 0
125
+ ) -> tuple[int, int, tuple | None]:
126
+ """
127
+ For an array, determine the slice boundaries for parallel reading.
128
+
129
+ Args:
130
+ array_shape: The total shape of the target array.
131
+ this_rank: The rank of the distributed process.
132
+ n_splits: The size of the distributed process.
133
+ split_dim: The dimension to split, default is 0.
134
+
135
+ Returns:
136
+ The slice boundaries for parallel reading.
137
+ """
138
+ # Determine what slice this rank should read
139
+
140
+ sections = compute_split_shapes(array_shape[split_dim], n_splits)
141
+
142
+ global_chunk_start = sum(sections[:this_rank])
143
+ global_chunk_stop = global_chunk_start + sections[this_rank]
144
+
145
+ chunk_sizes = tuple(
146
+ array_shape[:split_dim] + (section,) + array_shape[split_dim + 1 :]
147
+ for section in sections
148
+ )
149
+
150
+ return global_chunk_start, global_chunk_stop, chunk_sizes
151
+
152
+ def set_volume_sampling_size(self, volume_sampling_size: int):
153
+ """
154
+ Set the volume sampling size. When set, the readers will
155
+ assume the volumetric data is shuffled on disk and read only
156
+ contiguous chunks of the data up to the sampling size.
157
+
158
+
159
+ Args:
160
+ volume_sampling_size: The total size of the volume sampling.
161
+
162
+ """
163
+ self.volume_sampling_size = volume_sampling_size
164
+
165
+ def select_random_sections_from_slice(
166
+ self,
167
+ slice_start: int,
168
+ slice_stop: int,
169
+ n_points: int,
170
+ ) -> slice:
171
+ """
172
+
173
+ select the contiguous chunks of the volume data to read.
174
+
175
+ Args:
176
+ n_volume_points: The number of points to sample from the volume.
177
+
178
+ Returns:
179
+ A tuple of the start and stop indices of the contiguous chunks.
180
+ """
181
+
182
+ if slice_stop - slice_start < n_points:
183
+ raise ValueError(
184
+ f"Slice size {slice_stop - slice_start} is less than the number of points {n_points}"
185
+ )
186
+
187
+ # Choose a random start point that will fit the entire n_points region:
188
+ start = np.random.randint(slice_start, slice_stop - n_points)
189
+ return slice(start, start + n_points)
190
+
191
+
192
+ class NpyFileReader(BackendReader):
193
+ """
194
+ Reader for numpy files.
195
+ """
196
+
197
+ def __init__(
198
+ self,
199
+ keys_to_read: list[str] | None,
200
+ keys_to_read_if_available: dict[str, torch.Tensor] | None,
201
+ ) -> None:
202
+ super().__init__(keys_to_read, keys_to_read_if_available)
203
+
204
+ def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
205
+ """
206
+ Read a file and return a dictionary of tensors.
207
+ """
208
+ data = np.load(filename, allow_pickle=True).item()
209
+
210
+ missing_keys = set(self.keys_to_read) - set(data.keys())
211
+
212
+ if len(missing_keys) > 0:
213
+ raise ValueError(f"Keys {missing_keys} not found in file {filename}")
214
+
215
+ data = {key: torch.from_numpy(data[key]) for key in self.keys_to_read}
216
+
217
+ return self.fill_optional_keys(data)
218
+
219
+ def read_file_sharded(
220
+ self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh
221
+ ) -> dict[str, ShardTensor]:
222
+ pass
223
+
224
+ def set_volume_sampling_size(self, volume_sampling_size: int):
225
+ """
226
+ This is not supported for npy files.
227
+ """
228
+ raise NotImplementedError(
229
+ "volume sampling directly from disk is not supported for npy files."
230
+ )
231
+
232
+
233
+ class NpzFileReader(BackendReader):
234
+ """
235
+ Reader for npz files.
236
+ """
237
+
238
+ def __init__(
239
+ self,
240
+ keys_to_read: list[str] | None,
241
+ keys_to_read_if_available: dict[str, torch.Tensor] | None,
242
+ ) -> None:
243
+ super().__init__(keys_to_read, keys_to_read_if_available)
244
+
245
+ def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
246
+ """
247
+ Read a file and return a dictionary of tensors.
248
+ """
249
+ in_data = np.load(filename)
250
+
251
+ keys_found = set(in_data.keys())
252
+ keys_missing = set(self.keys_to_read) - keys_found
253
+ if len(keys_missing) > 0:
254
+ raise ValueError(f"Keys {keys_missing} not found in file {filename}")
255
+
256
+ # Make sure to select the slice outside of the loop.
257
+ if self.is_volumetric:
258
+ if self.volume_sampling_size is not None:
259
+ volume_slice = self.select_random_sections_from_slice(
260
+ 0,
261
+ in_data["volume_mesh_centers"].shape[0],
262
+ self.volume_sampling_size,
263
+ )
264
+ else:
265
+ volume_slice = slice(0, in_data["volume_mesh_centers"].shape[0])
266
+
267
+ # This is a slower basic way to do this, to be improved:
268
+ data = {}
269
+ for key in self.keys_to_read:
270
+ if "volume" not in key:
271
+ data[key] = torch.from_numpy(in_data[key][:])
272
+ else:
273
+ data[key] = torch.from_numpy(in_data[key][volume_slice])
274
+
275
+ # data = {key: torch.from_numpy(in_data[key][:]) for key in self.keys_to_read}
276
+
277
+ return self.fill_optional_keys(data)
278
+
279
+ def read_file_sharded(
280
+ self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh
281
+ ) -> dict[str, ShardTensor]:
282
+ pass
283
+
284
+ def set_volume_sampling_size(self, volume_sampling_size: int):
285
+ """
286
+ This is not supported for npz files.
287
+ """
288
+ raise NotImplementedError(
289
+ "volume sampling directly from disk is not supported for npz files."
290
+ )
291
+
292
+
293
+ class ZarrFileReader(BackendReader):
294
+ """
295
+ Reader for zarr files.
296
+ """
297
+
298
+ def __init__(
299
+ self,
300
+ keys_to_read: list[str] | None,
301
+ keys_to_read_if_available: dict[str, torch.Tensor] | None,
302
+ ) -> None:
303
+ super().__init__(keys_to_read, keys_to_read_if_available)
304
+
305
+ def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
306
+ """
307
+ Read a file and return a dictionary of tensors.
308
+ """
309
+ group = zarr.open_group(filename, mode="r")
310
+
311
+ missing_keys = set(self.keys_to_read) - set(group.keys())
312
+
313
+ if len(missing_keys) > 0:
314
+ raise ValueError(f"Keys {missing_keys} not found in file {filename}")
315
+
316
+ # Make sure to select the slice outside of the loop.
317
+ if self.is_volumetric:
318
+ if self.volume_sampling_size is not None:
319
+ volume_slice = self.select_random_sections_from_slice(
320
+ 0,
321
+ group["volume_mesh_centers"].shape[0],
322
+ self.volume_sampling_size,
323
+ )
324
+ else:
325
+ volume_slice = slice(0, group["volume_mesh_centers"].shape[0])
326
+
327
+ # This is a slower basic way to do this, to be improved:
328
+ data = {}
329
+ for key in self.keys_to_read:
330
+ if "volume" not in key:
331
+ data[key] = torch.from_numpy(group[key][:])
332
+ else:
333
+ data[key] = torch.from_numpy(group[key][volume_slice])
334
+
335
+ return self.fill_optional_keys(data)
336
+
337
+ def read_file_sharded(
338
+ self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh
339
+ ) -> tuple[dict[str, torch.Tensor], dict[str, dict]]:
340
+ """
341
+ Read a file and return a dictionary of tensors.
342
+ """
343
+
344
+ # We need the coordinates of this GPU:
345
+ this_rank = device_mesh.get_local_rank()
346
+ domain_size = dist.get_world_size(group=device_mesh.get_group())
347
+
348
+ group = zarr.open_group(filename, mode="r")
349
+
350
+ missing_keys = set(self.keys_to_read) - set(group.keys())
351
+
352
+ if len(missing_keys) > 0:
353
+ raise ValueError(f"Keys {missing_keys} not found in file {filename}")
354
+
355
+ data = {}
356
+ specs = {}
357
+ for key in self.keys_to_read:
358
+ # Open the array in zarr without reading it and get info:
359
+ zarr_array = group[key]
360
+ array_shape = zarr_array.shape
361
+ if array_shape == ():
362
+ # Read scalars from every rank and use replicate sharding
363
+ raw_data = torch.from_numpy(zarr_array[:])
364
+ placement = [
365
+ Replicate(),
366
+ ]
367
+ chunk_sizes = None
368
+ else:
369
+ target_dim = 0
370
+ if array_shape[target_dim] < domain_size:
371
+ # If the array is smaller than the number of ranks,
372
+ # again read and use replicate sharding:
373
+ raw_data = torch.from_numpy(zarr_array[:])
374
+ placement = [
375
+ Replicate(),
376
+ ]
377
+ chunk_sizes = None
378
+ else:
379
+ # Read partially from the data and use Shard(target_dim) sharding
380
+ chunk_start, chunk_stop, chunk_sizes = self._get_slice_boundaries(
381
+ zarr_array.shape, this_rank, domain_size
382
+ )
383
+ raw_data = torch.from_numpy(zarr_array[chunk_start:chunk_stop])
384
+ placement = [
385
+ Shard(target_dim),
386
+ ]
387
+
388
+ # Turn chunk sizes into a dict over mesh dim 0:
389
+ chunk_sizes = {0: chunk_sizes}
390
+
391
+ #
392
+ data[key] = raw_data
393
+ specs[key] = (placement, chunk_sizes)
394
+
395
+ # Patch in the optional keys:
396
+ data = self.fill_optional_keys(data)
397
+ for key in data.keys():
398
+ if key not in specs:
399
+ specs[key] = (
400
+ [
401
+ Replicate(),
402
+ ],
403
+ {},
404
+ )
405
+
406
+ return data, specs
407
+
408
+
409
+ if PV_AVAILABLE:
410
+
411
+ class VTKFileReader(BackendReader):
412
+ """
413
+ Reader for vtk files.
414
+ """
415
+
416
+ def __init__(
417
+ self,
418
+ keys_to_read: list[str] | None,
419
+ keys_to_read_if_available: dict[str, torch.Tensor] | None,
420
+ ) -> None:
421
+ super().__init__(keys_to_read, keys_to_read_if_available)
422
+
423
+ self.stl_file_keys = [
424
+ "stl_coordinates",
425
+ "stl_centers",
426
+ "stl_faces",
427
+ "stl_areas",
428
+ ]
429
+ self.vtp_file_keys = [
430
+ "surface_mesh_centers",
431
+ "surface_normals",
432
+ "surface_mesh_sizes",
433
+ "CpMeanTrim",
434
+ "pMeanTrim",
435
+ "wallShearStressMeanTrim",
436
+ ]
437
+ self.vtu_file_keys = [
438
+ "volume_mesh_centers",
439
+ "volume_fields",
440
+ ]
441
+
442
+ self.exclude_patterns = [
443
+ "single_solid",
444
+ ]
445
+
446
+ def get_file_name(self, dir_name: pathlib.Path, extension: str) -> pathlib.Path:
447
+ """
448
+ Get the file name for a given directory and extension.
449
+ """
450
+ # >>> matches = [p for p in list(dir_name.iterdir()) if p.suffix == ".stl" and not any(pattern in p.name for pattern in exclude_patterns)]
451
+ matches = [
452
+ p
453
+ for p in dir_name.iterdir()
454
+ if p.suffix == extension
455
+ and not any(pattern in p.name for pattern in self.exclude_patterns)
456
+ ]
457
+ if len(matches) == 0:
458
+ raise FileNotFoundError(f"No {extension} files found in {dir_name}")
459
+ fname = matches[0]
460
+ return dir_name / fname
461
+
462
+ def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
463
+ """
464
+ Read a set of files and return a dictionary of tensors.
465
+ """
466
+
467
+ # This reader attempts to only read what's necessary, and not more.
468
+ # So, the functions that do the reading are each "one file" functions
469
+ # and we open them for processing only when necessary.
470
+
471
+ return_data = {}
472
+
473
+ # Note that this reader is, already, running in a background thread.
474
+ # It may or may not help to further thread these calls.
475
+ if any(key in self.stl_file_keys for key in self.keys_to_read):
476
+ stl_path = self.get_file_name(filename, ".stl")
477
+ stl_data = self.read_data_from_stl(stl_path)
478
+ return_data.update(stl_data)
479
+ if any(key in self.vtp_file_keys for key in self.keys_to_read):
480
+ vtp_path = self.get_file_name(filename, ".vtp")
481
+ vtp_data = self.read_data_from_vtp(vtp_path)
482
+ return_data.update(vtp_data)
483
+ if any(key in self.vtu_file_keys for key in self.keys_to_read):
484
+ raise NotImplementedError("VTU files are not supported yet.")
485
+
486
+ return self.fill_optional_keys(return_data)
487
+
488
+ def read_file_sharded(
489
+ self, filename: pathlib.Path, parallel_rank: int, parallel_size: int
490
+ ) -> tuple[dict[str, torch.Tensor], dict[str, ShardTensorSpec]]:
491
+ """
492
+ Read a file and return a dictionary of tensors.
493
+ """
494
+ raise NotImplementedError("Not implemented yet.")
495
+
496
+ def read_data_from_stl(
497
+ self,
498
+ stl_path: str,
499
+ ) -> dict:
500
+ """
501
+ Reads surface mesh data from an STL file and prepares a batch dictionary for inference.
502
+
503
+ Args:
504
+ stl_path (str): Path to the STL file.
505
+
506
+ Returns:
507
+ dict: Batch dictionary with mesh faces and coordinates as torch tensors.
508
+ """
509
+
510
+ mesh = pv.read(stl_path)
511
+
512
+ batch = {}
513
+
514
+ faces = mesh.faces.reshape(-1, 4)
515
+ faces = faces[:, 1:]
516
+
517
+ batch["stl_faces"] = faces.flatten()
518
+
519
+ batch["stl_coordinates"] = mesh.points
520
+ batch["surface_normals"] = mesh.cell_normals
521
+
522
+ batch = {k: torch.from_numpy(v) for k, v in batch.items()}
523
+
524
+ return batch
525
+
526
+ def read_data_from_vtp(self, vtp_path: str) -> dict:
527
+ """
528
+ Read vtp file from a file
529
+ """
530
+
531
+ raise NotImplementedError("Not implemented yet.")
532
+
533
+ def set_volume_sampling_size(self, volume_sampling_size: int):
534
+ """
535
+ This is not supported for vtk files.
536
+ """
537
+ raise NotImplementedError(
538
+ "volume sampling directly from disk is not supported for vtk files."
539
+ )
540
+
541
+
542
+ if TENSORSTORE_AVAILABLE:
543
+
544
+ class TensorStoreZarrReader(BackendReader):
545
+ """
546
+ Reader for tensorstore zarr files.
547
+ """
548
+
549
+ def __init__(
550
+ self,
551
+ keys_to_read: list[str] | None,
552
+ keys_to_read_if_available: dict[str, torch.Tensor] | None,
553
+ cache_bytes_limit: int = 10_000_000,
554
+ data_copy_concurrency: int = 72,
555
+ file_io_concurrency: int = 72,
556
+ ) -> None:
557
+ super().__init__(keys_to_read, keys_to_read_if_available)
558
+
559
+ self.spec_template = {
560
+ "driver": "auto",
561
+ "kvstore": {
562
+ "driver": "file",
563
+ "path": None,
564
+ },
565
+ }
566
+
567
+ self.context = ts.Context(
568
+ {
569
+ "cache_pool": {"total_bytes_limit": cache_bytes_limit},
570
+ "data_copy_concurrency": {"limit": data_copy_concurrency},
571
+ "file_io_concurrency": {"limit": file_io_concurrency},
572
+ }
573
+ )
574
+
575
+ def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
576
+ """
577
+ Read a file and return a dictionary of tensors.
578
+ """
579
+
580
+ # Trigger an async open of each data item:
581
+ read_futures = {}
582
+ for key in self.keys_to_read:
583
+ spec = self.spec_template.copy()
584
+ spec["kvstore"]["path"] = str(filename) + "/" + str(key)
585
+
586
+ read_futures[key] = ts.open(
587
+ spec, create=False, open=True, context=self.context
588
+ )
589
+
590
+ # Wait for all the opens to conclude:
591
+ read_futures = {
592
+ key: read_futures[key].result() for key in read_futures.keys()
593
+ }
594
+
595
+ # Make sure to select the slice outside of the loop.
596
+ # We need
597
+ if self.is_volumetric:
598
+ if self.volume_sampling_size is not None:
599
+ volume_slice = self.select_random_sections_from_slice(
600
+ 0,
601
+ read_futures["volume_mesh_centers"].shape[0],
602
+ self.volume_sampling_size,
603
+ )
604
+ else:
605
+ volume_slice = slice(
606
+ 0, read_futures["volume_mesh_centers"].shape[0]
607
+ )
608
+
609
+ # Trigger an async read of each data item:
610
+ # (Each item will be a numpy ndarray after this:)
611
+ tensor_futures = {}
612
+ for key in self.keys_to_read:
613
+ if "volume" not in key:
614
+ tensor_futures[key] = read_futures[key].read()
615
+ # For the volume data, read the slice:
616
+ else:
617
+ tensor_futures[key] = read_futures[key][volume_slice].read()
618
+
619
+ # Convert them to torch tensors:
620
+ # (make sure to block for the result)
621
+ data = {
622
+ key: torch.as_tensor(tensor_futures[key].result(), dtype=torch.float32)
623
+ for key in self.keys_to_read
624
+ }
625
+
626
+ return self.fill_optional_keys(data)
627
+
628
+ def read_file_sharded(
629
+ self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh
630
+ ) -> tuple[dict[str, torch.Tensor], dict[str, dict]]:
631
+ """
632
+ Read a file and return a dictionary of tensors.
633
+ """
634
+
635
+ # We need the coordinates of this GPU:
636
+ this_rank = device_mesh.get_local_rank()
637
+ domain_size = dist.get_world_size(group=device_mesh.get_group())
638
+
639
+ # This pulls a list of store objects in tensorstore:
640
+ stores = {}
641
+ for key in self.keys_to_read:
642
+ spec = self.spec_template.copy()
643
+ spec["kvstore"]["path"] = str(filename) + "/" + str(key)
644
+
645
+ stores[key] = ts.open(
646
+ spec, create=False, open=True, context=self.context
647
+ )
648
+
649
+ stores = {key: stores[key].result() for key in stores.keys()}
650
+
651
+ data = {}
652
+ specs = {}
653
+ for key in self.keys_to_read:
654
+ # Open the array in zarr without reading it and get info:
655
+ store = stores[key]
656
+ array_shape = store.shape
657
+ if array_shape == ():
658
+ # Read scalars from every rank and use replicate sharding
659
+ _slice = np.s_[:]
660
+ # raw_data = torch.from_numpy(store[:])
661
+ placement = [
662
+ Replicate(),
663
+ ]
664
+ chunk_sizes = None
665
+ else:
666
+ target_dim = 0
667
+ if array_shape[target_dim] < domain_size:
668
+ # If the array is smaller than the number of ranks,
669
+ # again read and use replicate sharding:
670
+ _slice = np.s_[:]
671
+ # raw_data = torch.from_numpy(store[:])
672
+ placement = [
673
+ Replicate(),
674
+ ]
675
+ chunk_sizes = None
676
+ else:
677
+ # Read partially from the data and use Shard(target_dim) sharding
678
+ chunk_start, chunk_stop, chunk_sizes = (
679
+ self._get_slice_boundaries(
680
+ store.shape, this_rank, domain_size
681
+ )
682
+ )
683
+ _slice = np.s_[chunk_start:chunk_stop]
684
+ # raw_data = torch.from_numpy(zarr_array[chunk_start:chunk_stop])
685
+ placement = [
686
+ Shard(target_dim),
687
+ ]
688
+
689
+ # Turn chunk sizes into a dict over mesh dim 0:
690
+ chunk_sizes = {0: chunk_sizes}
691
+
692
+ # Trigger the reads as async:
693
+ data[key] = store[_slice].read()
694
+ specs[key] = (placement, chunk_sizes)
695
+
696
+ # Finally, await the full data read:
697
+ for key in self.keys_to_read:
698
+ data[key] = torch.as_tensor(data[key].result())
699
+
700
+ # Patch in the optional keys:
701
+ data = self.fill_optional_keys(data)
702
+ for key in data.keys():
703
+ if key not in specs:
704
+ specs[key] = (
705
+ [
706
+ Replicate(),
707
+ ],
708
+ {},
709
+ )
710
+
711
+ return data, specs
712
+
713
+ else:
714
+
715
+ class TensorStoreZarrReader(BackendReader):
716
+ """
717
+ Null reader for tensorstore zarr files.
718
+ """
719
+
720
+ def __init__(
721
+ self,
722
+ keys_to_read: list[str] | None,
723
+ keys_to_read_if_available: dict[str, torch.Tensor] | None,
724
+ ) -> None:
725
+ # Raise an exception on construction if we get here:
726
+ raise NotImplementedError(
727
+ "TensorStoreZarrReader is not available without tensorstore. `pip install tensorstore`."
728
+ )
729
+
730
+
731
+ def is_vtk_directory(file: pathlib.Path) -> bool:
732
+ """
733
+ Check if a file is a vtk directory.
734
+ """
735
+ return file.is_dir() and all(
736
+ [f.suffix in [".vtp", ".stl", ".vtu", ".vtk", ".csv"] for f in file.iterdir()]
737
+ )
738
+
739
+
740
+ class CAEDataset:
741
+ """
742
+ Dataset reader for DrivaerML and similar datasets. In general, this
743
+ dataset supports reading dictionary-like data, and returning a
744
+ dictionary of torch.Tensor objects.
745
+
746
+ When constructed, the user must pass a directory of data examples.
747
+ The dataset will inspect the folder, identify all children, and decide:
748
+ - If every file is a directory ending in .zarr, the zarr reader is used.
749
+ - If every file is .npy, the .npy reader is used.
750
+ - If every file is .npz, the .npz reader is used.
751
+ - If every file is a directory without an extension, it's assumed to be .stl/.vtp/.vtu
752
+
753
+ The user can optionally force one path with a parameter.
754
+
755
+ The flow of this dataset is:
756
+ - Load data from file, using a thread.
757
+ - Each individual file reading tool may or may not have it's own threading
758
+ or multi processing enabled. That's up to it. This just does async
759
+ loading.
760
+ - Data should come out of the readers in dict{str : torch.Tensor} format
761
+ - The data is transferred from CPU to GPU in a separate stream.
762
+
763
+ Users can call __getitem__(i), which will trigger the pipeline,
764
+ or they can call `preload(i)`, which will start the pipeline for index `i`.
765
+ Subsequent calls to `__getitem__(i)` should be faster since the IO is in
766
+ progress or complete.
767
+
768
+ Using the `__iter__` functionality will automatically enable preloading.
769
+
770
+ """
771
+
772
+ def __init__(
773
+ self,
774
+ data_dir: str | pathlib.Path,
775
+ keys_to_read: list[str] | None,
776
+ keys_to_read_if_available: dict[str, torch.Tensor] | None,
777
+ output_device: torch.device,
778
+ preload_depth: int = 2,
779
+ pin_memory: bool = False,
780
+ device_mesh: torch.distributed.DeviceMesh | None = None,
781
+ placements: dict[str, torch.distributed.tensor.Placement] | None = None,
782
+ consumer_stream: torch.cuda.Stream | None = None,
783
+ ) -> None:
784
+ if isinstance(data_dir, str):
785
+ data_dir = pathlib.Path(data_dir)
786
+
787
+ # Verify the data directory exists:
788
+ if not data_dir.exists():
789
+ raise FileNotFoundError(f"Data directory {data_dir} does not exist")
790
+
791
+ # Verify the data directory is a directory:
792
+ if not data_dir.is_dir():
793
+ raise NotADirectoryError(f"Data directory {data_dir} is not a directory")
794
+
795
+ self._keys_to_read = keys_to_read
796
+
797
+ # Make sure the optional keys are on the right device:
798
+ self._keys_to_read_if_available = {
799
+ k: v.to(output_device) for k, v in keys_to_read_if_available.items()
800
+ }
801
+
802
+ self.file_reader, self._filenames = self._infer_file_type_and_filenames(
803
+ data_dir
804
+ )
805
+
806
+ self.pin_memory = pin_memory
807
+
808
+ # Check the file names; some can be read well in parallel, while others
809
+ # are not parallelizable.
810
+
811
+ self._length = len(self._filenames)
812
+
813
+ self.output_device = output_device
814
+ if output_device.type == "cuda":
815
+ self._data_loader_stream = torch.cuda.Stream()
816
+ else:
817
+ self._data_loader_stream = None
818
+
819
+ self.device_mesh = device_mesh
820
+ self.placements = placements
821
+ # This tracks global tensor info
822
+ # so we can convert to ShardTensor at the right time.
823
+ self.shard_spec = {}
824
+
825
+ if self.device_mesh is not None:
826
+ if self.device_mesh.ndim != 1:
827
+ raise ValueError("Device mesh must be one dimensional")
828
+
829
+ # This is thread storage for data preloading:
830
+ self._preload_queue = {}
831
+ self._transfer_events = {}
832
+ self.preload_depth = preload_depth
833
+ self.preload_executor = ThreadPoolExecutor(max_workers=max(1, preload_depth))
834
+
835
+ if consumer_stream is None and self.output_device.type == "cuda":
836
+ consumer_stream = torch.cuda.current_stream()
837
+
838
+ self.consumer_stream = consumer_stream
839
+
840
+ def set_indices(self, indices: list[int]):
841
+ """
842
+ Set the indices for the dataset for this epoch.
843
+ """
844
+
845
+ # TODO - this needs to block while anything is in the preprocess queue.
846
+
847
+ self.indices = indices
848
+
849
+ def idx_to_index(self, idx):
850
+ if hasattr(self, "indices"):
851
+ return self.indices[idx]
852
+
853
+ return idx
854
+
855
+ def _infer_file_type_and_filenames(
856
+ self, data_dir: pathlib.Path
857
+ ) -> tuple[str, list[str]]:
858
+ """
859
+ Infer the file type and filenames from the data directory.
860
+ """
861
+
862
+ # We validated the directory exists and is a directory already.
863
+
864
+ # List the files:
865
+ files = list(data_dir.iterdir())
866
+
867
+ # Initialize the file reader object
868
+ # Note that for some of these, they could be functions
869
+ # But others benefit from having a state, so we use classes:
870
+
871
+ if all(file.suffix == ".npy" for file in files):
872
+ file_reader = NpyFileReader(
873
+ self._keys_to_read, self._keys_to_read_if_available
874
+ )
875
+ return file_reader, files
876
+ elif all(file.suffix == ".npz" for file in files):
877
+ file_reader = NpzFileReader(
878
+ self._keys_to_read, self._keys_to_read_if_available
879
+ )
880
+ return file_reader, files
881
+ elif all(file.suffix == ".zarr" and file.is_dir() for file in files):
882
+ if TENSORSTORE_AVAILABLE:
883
+ file_reader = TensorStoreZarrReader(
884
+ self._keys_to_read, self._keys_to_read_if_available
885
+ )
886
+ else:
887
+ file_reader = ZarrFileReader(
888
+ self._keys_to_read, self._keys_to_read_if_available
889
+ )
890
+ return file_reader, files
891
+ elif all(is_vtk_directory(file) for file in files):
892
+ file_reader = VTKFileReader(
893
+ self._keys_to_read, self._keys_to_read_if_available
894
+ )
895
+ return file_reader, files
896
+ # Each "file" here is a directory of .vtp, stl, etc.
897
+ else:
898
+ # TODO - support folders of stl, vtp, vtu.
899
+ raise ValueError(f"Unsupported file type: {files[0]}")
900
+
901
+ def _move_to_gpu(
902
+ self, data: dict[str, torch.Tensor], idx: int
903
+ ) -> dict[str, torch.Tensor]:
904
+ """Convert numpy arrays to torch tensors and move to GPU if available.
905
+
906
+ Args:
907
+ data: Dictionary of key to torch tensor.
908
+
909
+ Returns:
910
+ Dictionary of key to torch tensor on GPU if available.
911
+ """
912
+
913
+ if self.output_device.type != "cuda":
914
+ return data
915
+
916
+ result = {}
917
+
918
+ with torch.cuda.stream(self._data_loader_stream):
919
+ for key in data.keys():
920
+ if data[key].device == self.output_device:
921
+ result[key] = data[key]
922
+ continue
923
+ if self.pin_memory:
924
+ result[key] = (
925
+ data[key].pin_memory().to(self.output_device, non_blocking=True)
926
+ )
927
+ else:
928
+ result[key] = data[key].to(self.output_device, non_blocking=True)
929
+ # Move to GPU if available
930
+ # result[key] = data[key].to(self.output_device, non_blocking=True)
931
+ result[key].record_stream(self.consumer_stream)
932
+
933
+ # Mark the consumer stream:
934
+ transfer_event = torch.cuda.Event()
935
+ transfer_event.record(self._data_loader_stream)
936
+ self._transfer_events[idx] = transfer_event
937
+
938
+ return result
939
+
940
+ def _convert_to_shard_tensors(
941
+ self,
942
+ tensors: dict[str, torch.Tensor],
943
+ filename: str,
944
+ ) -> dict[str, ShardTensor]:
945
+ """Convert tensors to ShardTensor objects for distributed training.
946
+
947
+ Args:
948
+ tensors: Dictionary of key to torch tensor.
949
+
950
+ Returns:
951
+ Dictionary of key to torch tensor or ShardTensor.
952
+ """
953
+
954
+ if self.device_mesh is None:
955
+ return tensors
956
+
957
+ spec_dict = self.shard_spec.pop(filename)
958
+ result = {}
959
+ for key in tensors.keys():
960
+ placement, chunk_sizes = spec_dict[key]
961
+
962
+ result[key] = ShardTensor.from_local(
963
+ local_tensor=tensors[key],
964
+ device_mesh=self.device_mesh,
965
+ placements=placement,
966
+ sharding_shapes=chunk_sizes,
967
+ )
968
+
969
+ return result
970
+
971
+ def preload(self, idx: int) -> None:
972
+ """
973
+ Asynchronously preload the data for the given index (up to CPU, not GPU).
974
+ Only one preload operation is supported at a time.
975
+
976
+ Args:
977
+ idx: Index of the sample to preload.
978
+ """
979
+ if idx in self._preload_queue:
980
+ # Skip items that are already in the queue
981
+ return
982
+
983
+ def _preload_worker():
984
+ data = self._read_file(self._filenames[idx])
985
+ if "stl_faces" in data:
986
+ data["stl_faces"] = data["stl_faces"].to(torch.int32)
987
+ # Convert to torch tensors
988
+ return self._move_to_gpu(data, idx)
989
+
990
+ self._preload_queue[idx] = self.preload_executor.submit(_preload_worker)
991
+
992
+ def get_preloaded(self, idx: int) -> dict[str, torch.Tensor] | None:
993
+ """
994
+ Retrieve the preloaded data (blocking if not ready).
995
+
996
+ Returns:
997
+ (idx, data) tuple where data is a dictionary of key to numpy array or torch tensor.
998
+
999
+ Raises:
1000
+ RuntimeError: If no preload is in progress.
1001
+ Exception: If preload failed.
1002
+ """
1003
+
1004
+ if idx not in self._preload_queue:
1005
+ return None
1006
+
1007
+ result = self._preload_queue[
1008
+ idx
1009
+ ].result() # This will block until the result is ready
1010
+ self._preload_queue.pop(idx) # Clear the future after getting the result
1011
+
1012
+ return result
1013
+
1014
+ def __iter__(self):
1015
+ # When starting the iterator method, start loading the data
1016
+ # at idx = 0, idx = 1
1017
+ # Start preprocessing at idx = 0, when the load completes
1018
+
1019
+ self.i = 0
1020
+
1021
+ N = len(self.indices) if hasattr(self, "indices") else len(self)
1022
+ for i in range(self.preload_depth):
1023
+ # Trigger the dataset to start loading index 0:
1024
+ if N > i + 1:
1025
+ self.preload(self.idx_to_index(self.i + i))
1026
+
1027
+ return self
1028
+
1029
+ def __next__(self):
1030
+ N = len(self.indices) if hasattr(self, "indices") else len(self._filenames)
1031
+
1032
+ # Iteration bounds are based on the counter, not the random-access index
1033
+ if self.i >= N:
1034
+ self.i = 0
1035
+ raise StopIteration
1036
+
1037
+ # This is the file random access index
1038
+ target_index = self.idx_to_index(self.i)
1039
+
1040
+ # Before returning, put the next two target indexes into the queue:
1041
+ for preload_i in range(self.preload_depth):
1042
+ next_iteration_index = self.i + preload_i + 1
1043
+ if N > next_iteration_index:
1044
+ preload_idx = self.idx_to_index(next_iteration_index)
1045
+ self.preload(preload_idx)
1046
+
1047
+ # Send up the random-access data:
1048
+ data = self.__getitem__(target_index)
1049
+
1050
+ self.i += 1
1051
+
1052
+ return data
1053
+
1054
+ def __len__(self):
1055
+ return len(self._filenames)
1056
+
1057
+ def _read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
1058
+ """
1059
+ Read a file and return a dictionary of tensors.
1060
+ """
1061
+ if self.device_mesh is not None:
1062
+ tensor_dict, spec_dict = self.file_reader.read_file_sharded(
1063
+ filename, self.device_mesh
1064
+ )
1065
+ self.shard_spec[filename] = spec_dict
1066
+ return tensor_dict
1067
+ else:
1068
+ return self.file_reader.read_file(filename)
1069
+
1070
+ def __getitem__(self, idx: int) -> dict[str, torch.Tensor | ShardTensor]:
1071
+ """
1072
+ Get a data sample.
1073
+
1074
+ Flow is:
1075
+ - Read data, or get preloaded data if this idx is preloaded.
1076
+ - Move data to GPU, if needed.
1077
+ - Preloading data will move to GPU if it can.
1078
+ - If domain parallelism is enabled, convert to ShardTensors.
1079
+ - Return
1080
+
1081
+ Args:
1082
+ idx: Index of the sample to retrieve
1083
+
1084
+ Returns:
1085
+ Dictionary containing tensors/ShardTensors for the requested data
1086
+ """
1087
+
1088
+ if idx >= len(self._filenames):
1089
+ raise IndexError(
1090
+ f"Index {idx} out of range for dataset of size {len(self._filenames)}"
1091
+ )
1092
+
1093
+ # Attempt to get preloaded data:
1094
+ data = self.get_preloaded(idx)
1095
+ if data is None:
1096
+ # Read data from zarr file
1097
+ data = self._read_file(self._filenames[idx])
1098
+ data = self._move_to_gpu(data, idx)
1099
+
1100
+ # This blocks until the preprocessing has transferred to GPU
1101
+ if idx in self._transfer_events:
1102
+ self.consumer_stream.wait_event(self._transfer_events[idx])
1103
+ self._transfer_events.pop(idx)
1104
+
1105
+ # Convert to ShardTensors if using domain parallelism
1106
+ if self.device_mesh is not None:
1107
+ data = self._convert_to_shard_tensors(data, self._filenames[idx])
1108
+
1109
+ return data
1110
+
1111
+ def set_volume_sampling_size(self, volume_sampling_size: int):
1112
+ """
1113
+ Set the volume sampling size. When set, the readers will
1114
+ assume the volumetric data is shuffled on disk and read only
1115
+ contiguous chunks of the data up to the sampling size.
1116
+
1117
+ Args:
1118
+ volume_sampling_size: The total size of the volume sampling.
1119
+ """
1120
+ self.file_reader.set_volume_sampling_size(volume_sampling_size)
1121
+
1122
+ def close(self):
1123
+ """
1124
+ Explicitly close the dataset and cleanup resources, including the ThreadPoolExecutor.
1125
+ """
1126
+ if hasattr(self, "preload_executor") and self.preload_executor is not None:
1127
+ self.preload_executor.shutdown(wait=True)
1128
+ self.preload_executor = None
1129
+
1130
+ def __del__(self):
1131
+ """
1132
+ Cleanup resources when the dataset is destroyed.
1133
+ """
1134
+ self.close()
1135
+
1136
+
1137
+ def compute_mean_std_min_max(
1138
+ dataset: CAEDataset, field_keys: list[str], max_samples: int = 20
1139
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1140
+ """
1141
+ Compute the mean, standard deviation, minimum, and maximum for a specified field
1142
+ across all samples in a dataset.
1143
+
1144
+ Uses a numerically stable online algorithm for mean and variance.
1145
+
1146
+ Args:
1147
+ dataset (CAEDataset): The dataset to process.
1148
+ field_key (str): The key for the field to normalize.
1149
+
1150
+ Returns:
1151
+ tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1152
+ mean, std, min, max tensors for the field.
1153
+ """
1154
+ N = {}
1155
+ mean = {}
1156
+ M2 = {} # Sum of squares of differences from the current mean
1157
+ min_val = {}
1158
+ max_val = {}
1159
+
1160
+ # Read the first data item to get the shapes:
1161
+ example_data = dataset[0]
1162
+
1163
+ # Create placeholders for the accumulators:
1164
+ for key in field_keys:
1165
+ N[key] = torch.zeros(1, dtype=torch.int64, device=example_data[key].device)
1166
+ mean[key] = torch.zeros(
1167
+ example_data[key].shape[-1],
1168
+ device=example_data[key].device,
1169
+ dtype=torch.float64,
1170
+ )
1171
+ M2[key] = torch.zeros(
1172
+ example_data[key].shape[-1],
1173
+ device=example_data[key].device,
1174
+ dtype=torch.float64,
1175
+ )
1176
+ min_val[key] = torch.full(
1177
+ (example_data[key].shape[-1],),
1178
+ float("inf"),
1179
+ device=example_data[key].device,
1180
+ )
1181
+ max_val[key] = torch.full(
1182
+ (example_data[key].shape[-1],),
1183
+ float("-inf"),
1184
+ device=example_data[key].device,
1185
+ )
1186
+
1187
+ global_start = time.perf_counter()
1188
+ start = time.perf_counter()
1189
+ data_list = np.arange(len(dataset))
1190
+ np.random.shuffle(data_list)
1191
+ for i, j in enumerate(data_list):
1192
+ data = dataset[j]
1193
+ if i >= max_samples:
1194
+ break
1195
+
1196
+ for field_key in field_keys:
1197
+ field_data = data[field_key]
1198
+
1199
+ # Compute batch statistics
1200
+ batch_mean = field_data.mean(axis=(0))
1201
+ batch_M2 = ((field_data - batch_mean) ** 2).sum(axis=(0))
1202
+ batch_n = field_data.shape[0]
1203
+
1204
+ # Update running mean and M2 (Welford's algorithm)
1205
+ delta = batch_mean - mean[field_key]
1206
+ N[field_key] += batch_n # batch_n should also be torch.int64
1207
+ mean[field_key] = mean[field_key] + delta * (batch_n / N[field_key])
1208
+ M2[field_key] = (
1209
+ M2[field_key]
1210
+ + batch_M2
1211
+ + delta**2 * (batch_n * N[field_key]) / N[field_key]
1212
+ )
1213
+
1214
+ end = time.perf_counter()
1215
+ iteration_time = end - start
1216
+ print(
1217
+ f"on iteration {i} of {max_samples}, time: {iteration_time:.2f} seconds for file: {j}"
1218
+ )
1219
+ start = time.perf_counter()
1220
+
1221
+ var = {}
1222
+ std = {}
1223
+ for field_key in field_keys:
1224
+ var[field_key] = M2[field_key] / (
1225
+ N[field_key].item() - 1
1226
+ ) # Convert N to Python int for division
1227
+ std[field_key] = torch.sqrt(var[field_key])
1228
+
1229
+ start = time.perf_counter()
1230
+ for i, j in enumerate(data_list):
1231
+ data = dataset[j]
1232
+ if i >= max_samples:
1233
+ break
1234
+
1235
+ for field_key in field_keys:
1236
+ field_data = data[field_key]
1237
+
1238
+ batch_n = field_data.shape[0]
1239
+
1240
+ # # Update min/max
1241
+
1242
+ mean_sample = mean[field_key]
1243
+ std_sample = std[field_key]
1244
+ mask = torch.ones_like(field_data, dtype=torch.bool)
1245
+ for v in range(field_data.shape[-1]):
1246
+ outliers = (field_data[:, v] < mean_sample[v] - 9.0 * std_sample[v]) | (
1247
+ field_data[:, v] > mean_sample[v] + 9.0 * std_sample[v]
1248
+ )
1249
+ mask[:, v] = ~outliers
1250
+
1251
+ batch_min = []
1252
+ batch_max = []
1253
+ for v in range(field_data.shape[-1]):
1254
+ batch_min.append(field_data[mask[:, v], v].min())
1255
+ batch_max.append(field_data[mask[:, v], v].max())
1256
+
1257
+ batch_min = torch.stack(batch_min)
1258
+ batch_max = torch.stack(batch_max)
1259
+
1260
+ min_val[field_key] = torch.minimum(min_val[field_key], batch_min)
1261
+ max_val[field_key] = torch.maximum(max_val[field_key], batch_max)
1262
+
1263
+ end = time.perf_counter()
1264
+ iteration_time = end - start
1265
+ print(
1266
+ f"on iteration {i} of {max_samples}, time: {iteration_time:.2f} seconds for file: {j}"
1267
+ )
1268
+ start = time.perf_counter()
1269
+
1270
+ global_end = time.perf_counter()
1271
+ global_time = global_end - global_start
1272
+
1273
+ print(f"Total time: {global_time:.2f} seconds for {max_samples} samples")
1274
+
1275
+ return mean, std, min_val, max_val
physics_mcp/source/physicsnemo/datapipes/cae/domino_datapipe.py ADDED
@@ -0,0 +1,1334 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ """
18
+ This code provides the datapipe for reading the processed npy files,
19
+ generating multi-res grids, calculating signed distance fields,
20
+ sampling random points in the volume and on surface,
21
+ normalizing fields and returning the output tensors as a dictionary.
22
+
23
+ This datapipe also non-dimensionalizes the fields, so the order in which the variables should
24
+ be fixed: velocity, pressure, turbulent viscosity for volume variables and
25
+ pressure, wall-shear-stress for surface variables. The different parameters such as
26
+ variable names, domain resolution, sampling size etc. are configurable in config.yaml.
27
+ """
28
+
29
+ from dataclasses import dataclass
30
+ from pathlib import Path
31
+ from typing import Iterable, Literal, Optional, Protocol, Sequence, Union
32
+
33
+ import numpy as np
34
+ import torch
35
+ import torch.cuda.nvtx as nvtx
36
+ import torch.distributed as dist
37
+ from omegaconf import DictConfig
38
+ from torch.distributed.tensor.placement_types import Replicate, Shard
39
+ from torch.utils.data import Dataset
40
+
41
+ from physicsnemo.datapipes.cae.cae_dataset import (
42
+ CAEDataset,
43
+ compute_mean_std_min_max,
44
+ )
45
+ from physicsnemo.distributed import DistributedManager
46
+ from physicsnemo.distributed.shard_tensor import ShardTensor, scatter_tensor
47
+ from physicsnemo.utils.domino.utils import (
48
+ calculate_center_of_mass,
49
+ create_grid,
50
+ get_filenames,
51
+ normalize,
52
+ pad,
53
+ shuffle_array,
54
+ standardize,
55
+ unnormalize,
56
+ unstandardize,
57
+ )
58
+ from physicsnemo.utils.neighbors import knn
59
+ from physicsnemo.utils.profiling import profile
60
+ from physicsnemo.utils.sdf import signed_distance_field
61
+
62
+
63
+ class BoundingBox(Protocol):
64
+ """
65
+ Type definition for the required format of bounding box dimensions.
66
+ """
67
+
68
+ min: Sequence
69
+ max: Sequence
70
+
71
+
72
+ @dataclass
73
+ class DoMINODataConfig:
74
+ """Configuration for DoMINO dataset processing pipeline.
75
+
76
+ Attributes:
77
+ data_path: Path to the dataset to load.
78
+ phase: Which phase of data to load ("train", "val", or "test").
79
+ surface_variables: (Surface specific) Names of surface variables.
80
+ surface_points_sample: (Surface specific) Number of surface points to sample per batch.
81
+ num_surface_neighbors: (Surface specific) Number of surface neighbors to consider for nearest neighbors approach.
82
+ surface_sampling_algorithm: (Surface specific) Algorithm to use for surface sampling ("area_weighted" or "random").
83
+ surface_factors: (Surface specific) Non-dimensionalization factors for surface variables.
84
+ If set, and scaling_type is:
85
+ - min_max_scaling -> rescale surface_fields to the min/max set here
86
+ - mean_std_scaling -> rescale surface_fields to the mean and std set here.
87
+ bounding_box_dims_surf: (Surface specific) Dimensions of bounding box. Must be an object with min/max
88
+ attributes that are arraylike.
89
+ volume_variables: (Volume specific) Names of volume variables.
90
+ volume_points_sample: (Volume specific) Number of volume points to sample per batch.
91
+ volume_sample_from_disk: (Volume specific) If the volume data is in a shuffled state on disk,
92
+ read contiguous chunks of the data rather than the entire volume data. This greatly
93
+ accelerates IO in bandwidth limited systems or when the volumetric data is very large.
94
+ volume_factors: (Volume specific) Non-dimensionalization factors for volume variables scaling.
95
+ If set, and scaling_type is:
96
+ - min_max_scaling -> rescale volume_fields to the min/max set here
97
+ - mean_std_scaling -> rescale volume_fields to the mean and std set here.
98
+ bounding_box_dims: (Volume specific) Dimensions of bounding box. Must be an object with min/max
99
+ attributes that are arraylike.
100
+ grid_resolution: Resolution of the latent grid.
101
+ normalize_coordinates: Whether to normalize coordinates based on min/max values.
102
+ For surfaces: uses s_min/s_max, defined from:
103
+ - Surface bounding box, if defined.
104
+ - Min/max of the stl_vertices
105
+ For volumes: uses c_min/c_max, defined from:
106
+ - Volume bounding_box if defined,
107
+ - 1.5x s_min/max otherwise, except c_min[2] = s_min[2] in this case
108
+ sample_in_bbox: Whether to sample points in a specified bounding box.
109
+ Uses the same min/max points as coordinate normalization.
110
+ Only performed if compute_scaling_factors is false.
111
+ sampling: Whether to downsample the full resolution mesh to fit in GPU memory.
112
+ Surface and volume sampling points are configured separately as:
113
+ - surface.points_sample
114
+ - volume.points_sample
115
+ geom_points_sample: Number of STL points sampled per batch.
116
+ Independent of volume.points_sample and surface.points_sample.
117
+ scaling_type: Scaling type for volume variables.
118
+ If used, will rescale the volume_fields and surface fields outputs.
119
+ Requires volume.factor and surface.factor to be set.
120
+ compute_scaling_factors: Whether to compute scaling factors.
121
+ Not available if caching.
122
+ Many preprocessing pieces are disabled if computing scaling factors.
123
+ caching: Whether this is for caching or serving.
124
+ deterministic: Whether to use a deterministic seed for sampling and random numbers.
125
+ gpu_preprocessing: Whether to do preprocessing on the GPU (False for CPU).
126
+ gpu_output: Whether to return output on the GPU as cupy arrays.
127
+ If False, returns numpy arrays.
128
+ You might choose gpu_preprocessing=True and gpu_output=False if caching.
129
+ shard_grid: Whether to shard the grid across GPUs for domain parallelism.
130
+ Applies to the surf_grid and similiar tensors.
131
+ shard_points: Whether to shard the points across GPUs for domain parallelism.
132
+ Applies to the volume_fields/surface_fields and similiar tensors.
133
+ """
134
+
135
+ data_path: Path | None
136
+ phase: Literal["train", "val", "test"]
137
+
138
+ # Surface-specific variables:
139
+ surface_variables: Optional[Sequence] = ("pMean", "wallShearStress")
140
+ surface_points_sample: int = 1024
141
+ num_surface_neighbors: int = 11
142
+ surface_sampling_algorithm: str = Literal["area_weighted", "random"]
143
+ surface_factors: Optional[Sequence] = None
144
+ bounding_box_dims_surf: Optional[Union[BoundingBox, Sequence]] = None
145
+
146
+ # Volume specific variables:
147
+ volume_variables: Optional[Sequence] = ("UMean", "pMean")
148
+ volume_points_sample: int = 1024
149
+ volume_sample_from_disk: bool = False
150
+ volume_factors: Optional[Sequence] = None
151
+ bounding_box_dims: Optional[Union[BoundingBox, Sequence]] = None
152
+
153
+ grid_resolution: Sequence = (256, 96, 64)
154
+ normalize_coordinates: bool = False
155
+ sample_in_bbox: bool = False
156
+ sampling: bool = False
157
+ geom_points_sample: int = 300000
158
+ scaling_type: Optional[Literal["min_max_scaling", "mean_std_scaling"]] = None
159
+ compute_scaling_factors: bool = False
160
+ caching: bool = False
161
+ deterministic: bool = False
162
+ gpu_preprocessing: bool = True
163
+ gpu_output: bool = True
164
+
165
+ shard_grid: bool = False
166
+ shard_points: bool = False
167
+
168
+ def __post_init__(self):
169
+ if self.data_path is not None:
170
+ # Ensure data_path is a Path object:
171
+ if isinstance(self.data_path, str):
172
+ self.data_path = Path(self.data_path)
173
+ self.data_path = self.data_path.expanduser()
174
+
175
+ if not self.data_path.exists():
176
+ raise ValueError(f"Path {self.data_path} does not exist")
177
+
178
+ if not self.data_path.is_dir():
179
+ raise ValueError(f"Path {self.data_path} is not a directory")
180
+
181
+ # Object if caching settings are impossible:
182
+ if self.caching:
183
+ if self.sampling:
184
+ raise ValueError("Sampling should be False for caching")
185
+ if self.compute_scaling_factors:
186
+ raise ValueError("Compute scaling factors should be False for caching")
187
+
188
+ if self.phase not in [
189
+ "train",
190
+ "val",
191
+ "test",
192
+ ]:
193
+ raise ValueError(
194
+ f"phase should be one of ['train', 'val', 'test'], got {self.phase}"
195
+ )
196
+ if self.scaling_type is not None:
197
+ if self.scaling_type not in [
198
+ "min_max_scaling",
199
+ "mean_std_scaling",
200
+ ]:
201
+ raise ValueError(
202
+ f"scaling_type should be one of ['min_max_scaling', 'mean_std_scaling'], got {self.scaling_type}"
203
+ )
204
+
205
+
206
+ ##### TODO
207
+ # - The SDF normalization here is based on using a normalized mesh and
208
+ # a normalized coordinate. The alternate method is to normalize to the min/max of the grid.
209
+
210
+
211
+ class DoMINODataPipe(Dataset):
212
+ """
213
+ Datapipe for DoMINO
214
+
215
+ Leverages a dataset for the actual reading of the data, and this
216
+ object is responsible for preprocessing the data.
217
+
218
+ """
219
+
220
+ def __init__(
221
+ self,
222
+ input_path,
223
+ model_type: Literal["surface", "volume", "combined"],
224
+ pin_memory: bool = False,
225
+ **data_config_overrides,
226
+ ):
227
+ # Perform config packaging and validation
228
+ self.config = DoMINODataConfig(data_path=input_path, **data_config_overrides)
229
+
230
+ # Set up the distributed manager:
231
+ if not DistributedManager.is_initialized():
232
+ DistributedManager.initialize()
233
+
234
+ dist = DistributedManager()
235
+
236
+ # Set devices for the preprocessing and IO target
237
+ self.preproc_device = (
238
+ dist.device if self.config.gpu_preprocessing else torch.device("cpu")
239
+ )
240
+ # The cae_dataset will automatically target this device
241
+ # In an async transfer.
242
+ self.output_device = (
243
+ dist.device if self.config.gpu_output else torch.device("cpu")
244
+ )
245
+
246
+ # Model type determines whether we process surface, volume, or both.
247
+ self.model_type = model_type
248
+
249
+ # Update the arrays for bounding boxes:
250
+ if hasattr(self.config.bounding_box_dims, "max") and hasattr(
251
+ self.config.bounding_box_dims, "min"
252
+ ):
253
+ self.config.bounding_box_dims = [
254
+ torch.tensor(
255
+ self.config.bounding_box_dims.max,
256
+ device=self.preproc_device,
257
+ dtype=torch.float32,
258
+ ),
259
+ torch.tensor(
260
+ self.config.bounding_box_dims.min,
261
+ device=self.preproc_device,
262
+ dtype=torch.float32,
263
+ ),
264
+ ]
265
+ self.default_volume_grid = create_grid(
266
+ self.config.bounding_box_dims[0],
267
+ self.config.bounding_box_dims[1],
268
+ self.config.grid_resolution,
269
+ )
270
+
271
+ # And, do the surface bounding box if supplied:
272
+ if hasattr(self.config.bounding_box_dims_surf, "max") and hasattr(
273
+ self.config.bounding_box_dims_surf, "min"
274
+ ):
275
+ self.config.bounding_box_dims_surf = [
276
+ torch.tensor(
277
+ self.config.bounding_box_dims_surf.max,
278
+ device=self.preproc_device,
279
+ dtype=torch.float32,
280
+ ),
281
+ torch.tensor(
282
+ self.config.bounding_box_dims_surf.min,
283
+ device=self.preproc_device,
284
+ dtype=torch.float32,
285
+ ),
286
+ ]
287
+
288
+ self.default_surface_grid = create_grid(
289
+ self.config.bounding_box_dims_surf[0],
290
+ self.config.bounding_box_dims_surf[1],
291
+ self.config.grid_resolution,
292
+ )
293
+
294
+ # Ensure the volume and surface scaling factors are torch tensors
295
+ # and on the right device:
296
+ if self.config.volume_factors is not None:
297
+ if not isinstance(self.config.volume_factors, torch.Tensor):
298
+ self.config.volume_factors = torch.from_numpy(
299
+ self.config.volume_factors
300
+ )
301
+ self.config.volume_factors = self.config.volume_factors.to(
302
+ self.preproc_device, dtype=torch.float32
303
+ )
304
+ if self.config.surface_factors is not None:
305
+ if not isinstance(self.config.surface_factors, torch.Tensor):
306
+ self.config.surface_factors = torch.from_numpy(
307
+ self.config.surface_factors
308
+ )
309
+ self.config.surface_factors = self.config.surface_factors.to(
310
+ self.preproc_device, dtype=torch.float32
311
+ )
312
+
313
+ self.dataset = None
314
+
315
+ def compute_stl_scaling_and_surface_grids(
316
+ self,
317
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
318
+ """
319
+ Compute the min and max for the defining mesh.
320
+
321
+ If the user supplies a bounding box, we use that. Otherwise,
322
+ it raises an error.
323
+
324
+ The returned min/max and grid are used for surface data.
325
+ """
326
+
327
+ # Check the bounding box is not unit length
328
+
329
+ if self.config.bounding_box_dims_surf is not None:
330
+ s_max = self.config.bounding_box_dims_surf[0]
331
+ s_min = self.config.bounding_box_dims_surf[1]
332
+ surf_grid = self.default_surface_grid
333
+ else:
334
+ raise ValueError("Bounding box dimensions are not set in config")
335
+
336
+ return s_min, s_max, surf_grid
337
+
338
+ def compute_volume_scaling_and_grids(
339
+ self,
340
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
341
+ """
342
+ Compute the min and max and grid for volume data.
343
+
344
+ If the user supplies a bounding box, we use that. Otherwise,
345
+ it raises an error.
346
+
347
+ """
348
+
349
+ # Determine the volume min / max locations
350
+ if self.config.bounding_box_dims is not None:
351
+ c_max = self.config.bounding_box_dims[0]
352
+ c_min = self.config.bounding_box_dims[1]
353
+ volume_grid = self.default_volume_grid
354
+ else:
355
+ raise ValueError("Bounding box dimensions are not set in config")
356
+
357
+ return c_min, c_max, volume_grid
358
+
359
+ @profile
360
+ def downsample_geometry(
361
+ self,
362
+ stl_vertices,
363
+ ) -> torch.Tensor:
364
+ """
365
+ Downsample the geometry to the desired number of points.
366
+
367
+ Args:
368
+ stl_vertices: The vertices of the surface.
369
+ """
370
+
371
+ if self.config.sampling:
372
+ geometry_points = self.config.geom_points_sample
373
+
374
+ geometry_coordinates_sampled, idx_geometry = shuffle_array(
375
+ stl_vertices, geometry_points
376
+ )
377
+ if geometry_coordinates_sampled.shape[0] < geometry_points:
378
+ raise ValueError(
379
+ "Surface mesh has fewer points than requested sample size"
380
+ )
381
+ geom_centers = geometry_coordinates_sampled
382
+ else:
383
+ geom_centers = stl_vertices
384
+
385
+ return geom_centers
386
+
387
+ def process_surface(
388
+ self,
389
+ s_min: torch.Tensor,
390
+ s_max: torch.Tensor,
391
+ c_min: torch.Tensor,
392
+ c_max: torch.Tensor,
393
+ *, # Forcing the rest by keyword only since it's a long list ...
394
+ center_of_mass: torch.Tensor,
395
+ surf_grid: torch.Tensor,
396
+ surface_coordinates: torch.Tensor,
397
+ surface_normals: torch.Tensor,
398
+ surface_sizes: torch.Tensor,
399
+ stl_vertices: torch.Tensor,
400
+ stl_indices: torch.Tensor,
401
+ surface_fields: torch.Tensor | None,
402
+ ) -> dict[str, torch.Tensor]:
403
+ nx, ny, nz = self.config.grid_resolution
404
+
405
+ return_dict = {}
406
+
407
+ ########################################################################
408
+ # Remove any sizes <= 0:
409
+ ########################################################################
410
+ idx = surface_sizes > 0
411
+ surface_sizes = surface_sizes[idx]
412
+ surface_normals = surface_normals[idx]
413
+ surface_coordinates = surface_coordinates[idx]
414
+ if surface_fields is not None:
415
+ surface_fields = surface_fields[idx]
416
+
417
+ ########################################################################
418
+ # Reject surface points outside of the Bounding Box
419
+ # NOTE - this is using the VOLUME bounding box!
420
+ ########################################################################
421
+ if self.config.sample_in_bbox:
422
+ ids_min = surface_coordinates[:] > c_min
423
+ ids_max = surface_coordinates[:] < c_max
424
+
425
+ ids_in_bbox = ids_min & ids_max
426
+ ids_in_bbox = ids_in_bbox.all(dim=-1)
427
+
428
+ surface_coordinates = surface_coordinates[ids_in_bbox]
429
+ surface_normals = surface_normals[ids_in_bbox]
430
+ surface_sizes = surface_sizes[ids_in_bbox]
431
+ if surface_fields is not None:
432
+ surface_fields = surface_fields[ids_in_bbox]
433
+
434
+ ########################################################################
435
+ # Perform Down sampling of the surface fields.
436
+ # Note that we snapshot the full surface coordinates for
437
+ # use in the kNN in the next step.
438
+ ########################################################################
439
+
440
+ full_surface_coordinates = surface_coordinates
441
+ full_surface_normals = surface_normals
442
+ full_surface_sizes = surface_sizes
443
+
444
+ if self.config.sampling:
445
+ # Perform the down sampling:
446
+ if self.config.surface_sampling_algorithm == "area_weighted":
447
+ weights = surface_sizes
448
+ else:
449
+ weights = None
450
+
451
+ surface_coordinates_sampled, idx_surface = shuffle_array(
452
+ surface_coordinates,
453
+ self.config.surface_points_sample,
454
+ weights=weights,
455
+ )
456
+
457
+ if surface_coordinates_sampled.shape[0] < self.config.surface_points_sample:
458
+ raise ValueError(
459
+ "Surface mesh has fewer points than requested sample size"
460
+ )
461
+
462
+ # Select out the sampled points for non-neighbor arrays:
463
+ if surface_fields is not None:
464
+ surface_fields = surface_fields[idx_surface]
465
+
466
+ # Subsample the normals and sizes:
467
+ surface_normals = surface_normals[idx_surface]
468
+ surface_sizes = surface_sizes[idx_surface]
469
+ # Update the coordinates to the sampled points:
470
+ surface_coordinates = surface_coordinates_sampled
471
+
472
+ ########################################################################
473
+ # Perform a kNN on the surface to find the neighbor information
474
+ ########################################################################
475
+ if self.config.num_surface_neighbors > 1:
476
+ # Perform the kNN:
477
+ neighbor_indices, neighbor_distances = knn(
478
+ points=full_surface_coordinates,
479
+ queries=surface_coordinates,
480
+ k=self.config.num_surface_neighbors,
481
+ )
482
+ # print(f"Full surface coordinates shape: {full_surface_coordinates.shape}")
483
+ # Pull out the neighbor elements.
484
+ # Note that `neighbor_indices` is the index into the original,
485
+ # full sized tensors (full_surface_coordinates, etc).
486
+ surface_neighbors = full_surface_coordinates[neighbor_indices][:, 1:]
487
+ surface_neighbors_normals = full_surface_normals[neighbor_indices][:, 1:]
488
+ surface_neighbors_sizes = full_surface_sizes[neighbor_indices][:, 1:]
489
+ else:
490
+ surface_neighbors = surface_coordinates
491
+ surface_neighbors_normals = surface_normals
492
+ surface_neighbors_sizes = surface_sizes
493
+
494
+ # Better to normalize everything after the kNN and sampling
495
+ if self.config.normalize_coordinates:
496
+ surface_coordinates = normalize(surface_coordinates, s_max, s_min)
497
+ surface_neighbors = normalize(surface_neighbors, s_max, s_min)
498
+ center_of_mass = normalize(center_of_mass, s_max, s_min)
499
+
500
+ pos_normals_com_surface = surface_coordinates - center_of_mass
501
+
502
+ ########################################################################
503
+ # Apply scaling to the targets, if desired:
504
+ ########################################################################
505
+ if self.config.scaling_type is not None and surface_fields is not None:
506
+ surface_fields = self.scale_model_targets(
507
+ surface_fields, self.config.surface_factors
508
+ )
509
+
510
+ return_dict.update(
511
+ {
512
+ "pos_surface_center_of_mass": pos_normals_com_surface,
513
+ "surface_mesh_centers": surface_coordinates,
514
+ "surface_mesh_neighbors": surface_neighbors,
515
+ "surface_normals": surface_normals,
516
+ "surface_neighbors_normals": surface_neighbors_normals,
517
+ "surface_areas": surface_sizes,
518
+ "surface_neighbors_areas": surface_neighbors_sizes,
519
+ }
520
+ )
521
+ if surface_fields is not None:
522
+ return_dict["surface_fields"] = surface_fields
523
+
524
+ return return_dict
525
+
526
+ def process_volume(
527
+ self,
528
+ c_min: torch.Tensor,
529
+ c_max: torch.Tensor,
530
+ volume_coordinates: torch.Tensor,
531
+ volume_grid: torch.Tensor,
532
+ center_of_mass: torch.Tensor,
533
+ stl_vertices: torch.Tensor,
534
+ stl_indices: torch.Tensor,
535
+ volume_fields: torch.Tensor | None,
536
+ ) -> dict[str, torch.Tensor]:
537
+ """
538
+ Preprocess the volume data.
539
+
540
+ First, if configured, we reject points not in the volume bounding box.
541
+
542
+ Next, if sampling is enabled, we sample the volume points and apply that
543
+ sampling to the ground truth too, if it's present.
544
+
545
+ """
546
+ ########################################################################
547
+ # Reject points outside the volumetric BBox
548
+ ########################################################################
549
+ if self.config.sample_in_bbox:
550
+ # Remove points in the volume that are outside
551
+ # of the bbox area.
552
+ min_check = volume_coordinates[:] > c_min
553
+ max_check = volume_coordinates[:] < c_max
554
+
555
+ ids_in_bbox = min_check & max_check
556
+ ids_in_bbox = ids_in_bbox.all(dim=1)
557
+
558
+ volume_coordinates = volume_coordinates[ids_in_bbox]
559
+ if volume_fields is not None:
560
+ volume_fields = volume_fields[ids_in_bbox]
561
+
562
+ ########################################################################
563
+ # Apply sampling to the volume coordinates and fields
564
+ ########################################################################
565
+
566
+ # If the volume data has been sampled from disk, directly, then
567
+ # still apply sampling. We over-pull from disk deliberately.
568
+ if self.config.sampling:
569
+ # Generate a series of idx to sample the volume
570
+ # without replacement
571
+ volume_coordinates_sampled, idx_volume = shuffle_array(
572
+ volume_coordinates, self.config.volume_points_sample
573
+ )
574
+ volume_coordinates_sampled = volume_coordinates[idx_volume]
575
+ # In case too few points are in the sampled data (because the
576
+ # inputs were too few), pad the outputs:
577
+ if volume_coordinates_sampled.shape[0] < self.config.volume_points_sample:
578
+ raise ValueError(
579
+ "Volume mesh has fewer points than requested sample size"
580
+ )
581
+
582
+ # Apply the same sampling to the targets, too:
583
+ if volume_fields is not None:
584
+ volume_fields = volume_fields[idx_volume]
585
+
586
+ volume_coordinates = volume_coordinates_sampled
587
+
588
+ ########################################################################
589
+ # Apply normalization to the coordinates, if desired:
590
+ ########################################################################
591
+ if self.config.normalize_coordinates:
592
+ volume_coordinates = normalize(volume_coordinates, c_max, c_min)
593
+ grid = normalize(volume_grid, c_max, c_min)
594
+ normed_vertices = normalize(stl_vertices, c_max, c_min)
595
+ center_of_mass = normalize(center_of_mass, c_max, c_min)
596
+ else:
597
+ grid = volume_grid
598
+ normed_vertices = stl_vertices
599
+ center_of_mass = center_of_mass
600
+
601
+ ########################################################################
602
+ # Apply scaling to the targets, if desired:
603
+ ########################################################################
604
+ if self.config.scaling_type is not None and volume_fields is not None:
605
+ volume_fields = self.scale_model_targets(
606
+ volume_fields, self.config.volume_factors
607
+ )
608
+
609
+ ########################################################################
610
+ # Compute Signed Distance Function for volumetric quantities
611
+ # Note - the SDF happens here, after volume data processing finishes,
612
+ # because we need to use the (maybe) normalized volume coordinates and grid
613
+ ########################################################################
614
+
615
+ # SDF calculation on the volume grid using WARP
616
+ sdf_grid, _ = signed_distance_field(
617
+ normed_vertices,
618
+ stl_indices,
619
+ grid,
620
+ use_sign_winding_number=True,
621
+ )
622
+
623
+ # Get the SDF of all the selected volume coordinates,
624
+ # And keep the closest point to each one.
625
+ sdf_nodes, sdf_node_closest_point = signed_distance_field(
626
+ normed_vertices,
627
+ stl_indices,
628
+ volume_coordinates,
629
+ use_sign_winding_number=True,
630
+ )
631
+ sdf_nodes = sdf_nodes.reshape((-1, 1))
632
+
633
+ # Use the closest point from the mesh to compute the volume encodings:
634
+ pos_normals_closest_vol, pos_normals_com_vol = self.calculate_volume_encoding(
635
+ volume_coordinates, sdf_node_closest_point, center_of_mass
636
+ )
637
+
638
+ return_dict = {
639
+ "volume_mesh_centers": volume_coordinates,
640
+ "sdf_nodes": sdf_nodes,
641
+ "grid": grid,
642
+ "sdf_grid": sdf_grid,
643
+ "pos_volume_closest": pos_normals_closest_vol,
644
+ "pos_volume_center_of_mass": pos_normals_com_vol,
645
+ }
646
+
647
+ if volume_fields is not None:
648
+ return_dict["volume_fields"] = volume_fields
649
+
650
+ return return_dict
651
+
652
+ def calculate_volume_encoding(
653
+ self,
654
+ volume_coordinates: torch.Tensor,
655
+ sdf_node_closest_point: torch.Tensor,
656
+ center_of_mass: torch.Tensor,
657
+ ):
658
+ pos_normals_closest_vol = volume_coordinates - sdf_node_closest_point
659
+ pos_normals_com_vol = volume_coordinates - center_of_mass
660
+
661
+ return pos_normals_closest_vol, pos_normals_com_vol
662
+
663
+ @torch.no_grad()
664
+ def process_data(self, data_dict):
665
+ # Validate that all required keys are present in data_dict
666
+ required_keys = [
667
+ "global_params_values",
668
+ "global_params_reference",
669
+ "stl_coordinates",
670
+ "stl_faces",
671
+ "stl_centers",
672
+ "stl_areas",
673
+ ]
674
+ missing_keys = [key for key in required_keys if key not in data_dict]
675
+ if missing_keys:
676
+ raise ValueError(
677
+ f"Missing required keys in data_dict: {missing_keys}. "
678
+ f"Required keys are: {required_keys}"
679
+ )
680
+
681
+ # Start building the preprocessed return dict:
682
+ return_dict = {
683
+ "global_params_values": data_dict["global_params_values"],
684
+ "global_params_reference": data_dict["global_params_reference"],
685
+ }
686
+
687
+ # DoMINO's sharded datapipe can be tricky - output shapes are not always
688
+ # so simple to calculate, since much of the datapipe is dynamic.
689
+ # The datset will read in sharded data, to minimize IO.
690
+ # We collect it all locally, here, and then scatter
691
+ # Appropriately for the outputs
692
+
693
+ if self.config.shard_grid or self.config.shard_points:
694
+ # Get the mesh:
695
+ mesh = data_dict["stl_coordinates"]._spec.mesh
696
+ local_data_dict = {}
697
+ for key, value in data_dict.items():
698
+ local_data_dict[key] = value.full_tensor()
699
+
700
+ data_dict = local_data_dict
701
+
702
+ ########################################################################
703
+ # Process the core STL information
704
+ ########################################################################
705
+
706
+ # This function gets information about the surface scale,
707
+ # and decides what the surface grid will be:
708
+
709
+ s_min, s_max, surf_grid = self.compute_stl_scaling_and_surface_grids()
710
+
711
+ # We always need to calculate the SDF on the surface grid:
712
+ # This is for the SDF Later:
713
+ if self.config.normalize_coordinates:
714
+ normed_vertices = normalize(data_dict["stl_coordinates"], s_max, s_min)
715
+ surf_grid = normalize(surf_grid, s_max, s_min)
716
+ else:
717
+ normed_vertices = data_dict["stl_coordinates"]
718
+
719
+ # For SDF calculations, make sure the mesh_indices_flattened is an integer array:
720
+ mesh_indices_flattened = data_dict["stl_faces"].to(torch.int32)
721
+
722
+ # Compute signed distance function for the surface grid:
723
+ sdf_surf_grid, _ = signed_distance_field(
724
+ mesh_vertices=normed_vertices,
725
+ mesh_indices=mesh_indices_flattened,
726
+ input_points=surf_grid,
727
+ use_sign_winding_number=True,
728
+ )
729
+ return_dict["sdf_surf_grid"] = sdf_surf_grid
730
+ return_dict["surf_grid"] = surf_grid
731
+
732
+ # Store this only if normalization is active:
733
+ if self.config.normalize_coordinates:
734
+ return_dict["surface_min_max"] = torch.stack([s_min, s_max])
735
+
736
+ # This is a center of mass computation for the stl surface,
737
+ # using the size of each mesh point as weight.
738
+ center_of_mass = calculate_center_of_mass(
739
+ data_dict["stl_centers"], data_dict["stl_areas"]
740
+ )
741
+
742
+ # This will apply downsampling if needed to the geometry coordinates
743
+ geom_centers = self.downsample_geometry(
744
+ stl_vertices=data_dict["stl_coordinates"],
745
+ )
746
+ return_dict["geometry_coordinates"] = geom_centers
747
+
748
+ ########################################################################
749
+ # Determine the volumetric bounds of the data:
750
+ ########################################################################
751
+ # Compute the min/max for volume an the unnomralized grid:
752
+ c_min, c_max, volume_grid = self.compute_volume_scaling_and_grids()
753
+
754
+ ########################################################################
755
+ # Process the surface data
756
+ ########################################################################
757
+ if self.model_type == "surface" or self.model_type == "combined":
758
+ surface_fields_raw = (
759
+ data_dict["surface_fields"] if "surface_fields" in data_dict else None
760
+ )
761
+ surface_dict = self.process_surface(
762
+ s_min,
763
+ s_max,
764
+ c_min,
765
+ c_max,
766
+ center_of_mass=center_of_mass,
767
+ surf_grid=surf_grid,
768
+ surface_coordinates=data_dict["surface_mesh_centers"],
769
+ surface_normals=data_dict["surface_normals"],
770
+ surface_sizes=data_dict["surface_areas"],
771
+ stl_vertices=data_dict["stl_coordinates"],
772
+ stl_indices=mesh_indices_flattened,
773
+ surface_fields=surface_fields_raw,
774
+ )
775
+
776
+ return_dict.update(surface_dict)
777
+
778
+ ########################################################################
779
+ # Process the volume data
780
+ ########################################################################
781
+ # For volume data, we store this only if normalizing coordinates:
782
+ if self.model_type == "volume" or self.model_type == "combined":
783
+ if self.config.normalize_coordinates:
784
+ return_dict["volume_min_max"] = torch.stack([c_min, c_max])
785
+
786
+ if self.model_type == "volume" or self.model_type == "combined":
787
+ volume_fields_raw = (
788
+ data_dict["volume_fields"] if "volume_fields" in data_dict else None
789
+ )
790
+ volume_dict = self.process_volume(
791
+ c_min,
792
+ c_max,
793
+ volume_coordinates=data_dict["volume_mesh_centers"],
794
+ volume_grid=volume_grid,
795
+ center_of_mass=center_of_mass,
796
+ stl_vertices=data_dict["stl_coordinates"],
797
+ stl_indices=mesh_indices_flattened,
798
+ volume_fields=volume_fields_raw,
799
+ )
800
+
801
+ return_dict.update(volume_dict)
802
+
803
+ # For domain parallelism, shard everything appropriately:
804
+ if self.config.shard_grid or self.config.shard_points:
805
+ # Mesh was defined above!
806
+ output_dict = {}
807
+
808
+ # For scattering, we need to know the _global_ index of rank
809
+ # 0 on this mesh:
810
+ global_index = dist.get_global_rank(mesh.get_group(), 0)
811
+
812
+ for key, value in return_dict.items():
813
+ grid_placements = (
814
+ [
815
+ Shard(0),
816
+ ]
817
+ if self.config.shard_grid
818
+ else [
819
+ Replicate(),
820
+ ]
821
+ )
822
+ point_placements = (
823
+ [
824
+ Shard(0),
825
+ ]
826
+ if self.config.shard_points
827
+ else [
828
+ Replicate(),
829
+ ]
830
+ )
831
+ if key == "volume_min_max":
832
+ output_dict[key] = ShardTensor.from_local(
833
+ value,
834
+ mesh,
835
+ [
836
+ Replicate(),
837
+ ],
838
+ )
839
+ elif key == "surface_min_max":
840
+ output_dict[key] = ShardTensor.from_local(
841
+ value,
842
+ mesh,
843
+ [
844
+ Replicate(),
845
+ ],
846
+ )
847
+ elif not isinstance(value, ShardTensor):
848
+ if "grid" in key:
849
+ output_dict[key] = scatter_tensor(
850
+ value.contiguous(),
851
+ global_index,
852
+ mesh,
853
+ grid_placements,
854
+ global_shape=value.shape,
855
+ dtype=value.dtype,
856
+ )
857
+ else:
858
+ output_dict[key] = scatter_tensor(
859
+ value.contiguous(),
860
+ global_index,
861
+ mesh,
862
+ point_placements,
863
+ global_shape=value.shape,
864
+ dtype=value.dtype,
865
+ )
866
+ else:
867
+ output_dict[key] = value
868
+
869
+ return_dict = output_dict
870
+
871
+ return return_dict
872
+
873
+ def scale_model_targets(
874
+ self, fields: torch.Tensor, factors: torch.Tensor
875
+ ) -> torch.Tensor:
876
+ """
877
+ Scale the model targets based on the configured scaling factors.
878
+ """
879
+ if self.config.scaling_type == "mean_std_scaling":
880
+ field_mean = factors[0]
881
+ field_std = factors[1]
882
+ return standardize(fields, field_mean, field_std)
883
+ elif self.config.scaling_type == "min_max_scaling":
884
+ field_min = factors[1]
885
+ field_max = factors[0]
886
+ return normalize(fields, field_max, field_min)
887
+
888
+ def unscale_model_outputs(
889
+ self,
890
+ volume_fields: torch.Tensor | None = None,
891
+ surface_fields: torch.Tensor | None = None,
892
+ ):
893
+ """
894
+ Unscale the model outputs based on the configured scaling factors.
895
+
896
+ The unscaling is included here to make it a consistent interface regardless
897
+ of the scaling factors and type used.
898
+
899
+ """
900
+
901
+ # This is a step to make sure we can apply to sharded outputs:
902
+ if volume_fields is not None and isinstance(volume_fields, ShardTensor):
903
+ volume_spec = volume_fields._spec
904
+ volume_fields = ShardTensor.to_local(volume_fields)
905
+ else:
906
+ volume_spec = None
907
+
908
+ if surface_fields is not None and isinstance(surface_fields, ShardTensor):
909
+ surface_spec = surface_fields._spec
910
+ surface_fields = ShardTensor.to_local(surface_fields)
911
+ else:
912
+ surface_spec = None
913
+
914
+ if volume_fields is not None:
915
+ if self.config.scaling_type == "mean_std_scaling":
916
+ vol_mean = self.config.volume_factors[0]
917
+ vol_std = self.config.volume_factors[1]
918
+ volume_fields = unstandardize(volume_fields, vol_mean, vol_std)
919
+ elif self.config.scaling_type == "min_max_scaling":
920
+ vol_min = self.config.volume_factors[1]
921
+ vol_max = self.config.volume_factors[0]
922
+ volume_fields = unnormalize(volume_fields, vol_max, vol_min)
923
+ if surface_fields is not None:
924
+ if self.config.scaling_type == "mean_std_scaling":
925
+ surf_mean = self.config.surface_factors[0]
926
+ surf_std = self.config.surface_factors[1]
927
+ surface_fields = unstandardize(surface_fields, surf_mean, surf_std)
928
+ elif self.config.scaling_type == "min_max_scaling":
929
+ surf_min = self.config.surface_factors[1]
930
+ surf_max = self.config.surface_factors[0]
931
+ surface_fields = unnormalize(surface_fields, surf_max, surf_min)
932
+
933
+ if volume_spec is not None:
934
+ volume_fields = ShardTensor.from_local(
935
+ volume_fields,
936
+ device_mesh=volume_spec.mesh,
937
+ placements=volume_spec.placements,
938
+ sharding_shapes=volume_spec.sharding_shapes(),
939
+ )
940
+ if surface_spec is not None:
941
+ surface_fields = ShardTensor.from_local(
942
+ surface_fields,
943
+ device_mesh=surface_spec.mesh,
944
+ placements=surface_spec.placements,
945
+ sharding_shapes=surface_spec.sharding_shapes(),
946
+ )
947
+
948
+ return volume_fields, surface_fields
949
+
950
+ def set_dataset(self, dataset: Iterable) -> None:
951
+ """
952
+ Pass a dataset to the datapipe to enable iterating over both in one pass.
953
+ """
954
+ self.dataset = dataset
955
+
956
+ if self.config.volume_sample_from_disk:
957
+ # We deliberately double the data to read compared to the sampling size:
958
+ self.dataset.set_volume_sampling_size(
959
+ 100 * self.config.volume_points_sample
960
+ )
961
+
962
+ def __len__(self):
963
+ if self.dataset is not None:
964
+ return len(self.dataset)
965
+ else:
966
+ return 0
967
+
968
+ def __getitem__(self, idx):
969
+ """
970
+ Function for fetching and processing a single file's data.
971
+
972
+ Domino, in general, expects one example per file and the files
973
+ are relatively large due to the mesh size.
974
+
975
+ Requires the user to have set a dataset via `set_dataset`.
976
+ """
977
+ if self.dataset is None:
978
+ raise ValueError("Dataset is not present")
979
+
980
+ # Get the data from the dataset.
981
+ # Under the hood, this may be fetching preloaded data.
982
+ data_dict = self.dataset[idx]
983
+
984
+ return self.__call__(data_dict)
985
+
986
+ def __call__(self, data_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
987
+ """
988
+ Process the incoming data dictionary.
989
+ - Processes the data
990
+ - moves it to GPU
991
+ - adds a batch dimension
992
+
993
+ Args:
994
+ data_dict: Dictionary containing the data to process as torch.Tensors.
995
+
996
+ Returns:
997
+ Dictionary containing the processed data as torch.Tensors.
998
+
999
+ """
1000
+ data_dict = self.process_data(data_dict)
1001
+
1002
+ # If the data is not on the target device, put it there:
1003
+ for key, value in data_dict.items():
1004
+ if value.device != self.output_device:
1005
+ data_dict[key] = value.to(self.output_device)
1006
+
1007
+ # Add a batch dimension to the data_dict
1008
+ data_dict = {k: v.unsqueeze(0) for k, v in data_dict.items()}
1009
+
1010
+ return data_dict
1011
+
1012
+ def __iter__(self):
1013
+ if self.dataset is None:
1014
+ raise ValueError(
1015
+ "Dataset is not present, can not use the datapipe as an iterator."
1016
+ )
1017
+
1018
+ for i, batch in enumerate(self.dataset):
1019
+ yield self.__call__(batch)
1020
+
1021
+
1022
+ def compute_scaling_factors(
1023
+ cfg: DictConfig,
1024
+ input_path: str,
1025
+ target_keys: list[str],
1026
+ max_samples=20,
1027
+ ) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
1028
+ """
1029
+ Using the dataset at the path, compute the mean, std, min, and max of the target keys.
1030
+
1031
+ Args:
1032
+ cfg: Hydra configuration object containing all parameters
1033
+ input_path: Path to the dataset to load.
1034
+ target_keys: List of keys to compute the mean, std, min, and max of.
1035
+ use_cache: (deprecated) This argument has no effect.
1036
+ """
1037
+
1038
+ device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
1039
+
1040
+ dataset = CAEDataset(
1041
+ data_dir=input_path,
1042
+ keys_to_read=target_keys,
1043
+ keys_to_read_if_available={},
1044
+ output_device=device,
1045
+ )
1046
+
1047
+ mean, std, min_val, max_val = compute_mean_std_min_max(
1048
+ dataset,
1049
+ field_keys=target_keys,
1050
+ max_samples=max_samples,
1051
+ )
1052
+
1053
+ return mean, std, min_val, max_val
1054
+
1055
+
1056
+ class CachedDoMINODataset(Dataset):
1057
+ """
1058
+ Dataset for reading cached DoMINO data files, with optional resampling.
1059
+ Acts as a drop-in replacement for DoMINODataPipe.
1060
+ """
1061
+
1062
+ # @nvtx_annotate(message="CachedDoMINODataset __init__")
1063
+ def __init__(
1064
+ self,
1065
+ data_path: Union[str, Path],
1066
+ phase: Literal["train", "val", "test"] = "train",
1067
+ sampling: bool = False,
1068
+ volume_points_sample: Optional[int] = None,
1069
+ surface_points_sample: Optional[int] = None,
1070
+ geom_points_sample: Optional[int] = None,
1071
+ model_type=None, # Model_type, surface, volume or combined
1072
+ deterministic_seed=False,
1073
+ surface_sampling_algorithm="area_weighted",
1074
+ ):
1075
+ super().__init__()
1076
+
1077
+ self.model_type = model_type
1078
+ if deterministic_seed:
1079
+ np.random.seed(42)
1080
+
1081
+ if isinstance(data_path, str):
1082
+ data_path = Path(data_path)
1083
+ self.data_path = data_path.expanduser()
1084
+
1085
+ if not self.data_path.exists():
1086
+ raise AssertionError(f"Path {self.data_path} does not exist")
1087
+ if not self.data_path.is_dir():
1088
+ raise AssertionError(f"Path {self.data_path} is not a directory")
1089
+
1090
+ self.deterministic_seed = deterministic_seed
1091
+ self.sampling = sampling
1092
+ self.volume_points = volume_points_sample
1093
+ self.surface_points = surface_points_sample
1094
+ self.geom_points = geom_points_sample
1095
+ self.surface_sampling_algorithm = surface_sampling_algorithm
1096
+
1097
+ self.filenames = get_filenames(self.data_path, exclude_dirs=True)
1098
+
1099
+ total_files = len(self.filenames)
1100
+
1101
+ self.phase = phase
1102
+ self.indices = np.array(range(total_files))
1103
+
1104
+ np.random.shuffle(self.indices)
1105
+
1106
+ if not self.filenames:
1107
+ raise AssertionError(f"No cached files found in {self.data_path}")
1108
+
1109
+ def __len__(self):
1110
+ return len(self.indices)
1111
+
1112
+ # @nvtx_annotate(message="CachedDoMINODataset __getitem__")
1113
+ def __getitem__(self, idx):
1114
+ if self.deterministic_seed:
1115
+ np.random.seed(idx)
1116
+ nvtx.range_push("Load cached file")
1117
+
1118
+ index = self.indices[idx]
1119
+ cfd_filename = self.filenames[index]
1120
+
1121
+ filepath = self.data_path / cfd_filename
1122
+ result = np.load(filepath, allow_pickle=True).item()
1123
+ result = {
1124
+ k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v
1125
+ for k, v in result.items()
1126
+ }
1127
+
1128
+ nvtx.range_pop()
1129
+ if not self.sampling:
1130
+ return result
1131
+
1132
+ nvtx.range_push("Sample points")
1133
+
1134
+ # Sample volume points if present
1135
+ if "volume_mesh_centers" in result and self.volume_points:
1136
+ coords_sampled, idx_volume = shuffle_array(
1137
+ result["volume_mesh_centers"], self.volume_points
1138
+ )
1139
+ if coords_sampled.shape[0] < self.volume_points:
1140
+ coords_sampled = pad(
1141
+ coords_sampled, self.volume_points, pad_value=-10.0
1142
+ )
1143
+
1144
+ result["volume_mesh_centers"] = coords_sampled
1145
+ for key in [
1146
+ "volume_fields",
1147
+ "pos_volume_closest",
1148
+ "pos_volume_center_of_mass",
1149
+ "sdf_nodes",
1150
+ ]:
1151
+ if key in result:
1152
+ result[key] = result[key][idx_volume]
1153
+
1154
+ # Sample surface points if present
1155
+ if "surface_mesh_centers" in result and self.surface_points:
1156
+ if self.surface_sampling_algorithm == "area_weighted":
1157
+ coords_sampled, idx_surface = shuffle_array(
1158
+ points=result["surface_mesh_centers"],
1159
+ n_points=self.surface_points,
1160
+ weights=result["surface_areas"],
1161
+ )
1162
+ else:
1163
+ coords_sampled, idx_surface = shuffle_array(
1164
+ result["surface_mesh_centers"], self.surface_points
1165
+ )
1166
+
1167
+ if coords_sampled.shape[0] < self.surface_points:
1168
+ coords_sampled = pad(
1169
+ coords_sampled, self.surface_points, pad_value=-10.0
1170
+ )
1171
+
1172
+ ii = result["neighbor_indices"]
1173
+ result["surface_mesh_neighbors"] = result["surface_mesh_centers"][ii]
1174
+ result["surface_neighbors_normals"] = result["surface_normals"][ii]
1175
+ result["surface_neighbors_areas"] = result["surface_areas"][ii]
1176
+
1177
+ result["surface_mesh_centers"] = coords_sampled
1178
+
1179
+ for key in [
1180
+ "surface_fields",
1181
+ "surface_areas",
1182
+ "surface_normals",
1183
+ "pos_surface_center_of_mass",
1184
+ "surface_mesh_neighbors",
1185
+ "surface_neighbors_normals",
1186
+ "surface_neighbors_areas",
1187
+ ]:
1188
+ if key in result:
1189
+ result[key] = result[key][idx_surface]
1190
+
1191
+ del result["neighbor_indices"]
1192
+
1193
+ # Sample geometry points if present
1194
+ if "geometry_coordinates" in result and self.geom_points:
1195
+ coords_sampled, _ = shuffle_array(
1196
+ result["geometry_coordinates"], self.geom_points
1197
+ )
1198
+ if coords_sampled.shape[0] < self.geom_points:
1199
+ coords_sampled = pad(coords_sampled, self.geom_points, pad_value=-100.0)
1200
+ result["geometry_coordinates"] = coords_sampled
1201
+
1202
+ nvtx.range_pop()
1203
+ return result
1204
+
1205
+
1206
+ def create_domino_dataset(
1207
+ cfg: DictConfig,
1208
+ phase: Literal["train", "val", "test"],
1209
+ keys_to_read: list[str],
1210
+ keys_to_read_if_available: dict[str, torch.Tensor],
1211
+ vol_factors: list[float],
1212
+ surf_factors: list[float],
1213
+ normalize_coordinates: bool = True,
1214
+ sample_in_bbox: bool = True,
1215
+ sampling: bool = True,
1216
+ device_mesh: torch.distributed.DeviceMesh | None = None,
1217
+ placements: dict[str, torch.distributed.tensor.Placement] | None = None,
1218
+ ):
1219
+ model_type = cfg.model.model_type
1220
+ if phase == "train":
1221
+ input_path = cfg.data.input_dir
1222
+ dataloader_cfg = cfg.train.dataloader
1223
+ elif phase == "val":
1224
+ input_path = cfg.data.input_dir_val
1225
+ dataloader_cfg = cfg.val.dataloader
1226
+ elif phase == "test":
1227
+ input_path = cfg.eval.test_path
1228
+ dataloader_cfg = None
1229
+ else:
1230
+ raise ValueError(f"Invalid phase {phase}")
1231
+
1232
+ if cfg.data_processor.use_cache:
1233
+ return CachedDoMINODataset(
1234
+ input_path,
1235
+ phase=phase,
1236
+ sampling=sampling,
1237
+ volume_points_sample=cfg.model.volume_points_sample,
1238
+ surface_points_sample=cfg.model.surface_points_sample,
1239
+ geom_points_sample=cfg.model.geom_points_sample,
1240
+ model_type=cfg.model.model_type,
1241
+ surface_sampling_algorithm=cfg.model.surface_sampling_algorithm,
1242
+ )
1243
+ else:
1244
+ # The dataset path works in two pieces:
1245
+ # There is a core "dataset" which is loading data and moving to GPU
1246
+ # And there is the preprocess step, here.
1247
+
1248
+ # Optionally, and for backwards compatibility, the preprocess
1249
+ # object can accept a dataset which will enable it as an iterator.
1250
+ # The iteration function will loop over the dataset, preprocess the
1251
+ # output, and return it.
1252
+
1253
+ overrides = {}
1254
+ if hasattr(cfg.data, "gpu_preprocessing"):
1255
+ overrides["gpu_preprocessing"] = cfg.data.gpu_preprocessing
1256
+
1257
+ if hasattr(cfg.data, "gpu_output"):
1258
+ overrides["gpu_output"] = cfg.data.gpu_output
1259
+
1260
+ dm = DistributedManager()
1261
+
1262
+ if cfg.data.gpu_preprocessing:
1263
+ device = dm.device
1264
+ consumer_stream = torch.cuda.default_stream()
1265
+ else:
1266
+ device = torch.device("cpu")
1267
+ consumer_stream = None
1268
+
1269
+ if dataloader_cfg is not None:
1270
+ preload_depth = dataloader_cfg.preload_depth
1271
+ pin_memory = dataloader_cfg.pin_memory
1272
+ else:
1273
+ preload_depth = 1
1274
+ pin_memory = False
1275
+
1276
+ dataset = CAEDataset(
1277
+ data_dir=input_path,
1278
+ keys_to_read=keys_to_read,
1279
+ keys_to_read_if_available=keys_to_read_if_available,
1280
+ output_device=device,
1281
+ preload_depth=preload_depth,
1282
+ pin_memory=pin_memory,
1283
+ device_mesh=device_mesh,
1284
+ placements=placements,
1285
+ consumer_stream=consumer_stream,
1286
+ )
1287
+
1288
+ # Domain parallelism configuration:
1289
+ # (By default, the dataset will shard as aggressively as possible,
1290
+ # to improve IO speed and prevent bottlenecks - the datapipe
1291
+ # has to reshard to the final shape.)
1292
+
1293
+ # NOTE: we can always capture the mesh and placements from the dataset
1294
+ # outputs, so no need to pass them here.
1295
+ if cfg.get("domain_parallelism", {}).get("domain_size", 1) > 1:
1296
+ shard_grid = cfg.get("domain_parallelism", {}).get("shard_grid", False)
1297
+ shard_points = cfg.get("domain_parallelism", {}).get("shard_points", False)
1298
+ overrides["shard_grid"] = shard_grid
1299
+ overrides["shard_points"] = shard_points
1300
+
1301
+ datapipe = DoMINODataPipe(
1302
+ input_path,
1303
+ phase=phase,
1304
+ grid_resolution=cfg.model.interp_res,
1305
+ normalize_coordinates=normalize_coordinates,
1306
+ sampling=sampling,
1307
+ sample_in_bbox=sample_in_bbox,
1308
+ volume_points_sample=cfg.model.volume_points_sample,
1309
+ surface_points_sample=cfg.model.surface_points_sample,
1310
+ geom_points_sample=cfg.model.geom_points_sample,
1311
+ volume_factors=vol_factors,
1312
+ surface_factors=surf_factors,
1313
+ scaling_type=cfg.model.normalization,
1314
+ model_type=model_type,
1315
+ bounding_box_dims=cfg.data.bounding_box,
1316
+ bounding_box_dims_surf=cfg.data.bounding_box_surface,
1317
+ volume_sample_from_disk=cfg.data.volume_sample_from_disk,
1318
+ num_surface_neighbors=cfg.model.num_neighbors_surface,
1319
+ surface_sampling_algorithm=cfg.model.surface_sampling_algorithm,
1320
+ **overrides,
1321
+ )
1322
+
1323
+ datapipe.set_dataset(dataset)
1324
+
1325
+ return datapipe
1326
+
1327
+
1328
+ if __name__ == "__main__":
1329
+ fm_data = DoMINODataPipe(
1330
+ data_path="/code/processed_data/new_models_1/",
1331
+ phase="train",
1332
+ sampling=False,
1333
+ sample_in_bbox=False,
1334
+ )
physics_mcp/source/physicsnemo/datapipes/cae/mesh_datapipe.py ADDED
@@ -0,0 +1,490 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import numpy as np
19
+ import torch
20
+ import vtk
21
+
22
+ try:
23
+ import nvidia.dali as dali
24
+ import nvidia.dali.plugin.pytorch as dali_pth
25
+ except ImportError:
26
+ raise ImportError(
27
+ "DALI dataset requires NVIDIA DALI package to be installed. "
28
+ + "The package can be installed at:\n"
29
+ + "https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html"
30
+ )
31
+
32
+ from dataclasses import dataclass
33
+ from pathlib import Path
34
+ from typing import Iterable, List, Tuple, Union
35
+
36
+ from torch import Tensor
37
+
38
+ from physicsnemo.datapipes.datapipe import Datapipe
39
+ from physicsnemo.datapipes.meta import DatapipeMetaData
40
+
41
+ from .readers import read_cgns, read_vtp, read_vtu
42
+
43
+
44
+ @dataclass
45
+ class MetaData(DatapipeMetaData):
46
+ name: str = "MeshDatapipe"
47
+ # Optimization
48
+ auto_device: bool = True
49
+ cuda_graphs: bool = True
50
+ # Parallel
51
+ ddp_sharding: bool = True
52
+
53
+
54
+ class MeshDatapipe(Datapipe):
55
+ """DALI data pipeline for mesh data
56
+
57
+ Parameters
58
+ ----------
59
+ data_dir : str
60
+ Directory where ERA5 data is stored
61
+ variables : List[str, None]
62
+ Ordered list of variables to be loaded from the files
63
+ num_variables : int
64
+ Number of variables to be loaded from the files
65
+ file_format : str, optional
66
+ File format of the data, by default "vtp"
67
+ Supported formats: "vtp", "vtu", "cgns"
68
+ stats_dir : Union[str, None], optional
69
+ Directory where statistics are stored, by default None
70
+ If provided, the statistics are used to normalize the attributes
71
+ batch_size : int, optional
72
+ Batch size, by default 1
73
+ num_steps : int, optional
74
+ Number of timesteps are included in the output variables, by default 1
75
+ shuffle : bool, optional
76
+ Shuffle dataset, by default True
77
+ num_workers : int, optional
78
+ Number of workers, by default 1
79
+ device: Union[str, torch.device], optional
80
+ Device for DALI pipeline to run on, by default cuda
81
+ process_rank : int, optional
82
+ Rank ID of local process, by default 0
83
+ world_size : int, optional
84
+ Number of training processes, by default 1
85
+ cache_data : False, optional
86
+ Whether to cache the data in memory for faster access in subsequent epochs, by default False
87
+ Parallel: True, optional
88
+ Setting parallel=True for an external_source node indicates to the pipeline to run the source in Python worker processes started by DALI.
89
+ """
90
+
91
+ def __init__(
92
+ self,
93
+ data_dir: str,
94
+ variables: List[str],
95
+ num_variables: int,
96
+ file_format: str = "vtp",
97
+ stats_dir: Union[str, None] = None,
98
+ batch_size: int = 1,
99
+ num_samples: int = 1,
100
+ shuffle: bool = True,
101
+ num_workers: int = 1,
102
+ device: Union[str, torch.device] = "cuda",
103
+ process_rank: int = 0,
104
+ world_size: int = 1,
105
+ cache_data: bool = False,
106
+ parallel: bool = True,
107
+ ):
108
+ super().__init__(meta=MetaData())
109
+ self.file_format = file_format
110
+ self.variables = variables
111
+ self.num_variables = num_variables
112
+ self.batch_size = batch_size
113
+ self.num_workers = num_workers
114
+ self.shuffle = shuffle
115
+ self.data_dir = Path(data_dir)
116
+ self.stats_dir = Path(stats_dir) if stats_dir is not None else None
117
+ self.num_samples = num_samples
118
+ self.process_rank = process_rank
119
+ self.world_size = world_size
120
+ self.cache_data = cache_data
121
+ self.parallel = parallel
122
+
123
+ # if self.batch_size > 1:
124
+ # raise NotImplementedError("Batch size greater than 1 is not supported yet")
125
+
126
+ # Set up device, needed for pipeline
127
+ if isinstance(device, str):
128
+ device = torch.device(device)
129
+ # Need a index id if cuda
130
+ if device.type == "cuda" and device.index is None:
131
+ device = torch.device("cuda:0")
132
+ self.device = device
133
+
134
+ # check root directory exists
135
+ if not self.data_dir.is_dir():
136
+ raise IOError(f"Error, data directory {self.data_dir} does not exist")
137
+
138
+ self.parse_dataset_files()
139
+ self.load_statistics()
140
+
141
+ self.pipe = self._create_pipeline()
142
+
143
+ def parse_dataset_files(self) -> None:
144
+ """Parses the data directory for valid files and determines training samples
145
+
146
+ Raises
147
+ ------
148
+ ValueError
149
+ In channels specified or number of samples per year is not valid
150
+ """
151
+ # get all input data files
152
+ match self.file_format:
153
+ case "vtp":
154
+ pattern = "*.vtp"
155
+ case "vtu":
156
+ pattern = "*.vtu"
157
+ case "cgns":
158
+ pattern = "*.cgns"
159
+ case _:
160
+ raise NotImplementedError(
161
+ f"Data type {self.file_format} is not supported yet"
162
+ )
163
+
164
+ self.data_paths = sorted(str(path) for path in self.data_dir.glob(pattern))
165
+
166
+ for data_path in self.data_paths:
167
+ self.logger.info(f"File found: {data_path}")
168
+ self.total_samples = len(self.data_paths)
169
+
170
+ if self.num_samples > self.total_samples:
171
+ raise ValueError(
172
+ "Number of requested samples is greater than the total number of available samples!"
173
+ )
174
+ self.logger.info(
175
+ f"Total number of samples: {self.total_samples}, number of requested samples: {self.num_samples}"
176
+ )
177
+
178
+ def load_statistics(
179
+ self,
180
+ ) -> None: # TODO generalize and combine with climate/era5_hdf5 datapipes
181
+ """Loads statistics from pre-computed numpy files
182
+
183
+ The statistic files should be of name global_means.npy and global_std.npy with
184
+ a shape of [1, C] located in the stat_dir.
185
+
186
+ Raises
187
+ ------
188
+ IOError
189
+ If mean or std numpy files are not found
190
+ AssertionError
191
+ If loaded numpy arrays are not of correct size
192
+ """
193
+ # If no stats dir we just skip loading the stats
194
+ if self.stats_dir is None:
195
+ self.mu = None
196
+ self.std = None
197
+ return
198
+ # load normalisation values
199
+ mean_stat_file = self.stats_dir / Path("global_means.npy")
200
+ std_stat_file = self.stats_dir / Path("global_stds.npy")
201
+
202
+ if not mean_stat_file.exists():
203
+ raise IOError(f"Mean statistics file {mean_stat_file} not found")
204
+ if not std_stat_file.exists():
205
+ raise IOError(f"Std statistics file {std_stat_file} not found")
206
+
207
+ # has shape [1, C]
208
+ self.mu = np.load(str(mean_stat_file))[:, 0 : self.num_variables]
209
+ # has shape [1, C]
210
+ self.sd = np.load(str(std_stat_file))[:, 0 : self.num_variables]
211
+
212
+ if not self.mu.shape == self.sd.shape == (1, self.num_variables):
213
+ raise AssertionError("Error, normalisation arrays have wrong shape")
214
+
215
+ def _create_pipeline(self) -> dali.Pipeline:
216
+ """Create DALI pipeline
217
+
218
+ Returns
219
+ -------
220
+ dali.Pipeline
221
+ Mesh DALI pipeline
222
+ """
223
+ pipe = dali.Pipeline(
224
+ batch_size=self.batch_size,
225
+ num_threads=2,
226
+ prefetch_queue_depth=2,
227
+ py_num_workers=self.num_workers,
228
+ device_id=self.device.index,
229
+ py_start_method="spawn",
230
+ )
231
+
232
+ with pipe:
233
+ source = MeshDaliExternalSource(
234
+ data_paths=self.data_paths,
235
+ file_format=self.file_format,
236
+ variables=self.variables,
237
+ num_samples=self.num_samples,
238
+ batch_size=self.batch_size,
239
+ shuffle=self.shuffle,
240
+ process_rank=self.process_rank,
241
+ world_size=self.world_size,
242
+ cache_data=self.cache_data,
243
+ )
244
+ # Update length of dataset
245
+ self.length = len(source) // self.batch_size
246
+ # Read current batch.
247
+ vertices, attributes, edges = dali.fn.external_source(
248
+ source,
249
+ num_outputs=3,
250
+ parallel=self.parallel,
251
+ batch=False,
252
+ )
253
+
254
+ if self.device.type == "cuda":
255
+ # Move tensors to GPU as external_source won't do that.
256
+ vertices = vertices.gpu()
257
+ attributes = attributes.gpu()
258
+ edges = edges.gpu()
259
+
260
+ # Normalize attributes if statistics are available.
261
+ if self.stats_dir is not None:
262
+ attributes = dali.fn.normalize(attributes, mean=self.mu, stddev=self.sd)
263
+
264
+ # Set outputs.
265
+ pipe.set_outputs(vertices, attributes, edges)
266
+
267
+ return pipe
268
+
269
+ def __iter__(self):
270
+ # Reset the pipeline before creating an iterator to enable epochs.
271
+ self.pipe.reset()
272
+ # Create DALI PyTorch iterator.
273
+ return dali_pth.DALIGenericIterator([self.pipe], ["vertices", "x", "edges"])
274
+
275
+ def __len__(self):
276
+ return self.length
277
+
278
+
279
+ class MeshDaliExternalSource:
280
+ """DALI Source for lazy-loading with caching of mesh data
281
+
282
+ Parameters
283
+ ----------
284
+ data_paths : Iterable[str]
285
+ Directory where data is stored
286
+ num_samples : int
287
+ Total number of training samples
288
+ batch_size : int, optional
289
+ Batch size, by default 1
290
+ shuffle : bool, optional
291
+ Shuffle dataset, by default True
292
+ process_rank : int, optional
293
+ Rank ID of local process, by default 0
294
+ world_size : int, optional
295
+ Number of training processes, by default 1
296
+ cache_data : False, optional
297
+ Whether to cache the data in memory for faster access in subsequent epochs, by default False
298
+
299
+ Note
300
+ ----
301
+ For more information about DALI external source operator:
302
+ https://docs.nvidia.com/deeplearning/dali/archives/dali_1_13_0/user-guide/docs/examples/general/data_loading/parallel_external_source.html
303
+ """
304
+
305
+ def __init__(
306
+ self,
307
+ data_paths: Iterable[str],
308
+ file_format: str,
309
+ variables: List[str],
310
+ num_samples: int,
311
+ batch_size: int = 1,
312
+ shuffle: bool = True,
313
+ process_rank: int = 0,
314
+ world_size: int = 1,
315
+ cache_data: bool = False,
316
+ ):
317
+ self.data_paths = list(data_paths)
318
+ self.file_format = file_format
319
+ self.variables = variables
320
+ # Will be populated later once each worker starts running in its own process.
321
+ self.poly_data = None
322
+ self.num_samples = num_samples
323
+ self.batch_size = batch_size
324
+ self.shuffle = shuffle
325
+ self.cache_data = cache_data
326
+
327
+ self.last_epoch = None
328
+
329
+ self.indices = np.arange(num_samples)
330
+ # Shard from indices if running in parallel
331
+ self.indices = np.array_split(self.indices, world_size)[process_rank]
332
+
333
+ # Get number of full batches, ignore possible last incomplete batch for now.
334
+ # Also, DALI external source does not support incomplete batches in parallel mode.
335
+ self.num_batches = len(self.indices) // self.batch_size
336
+
337
+ self.mesh_reader_fn = self.mesh_reader()
338
+ self.parse_vtk_data_fn = self.parse_vtk_data()
339
+
340
+ if self.cache_data:
341
+ # Make cache for the data
342
+ self.data_cache = {}
343
+ for data_path in self.data_paths:
344
+ self.data_cache[data_path] = None
345
+
346
+ def __call__(self, sample_info: dali.types.SampleInfo) -> Tuple[Tensor, Tensor]:
347
+ if sample_info.iteration >= self.num_batches:
348
+ raise StopIteration()
349
+
350
+ # Shuffle before the next epoch starts.
351
+ if self.shuffle and sample_info.epoch_idx != self.last_epoch:
352
+ # All workers use the same rng seed so the resulting
353
+ # indices are the same across workers.
354
+ np.random.default_rng(seed=sample_info.epoch_idx).shuffle(self.indices)
355
+ self.last_epoch = sample_info.epoch_idx
356
+
357
+ # Get local indices from global index.
358
+ idx = self.indices[sample_info.idx_in_epoch]
359
+
360
+ # if self.poly_data is None: # TODO check
361
+ # This will be called once per worker. Workers are persistent,
362
+ # so there is no need to explicitly close the files - this will be done
363
+ # when corresponding pipeline/dataset is destroyed.
364
+ if self.cache_data:
365
+ processed_data = self.data_cache.get(self.data_paths[idx])
366
+ if processed_data is None:
367
+ data = self.mesh_reader_fn(self.data_paths[idx])
368
+ processed_data = self.parse_vtk_data_fn(data, self.variables)
369
+ self.data_cache[self.data_paths[idx]] = processed_data
370
+ else:
371
+ data = self.mesh_reader_fn(self.data_paths[idx])
372
+ processed_data = self.parse_vtk_data_fn(data, self.variables)
373
+
374
+ return processed_data
375
+
376
+ def __len__(self):
377
+ return len(self.indices)
378
+
379
+ def mesh_reader(self):
380
+ if self.file_format == "vtp":
381
+ return read_vtp
382
+ if self.file_format == "vtu":
383
+ return read_vtu
384
+ if self.file_format == "cgns":
385
+ return read_cgns
386
+ else:
387
+ raise NotImplementedError(
388
+ f"Data type {self.file_format} is not supported yet"
389
+ )
390
+
391
+ def parse_vtk_data(self):
392
+ if self.file_format == "vtp":
393
+ return _parse_vtk_polydata
394
+ elif self.file_format in ["vtu", "cgns"]:
395
+ return _parse_vtk_unstructuredgrid
396
+ else:
397
+ raise NotImplementedError(
398
+ f"Data type {self.file_format} is not supported yet"
399
+ )
400
+
401
+
402
+ def _parse_vtk_polydata(polydata, variables):
403
+ # Fetch vertices
404
+ points = polydata.GetPoints()
405
+ if points is None:
406
+ raise ValueError("Failed to get points from the polydata.")
407
+ vertices = torch.tensor(
408
+ np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]),
409
+ dtype=torch.float32,
410
+ )
411
+
412
+ # Fetch node attributes # TODO modularize
413
+ attributes = []
414
+ point_data = polydata.GetPointData()
415
+ if point_data is None:
416
+ raise ValueError("Failed to get point data from the unstructured grid.")
417
+ for array_name in variables:
418
+ try:
419
+ array = point_data.GetArray(array_name)
420
+ except ValueError:
421
+ raise ValueError(
422
+ f"Failed to get array {array_name} from the unstructured grid."
423
+ )
424
+ array_data = np.zeros(
425
+ (points.GetNumberOfPoints(), array.GetNumberOfComponents())
426
+ )
427
+ for j in range(points.GetNumberOfPoints()):
428
+ array.GetTuple(j, array_data[j])
429
+ attributes.append(torch.tensor(array_data, dtype=torch.float32))
430
+ attributes = torch.cat(attributes, dim=-1)
431
+ # TODO torch.cat is usually very inefficient when the number of items is large.
432
+ # If possible, the resulting tensor should be pre-allocated and filled in during the loop.
433
+
434
+ # Fetch edges
435
+ polys = polydata.GetPolys()
436
+ if polys is None:
437
+ raise ValueError("Failed to get polygons from the polydata.")
438
+ polys.InitTraversal()
439
+ edges = []
440
+ id_list = vtk.vtkIdList()
441
+ for _ in range(polys.GetNumberOfCells()):
442
+ polys.GetNextCell(id_list)
443
+ num_ids = id_list.GetNumberOfIds()
444
+ edges = [
445
+ (id_list.GetId(j), id_list.GetId((j + 1) % num_ids)) for j in range(num_ids)
446
+ ]
447
+ edges = torch.tensor(edges, dtype=torch.long)
448
+
449
+ return vertices, attributes, edges
450
+
451
+
452
+ def _parse_vtk_unstructuredgrid(grid, variables):
453
+ # Fetch vertices
454
+ points = grid.GetPoints()
455
+ if points is None:
456
+ raise ValueError("Failed to get points from the unstructured grid.")
457
+ vertices = torch.tensor(
458
+ np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]),
459
+ dtype=torch.float32,
460
+ )
461
+
462
+ # Fetch node attributes # TODO modularize
463
+ attributes = []
464
+ point_data = grid.GetPointData()
465
+ if point_data is None:
466
+ raise ValueError("Failed to get point data from the unstructured grid.")
467
+ for array_name in variables:
468
+ try:
469
+ array = point_data.GetArray(array_name)
470
+ except ValueError:
471
+ raise ValueError(
472
+ f"Failed to get array {array_name} from the unstructured grid."
473
+ )
474
+ array_data = np.zeros(
475
+ (points.GetNumberOfPoints(), array.GetNumberOfComponents())
476
+ )
477
+ for j in range(points.GetNumberOfPoints()):
478
+ array.GetTuple(j, array_data[j])
479
+ attributes.append(torch.tensor(array_data, dtype=torch.float32))
480
+ if variables:
481
+ attributes = torch.cat(attributes, dim=-1)
482
+ else:
483
+ attributes = torch.zeros((1,), dtype=torch.float32)
484
+
485
+ # Return a dummy tensor of zeros for edges since they are not directly computable
486
+ return (
487
+ vertices,
488
+ attributes,
489
+ torch.zeros((0, 2), dtype=torch.long),
490
+ ) # Dummy tensor for edges
physics_mcp/source/physicsnemo/datapipes/cae/readers.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import os
18
+ from typing import Any
19
+
20
+ import torch
21
+ import vtk
22
+
23
+ Tensor = torch.Tensor
24
+
25
+
26
+ def read_vtp(file_path: str) -> Any: # TODO add support for older format (VTK)
27
+ """
28
+ Read a VTP file and return the polydata.
29
+
30
+ Parameters
31
+ ----------
32
+ file_path : str
33
+ Path to the VTP file.
34
+
35
+ Returns
36
+ -------
37
+ vtkPolyData
38
+ The polydata read from the VTP file.
39
+ """
40
+ # Check if file exists
41
+ if not os.path.exists(file_path):
42
+ raise FileNotFoundError(f"{file_path} does not exist.")
43
+
44
+ # Check if file has .vtp extension
45
+ if not file_path.endswith(".vtp"):
46
+ raise ValueError(f"Expected a .vtp file, got {file_path}")
47
+
48
+ reader = vtk.vtkXMLPolyDataReader()
49
+ reader.SetFileName(file_path)
50
+ reader.Update()
51
+
52
+ # Get the polydata
53
+ polydata = reader.GetOutput()
54
+
55
+ # Check if polydata is valid
56
+ if polydata is None:
57
+ raise ValueError(f"Failed to read polydata from {file_path}")
58
+
59
+ return polydata
60
+
61
+
62
+ def read_vtu(file_path: str) -> Any:
63
+ """
64
+ Read a VTU file and return the unstructured grid data.
65
+
66
+ Parameters
67
+ ----------
68
+ file_path : str
69
+ Path to the VTU file.
70
+
71
+ Returns
72
+ -------
73
+ vtkUnstructuredGrid
74
+ The unstructured grid data read from the VTU file.
75
+ """
76
+ # Check if file exists
77
+ if not os.path.exists(file_path):
78
+ raise FileNotFoundError(f"{file_path} does not exist.")
79
+
80
+ # Check if file has .vtu extension
81
+ if not file_path.endswith(".vtu"):
82
+ raise ValueError(f"Expected a .vtu file, got {file_path}")
83
+
84
+ reader = vtk.vtkXMLUnstructuredGridReader()
85
+ reader.SetFileName(file_path)
86
+ reader.Update()
87
+
88
+ # Get the unstructured grid data
89
+ grid = reader.GetOutput()
90
+
91
+ # Check if grid is valid
92
+ if grid is None:
93
+ raise ValueError(f"Failed to read unstructured grid data from {file_path}")
94
+
95
+ return grid
96
+
97
+
98
+ def read_cgns(file_path: str) -> Any:
99
+ """
100
+ Read a CGNS file and return the unstructured grid data.
101
+
102
+ Parameters
103
+ ----------
104
+ file_path : str
105
+ Path to the CGNS file.
106
+
107
+ Returns
108
+ -------
109
+ vtkUnstructuredGrid
110
+ The unstructured grid data read from the CGNS file.
111
+ """
112
+ # Check if file exists
113
+ if not os.path.exists(file_path):
114
+ raise FileNotFoundError(f"{file_path} does not exist.")
115
+
116
+ # Check if file has .cgns extension
117
+ if not file_path.endswith(".cgns"):
118
+ raise ValueError(f"Expected a .cgns file, got {file_path}")
119
+
120
+ reader = vtk.vtkCGNSReader()
121
+ reader.SetFileName(file_path)
122
+ reader.Update()
123
+
124
+ # Get the multi-block dataset
125
+ multi_block = reader.GetOutput()
126
+
127
+ # Check if the multi-block dataset is valid
128
+ if multi_block is None:
129
+ raise ValueError(f"Failed to read multi-block data from {file_path}")
130
+
131
+ # Extract and return the vtkUnstructuredGrid from the multi-block dataset
132
+ return _extract_unstructured_grid(multi_block)
133
+
134
+
135
+ def read_stl(file_path: str) -> vtk.vtkPolyData:
136
+ """
137
+ Read an STL file and return the polydata.
138
+
139
+ Parameters
140
+ ----------
141
+ file_path : str
142
+ Path to the STL file.
143
+
144
+ Returns
145
+ -------
146
+ vtkPolyData
147
+ The polydata read from the STL file.
148
+ """
149
+ # Check if file exists
150
+ if not os.path.exists(file_path):
151
+ raise FileNotFoundError(f"{file_path} does not exist.")
152
+
153
+ # Check if file has .stl extension
154
+ if not file_path.endswith(".stl"):
155
+ raise ValueError(f"Expected a .stl file, got {file_path}")
156
+
157
+ # Create an STL reader
158
+ reader = vtk.vtkSTLReader()
159
+ reader.SetFileName(file_path)
160
+ reader.Update()
161
+
162
+ # Get the polydata
163
+ polydata = reader.GetOutput()
164
+
165
+ # Check if polydata is valid
166
+ if polydata is None:
167
+ raise ValueError(f"Failed to read polydata from {file_path}")
168
+
169
+ return polydata
170
+
171
+
172
+ def _extract_unstructured_grid(
173
+ multi_block: vtk.vtkMultiBlockDataSet,
174
+ ) -> vtk.vtkUnstructuredGrid:
175
+ """
176
+ Extracts a vtkUnstructuredGrid from a vtkMultiBlockDataSet.
177
+
178
+ Parameters
179
+ ----------
180
+ multi_block : vtk.vtkMultiBlockDataSet
181
+ The multi-block dataset containing various data blocks.
182
+
183
+ Returns
184
+ -------
185
+ vtk.vtkUnstructuredGrid
186
+ The unstructured grid extracted from the multi-block dataset.
187
+ """
188
+ block = multi_block.GetBlock(0).GetBlock(0)
189
+ if isinstance(block, vtk.vtkUnstructuredGrid):
190
+ return block
191
+ raise ValueError("No vtkUnstructuredGrid found in the vtkMultiBlockDataSet.")
physics_mcp/source/physicsnemo/datapipes/climate/__init__.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from .climate import ClimateDatapipe, ClimateDataSourceSpec
18
+ from .era5_hdf5 import ERA5HDF5Datapipe
19
+ from .synthetic import SyntheticWeatherDataLoader, SyntheticWeatherDataset
physics_mcp/source/physicsnemo/datapipes/climate/climate.py ADDED
@@ -0,0 +1,813 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import json
19
+ from abc import ABC, abstractmethod
20
+ from datetime import datetime, timedelta
21
+ from itertools import chain
22
+
23
+ import h5py
24
+ import netCDF4 as nc
25
+ import numpy as np
26
+ import pytz
27
+ import torch
28
+
29
+ try:
30
+ import nvidia.dali as dali
31
+ import nvidia.dali.plugin.pytorch as dali_pth
32
+ except ImportError:
33
+ raise ImportError(
34
+ "DALI dataset requires NVIDIA DALI package to be installed. "
35
+ + "The package can be installed at:\n"
36
+ + "https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html"
37
+ )
38
+
39
+ from dataclasses import dataclass
40
+ from pathlib import Path
41
+ from typing import Callable, Iterable, List, Mapping, Tuple, Union
42
+
43
+ from scipy.io import netcdf_file
44
+
45
+ from physicsnemo.datapipes.climate.utils.invariant import latlon_grid
46
+ from physicsnemo.datapipes.climate.utils.zenith_angle import cos_zenith_angle
47
+ from physicsnemo.datapipes.datapipe import Datapipe
48
+ from physicsnemo.datapipes.meta import DatapipeMetaData
49
+ from physicsnemo.launch.logging import PythonLogger
50
+
51
+ Tensor = torch.Tensor
52
+
53
+
54
+ @dataclass
55
+ class MetaData(DatapipeMetaData):
56
+ name: str = "Climate"
57
+ # Optimization
58
+ auto_device: bool = True
59
+ cuda_graphs: bool = True
60
+ # Parallel
61
+ ddp_sharding: bool = True
62
+
63
+
64
+ class ClimateDataSourceSpec:
65
+ """
66
+ A data source specification for ClimateDatapipe.
67
+
68
+ HDF5 files should contain the following variable with the corresponding
69
+ name:
70
+ `fields`: Tensor of shape (num_timesteps, num_channels, height, width),
71
+ containing climate data. The order of the channels should match the order
72
+ of the channels in the statistics files. The statistics files should be
73
+ `.npy` files with the shape (1, num_channels, 1, 1).
74
+ The names of the variables are found in the metadata file found in
75
+ `metadata_path`.
76
+
77
+ NetCDF4 files should contain a variable of shape
78
+ (num_timesteps, height, width) for each variable they provide. Only the
79
+ variables listed in `variables` will be loaded.
80
+
81
+ Parameters
82
+ ----------
83
+ data_dir : str
84
+ Directory where climate data is stored
85
+ name: Union[str, None], optional
86
+ The name that is used to label datapipe outputs from this source.
87
+ If None, the datapipe uses the number of the source in sequential order.
88
+ file_type: str
89
+ Type of files to read, supported values are "hdf5" (default) and "netcdf4"
90
+ stats_files: Union[Mapping[str, str], None], optional
91
+ Numpy files to data statistics for normalization. Supports either a channels
92
+ format, in which case the dict should contain the keys "mean" and "std", or a
93
+ named-variable format, in which case the dict should contain the key "norm" .
94
+ If None, no normalization will be used, by default None
95
+ metadata_path: Union[Mapping[str, str], None], optional for NetCDF, required for HDF5
96
+ Path to the metadata JSON file for the dataset (usually called data.json).
97
+ channels : Union[List[int], None], optional
98
+ Defines which climate variables to load, if None will use all in HDF5 file, by default None
99
+ variables: Union[List[str], None], optional for HDF5 files, mandatory for NetCDF4 files
100
+ List of named variables to load. Variables will be read in the order specified
101
+ by this parameter. Must be used for NetCDF4 files. Supported for HDF5 files
102
+ in which case it will override `channels`.
103
+ use_cos_zenith: bool, optional
104
+ If True, the cosine zenith angles corresponding to the coordinates of this
105
+ data source will be produced, default False
106
+ aux_variables : Union[Mapping[str, Callable], None], optional
107
+ A dictionary mapping strings to callables that accept arguments
108
+ (timestamps: numpy.ndarray, latlon: numpy.ndarray). These define any auxiliary
109
+ variables returned from this source.
110
+ num_steps : int, optional
111
+ Number of timesteps to return, by default 1
112
+ stride : int, optional
113
+ Number of steps between input and output variables. For example, if the dataset
114
+ contains data at every 6 hours, a stride 1 = 6 hour delta t and
115
+ stride 2 = 12 hours delta t, by default 1
116
+ """
117
+
118
+ def __init__(
119
+ self,
120
+ data_dir: str,
121
+ name: Union[str, None] = None,
122
+ file_type: str = "hdf5",
123
+ stats_files: Union[Mapping[str, str], None] = None,
124
+ metadata_path: Union[str, None] = None,
125
+ channels: Union[List[int], None] = None,
126
+ variables: Union[List[str], None] = None,
127
+ use_cos_zenith: bool = False,
128
+ aux_variables: Union[Mapping[str, Callable], None] = None,
129
+ num_steps: int = 1,
130
+ stride: int = 1,
131
+ backend_kwargs: Union[dict, None] = None,
132
+ ):
133
+ self.data_dir = Path(data_dir)
134
+ self.name = name
135
+ self.file_type = file_type
136
+ self.stats_files = (
137
+ {k: Path(fn) for (k, fn) in stats_files.items()}
138
+ if stats_files is not None
139
+ else None
140
+ )
141
+ self.metadata_path = Path(metadata_path) if metadata_path is not None else None
142
+ self.channels = channels
143
+ self.variables = variables
144
+ self.use_cos_zenith = use_cos_zenith
145
+ self.aux_variables = aux_variables if aux_variables is not None else {}
146
+ self.num_steps = num_steps
147
+ self.stride = stride
148
+ self.backend_kwargs = {} if backend_kwargs is None else backend_kwargs
149
+ self.logger = PythonLogger()
150
+
151
+ if file_type == "netcdf4" and not variables:
152
+ raise ValueError("Variables must be specified for a NetCDF4 source.")
153
+
154
+ # check root directory exists
155
+ if not self.data_dir.is_dir():
156
+ raise IOError(f"Error, data directory {self.data_dir} does not exist")
157
+ if self.stats_files is None:
158
+ self.logger.warning(
159
+ "Warning, no stats files specified, this will result in no normalisation"
160
+ )
161
+
162
+ def dimensions_compatible(self, other) -> bool:
163
+ """
164
+ Basic sanity check to test if two `ClimateDataSourceSpec` are
165
+ compatible.
166
+ """
167
+ return (
168
+ self.data_shape == other.data_shape
169
+ and self.cropped_data_shape == other.cropped_data_shape
170
+ and self.num_samples_per_year == other.num_samples_per_year
171
+ and self.total_length == other.total_length
172
+ and self.n_years == other.n_years
173
+ )
174
+
175
+ def parse_dataset_files(
176
+ self,
177
+ num_samples_per_year: Union[int, None] = None,
178
+ patch_size: Union[int, None] = None,
179
+ ) -> None:
180
+ """Parses the data directory for valid files and determines training samples
181
+
182
+ Parameters
183
+ ----------
184
+ num_samples_per_year : int, optional
185
+ Number of samples taken from each year. If None, all will be used, by default None
186
+ patch_size : Union[Tuple[int, int], int, None], optional
187
+ If specified, crops input and output variables so image dimensions are
188
+ divisible by patch_size, by default None
189
+
190
+ Raises
191
+ ------
192
+ ValueError
193
+ In channels specified or number of samples per year is not valid
194
+ """
195
+ # get all input data files
196
+ suffix = {"hdf5": "h5", "netcdf4": "nc"}[self.file_type]
197
+ self.data_paths = sorted(self.data_dir.glob(f"*.{suffix}"))
198
+ for data_path in self.data_paths:
199
+ self.logger.info(f"Climate data file found: {data_path}")
200
+ self.n_years = len(self.data_paths)
201
+ self.logger.info(f"Number of years: {self.n_years}")
202
+
203
+ # get total number of examples and image shape from the first file,
204
+ # assuming other files have exactly the same format.
205
+ self.logger.info(f"Getting file stats from {self.data_paths[0]}")
206
+ if self.file_type == "hdf5":
207
+ with h5py.File(self.data_paths[0], "r") as f:
208
+ dataset_shape = f["fields"].shape
209
+ else:
210
+ with nc.Dataset(self.data_paths[0], "r") as f:
211
+ var_shape = f[self.variables[0]].shape
212
+ dataset_shape = (var_shape[0], len(self.variables)) + var_shape[1:]
213
+
214
+ # truncate the dataset to avoid out-of-range sampling
215
+ data_samples_per_year = dataset_shape[0] - (self.num_steps - 1) * self.stride
216
+ self.data_shape = dataset_shape[2:]
217
+
218
+ # interpret list of variables into list of channels or vice versa
219
+ if self.file_type == "hdf5":
220
+ with open(self.metadata_path, "r") as f:
221
+ metadata = json.load(f)
222
+ data_vars = metadata["coords"]["channel"]
223
+ if self.variables is not None:
224
+ self.channels = [data_vars.index(v) for v in self.variables]
225
+ else:
226
+ if self.channels is None:
227
+ self.variables = data_vars
228
+ else:
229
+ self.variables = [data_vars[i] for i in self.channels]
230
+
231
+ # If channels not provided, use all of them
232
+ if self.channels is None:
233
+ self.channels = list(range(dataset_shape[1]))
234
+
235
+ # If num_samples_per_year use all
236
+ if num_samples_per_year is None:
237
+ num_samples_per_year = data_samples_per_year
238
+ self.num_samples_per_year = num_samples_per_year
239
+
240
+ # Adjust image shape if patch_size defined
241
+ if patch_size is not None:
242
+ self.cropped_data_shape = tuple(
243
+ s - s % patch_size[i] for i, s in enumerate(self.data_shape)
244
+ )
245
+ else:
246
+ self.cropped_data_shape = self.data_shape
247
+ self.logger.info(f"Input data shape: {self.cropped_data_shape}")
248
+
249
+ # Get total length
250
+ self.total_length = self.n_years * self.num_samples_per_year
251
+
252
+ # Sanity checks
253
+ if max(self.channels) >= dataset_shape[1]:
254
+ raise ValueError(
255
+ f"Provided channel has indexes greater than the number \
256
+ of fields {dataset_shape[1]}"
257
+ )
258
+
259
+ if self.num_samples_per_year > data_samples_per_year:
260
+ raise ValueError(
261
+ f"num_samples_per_year ({self.num_samples_per_year}) > number of \
262
+ samples available ({data_samples_per_year})!"
263
+ )
264
+
265
+ self._load_statistics()
266
+
267
+ self.logger.info(f"Number of samples/year: {self.num_samples_per_year}")
268
+ self.logger.info(f"Number of channels available: {dataset_shape[1]}")
269
+
270
+ def _load_statistics(self) -> None:
271
+ """Loads climate statistics from pre-computed numpy files
272
+
273
+ The statistic files should be of name global_means.npy and global_std.npy with
274
+ a shape of [1, C, 1, 1] located in the stat_dir.
275
+
276
+ Raises
277
+ ------
278
+ IOError
279
+ If statistics files are not found
280
+ AssertionError
281
+ If loaded numpy arrays are not of correct size
282
+ """
283
+ # If no stats files we just skip loading the stats
284
+ if self.stats_files is None:
285
+ self.mu = None
286
+ self.sd = None
287
+ return
288
+ # load normalisation values
289
+ if set(self.stats_files) == {"mean", "std"}: # use mean and std files
290
+ mean_stat_file = self.stats_files["mean"]
291
+ std_stat_file = self.stats_files["std"]
292
+
293
+ if not mean_stat_file.exists():
294
+ raise IOError(f"Mean statistics file {mean_stat_file} not found")
295
+ if not std_stat_file.exists():
296
+ raise IOError(f"Std statistics file {std_stat_file} not found")
297
+
298
+ # has shape [1, C, 1, 1]
299
+ self.mu = np.load(str(mean_stat_file))[:, self.channels]
300
+ # has shape [1, C, 1, 1]
301
+ self.sd = np.load(str(std_stat_file))[:, self.channels]
302
+ elif set(self.stats_files) == {
303
+ "norm",
304
+ }: # use dict formatted file with named variables
305
+ norm_stat_file = self.stats_files["norm"]
306
+ if not norm_stat_file.exists():
307
+ raise IOError(f"Statistics file {norm_stat_file} not found")
308
+
309
+ norm = np.load(str(norm_stat_file), allow_pickle=True).item()
310
+ mu = np.array([norm[var]["mean"] for var in self.variables])
311
+ self.mu = mu.reshape((1, len(mu), 1, 1))
312
+ sd = np.array([norm[var]["std"] for var in self.variables])
313
+ self.sd = sd.reshape((1, len(sd), 1, 1))
314
+ else:
315
+ raise ValueError(("Invalid statistics file specification"))
316
+
317
+ if not self.mu.shape == self.sd.shape == (1, len(self.channels), 1, 1):
318
+ raise ValueError("Error, normalisation arrays have wrong shape")
319
+
320
+
321
+ class ClimateDatapipe(Datapipe):
322
+ """
323
+ A Climate DALI data pipeline. This pipeline loads data from
324
+ HDF5/NetCDF4 files. It can also return additional data such as the
325
+ solar zenith angle for each time step. Additionally, it normalizes
326
+ the data if a statistics file is provided. The pipeline returns a dictionary
327
+ with the following structure, where {name} indicates the name of the data
328
+ source provided:
329
+
330
+ - ``state_seq-{name}``: Tensors of shape
331
+ (batch_size, num_steps, num_channels, height, width).
332
+ This sequence is drawn from the data file and normalized if a
333
+ statistics file is provided.
334
+ - ``timestamps-{name}``: Tensors of shape (batch_size, num_steps), containing
335
+ timestamps for each timestep in the sequence.
336
+ - ``{aux_variable}-{name}``: Tensors of shape
337
+ (batch_size, num_steps, aux_channels, height, width),
338
+ containing the auxiliary variables returned by each data source
339
+ - ``cos_zenith-{name}``: Tensors of shape (batch_size, num_steps, 1, height, width),
340
+ containing the cosine of the solar zenith angle if specified.
341
+ - ``{invariant_name}``: Tensors of shape (batch_size, invariant_channels, height, width),
342
+ containing the time-invariant data (depending only on spatial coordinates)
343
+ returned by the datapipe. These can include e.g.
344
+ land-sea mask and geopotential/surface elevation.
345
+
346
+ To use this data pipeline, your data directory must be structured as
347
+ follows:
348
+ ```
349
+ data_dir
350
+ ├── 1980.h5
351
+ ├── 1981.h5
352
+ ├── 1982.h5
353
+ ├── ...
354
+ └── 2020.h5
355
+ ```
356
+
357
+ The files are assumed have no metadata, such as timestamps.
358
+ Because of this, it's important to specify the `dt` parameter and the
359
+ `start_year` parameter so that the pipeline can compute the correct
360
+ timestamps for each timestep. These timestamps are then used to compute the
361
+ cosine of the solar zenith angle, if specified.
362
+
363
+ Parameters
364
+ ----------
365
+ sources: Iterable[ClimateDataSourceSpec]
366
+ A list of data specifications defining the sources for the climate variables
367
+ batch_size : int, optional
368
+ Batch size, by default 1
369
+ dt : float, optional
370
+ Time in hours between each timestep in the dataset, by default 6 hr
371
+ start_year : int, optional
372
+ Start year of dataset, by default 1980
373
+ latlon_bounds : Tuple[Tuple[float, float], Tuple[float, float]], optional
374
+ Bounds of latitude and longitude in the data, in the format
375
+ ((lat_start, lat_end,), (lon_start, lon_end)).
376
+ By default ((90, -90), (0, 360)).
377
+ crop_window: Union[Tuple[Tuple[float, float], Tuple[float, float]], None], optional
378
+ The window to crop the data to, in the format ((i0,i1), (j0,j1)) where the
379
+ first spatial dimension will be cropped to i0:i1 and the second to j0:j1.
380
+ If not given, all data will be used.
381
+ invariants : Mapping[str,Callable], optional
382
+ Specifies the time-invariant data (for example latitude and longitude)
383
+ included in the data samples. Should be a dict where the keys are the
384
+ names of the invariants and the values are the corresponding
385
+ functions. The functions need to accept an argument of the shape
386
+ (2, data_shape[0], data_shape[1]) where the first dimension contains
387
+ latitude and longitude in degrees and the other dimensions corresponding
388
+ to the shape of data in the data files. For example,
389
+ invariants={"trig_latlon": invariants.LatLon()}
390
+ will include the sin/cos of lat/lon in the output.
391
+ num_samples_per_year : int, optional
392
+ Number of samples taken from each year. If None, all will be used, by default None
393
+ shuffle : bool, optional
394
+ Shuffle dataset, by default True
395
+ num_workers : int, optional
396
+ Number of workers, by default 1
397
+ device: Union[str, torch.device], optional
398
+ Device for DALI pipeline to run on, by default cuda
399
+ process_rank : int, optional
400
+ Rank ID of local process, by default 0
401
+ world_size : int, optional
402
+ Number of training processes, by default 1
403
+ """
404
+
405
+ def __init__(
406
+ self,
407
+ sources: Iterable[ClimateDataSourceSpec],
408
+ batch_size: int = 1,
409
+ dt: float = 6.0,
410
+ start_year: int = 1980,
411
+ latlon_bounds: Tuple[Tuple[float, float], Tuple[float, float]] = (
412
+ (90, -90),
413
+ (0, 360),
414
+ ),
415
+ crop_window: Union[
416
+ Tuple[Tuple[float, float], Tuple[float, float]], None
417
+ ] = None,
418
+ invariants: Union[Mapping[str, Callable], None] = None,
419
+ num_samples_per_year: Union[int, None] = None,
420
+ shuffle: bool = True,
421
+ num_workers: int = 1, # TODO: is there a faster good default?
422
+ device: Union[str, torch.device] = "cuda",
423
+ process_rank: int = 0,
424
+ world_size: int = 1,
425
+ ):
426
+ super().__init__(meta=MetaData())
427
+ self.sources = list(sources)
428
+ self.batch_size = batch_size
429
+ self.num_workers = num_workers
430
+ self.shuffle = shuffle
431
+ self.dt = dt
432
+ self.start_year = start_year
433
+ self.data_latlon_bounds = latlon_bounds
434
+ self.process_rank = process_rank
435
+ self.world_size = world_size
436
+ self.num_samples_per_year = num_samples_per_year
437
+ self.logger = PythonLogger()
438
+
439
+ if invariants is None:
440
+ invariants = {}
441
+
442
+ # Determine outputs of pipeline
443
+ self.pipe_outputs = []
444
+ for i, spec in enumerate(self.sources):
445
+ name = spec.name if spec.name is not None else i
446
+ self.pipe_outputs += [f"state_seq-{name}", f"timestamps-{name}"]
447
+ self.pipe_outputs.extend(
448
+ f"{aux_var}-{name}" for aux_var in spec.aux_variables
449
+ )
450
+ if spec.use_cos_zenith:
451
+ self.pipe_outputs.append(f"cos_zenith-{name}")
452
+ self.pipe_outputs.extend(invariants.keys())
453
+
454
+ # Set up device, needed for pipeline
455
+ if isinstance(device, str):
456
+ device = torch.device(device)
457
+
458
+ # Need a index id if cuda
459
+ if device.type == "cuda" and device.index is None:
460
+ device = torch.device("cuda:0")
461
+ self.device = device
462
+
463
+ # Load all data files and statistics
464
+ for spec in sources:
465
+ spec.parse_dataset_files(num_samples_per_year=num_samples_per_year)
466
+ for i, spec_i in enumerate(sources):
467
+ for spec_j in sources[i + 1 :]:
468
+ if not spec_i.dimensions_compatible(spec_j):
469
+ raise ValueError("Incompatible data sources")
470
+
471
+ self.data_latlon = np.stack(
472
+ latlon_grid(bounds=self.data_latlon_bounds, shape=sources[0].data_shape),
473
+ axis=0,
474
+ )
475
+ if crop_window is None:
476
+ crop_window = (
477
+ (0, sources[0].cropped_data_shape[0]),
478
+ (0, sources[0].cropped_data_shape[1]),
479
+ )
480
+ self.crop_window = crop_window
481
+ self.window_latlon = self._crop_to_window(self.data_latlon)
482
+ self.window_latlon_dali = dali.types.Constant(self.window_latlon)
483
+
484
+ # load invariants
485
+ self.invariants = {
486
+ var: callback(self.window_latlon) for (var, callback) in invariants.items()
487
+ }
488
+
489
+ # Create pipeline
490
+ self.pipe = self._create_pipeline()
491
+
492
+ def _source_cls_from_type(self, source_type: str) -> type:
493
+ """Get the external source class based on a string descriptor."""
494
+ return {
495
+ "hdf5": ClimateHDF5DaliExternalSource,
496
+ "netcdf4": ClimateNetCDF4DaliExternalSource,
497
+ }[source_type]
498
+
499
+ def _crop_to_window(self, x):
500
+ cw = self.crop_window
501
+ if isinstance(x, dali.pipeline.DataNode):
502
+ # DALI doesn't support ellipsis notation
503
+ return x[:, :, cw[0][0] : cw[0][1], cw[1][0] : cw[1][1]]
504
+ else:
505
+ return x[..., cw[0][0] : cw[0][1], cw[1][0] : cw[1][1]]
506
+
507
+ def _source_outputs(self, spec: ClimateDataSourceSpec) -> List:
508
+ """Create DALI outputs for a given data source specification.
509
+
510
+ Parameters
511
+ ----------
512
+ spec: ClimateDataSourceSpec
513
+ The data source specification.
514
+ """
515
+ # HDF5/NetCDF source
516
+ source_cls = self._source_cls_from_type(spec.file_type)
517
+ source = source_cls(
518
+ data_paths=spec.data_paths,
519
+ num_samples=spec.total_length,
520
+ channels=spec.channels,
521
+ latlon=self.data_latlon,
522
+ variables=spec.variables,
523
+ aux_variables=spec.aux_variables,
524
+ stride=spec.stride,
525
+ dt=self.dt,
526
+ start_year=self.start_year,
527
+ num_steps=spec.num_steps,
528
+ num_samples_per_year=spec.num_samples_per_year,
529
+ batch_size=self.batch_size,
530
+ shuffle=self.shuffle,
531
+ process_rank=self.process_rank,
532
+ world_size=self.world_size,
533
+ )
534
+
535
+ # Update length of dataset
536
+ self.total_length = len(source) // self.batch_size
537
+
538
+ # Read current batch
539
+ (state_seq, timestamps, *aux) = dali.fn.external_source(
540
+ source,
541
+ num_outputs=source.num_outputs(),
542
+ parallel=True,
543
+ batch=False,
544
+ )
545
+
546
+ # Crop
547
+ state_seq = self._crop_to_window(state_seq)
548
+ aux = (self._crop_to_window(x) for x in aux)
549
+
550
+ # Normalize
551
+ if spec.stats_files is not None:
552
+ state_seq = dali.fn.normalize(state_seq, mean=spec.mu, stddev=spec.sd)
553
+
554
+ # Make output list
555
+ outputs = [state_seq, timestamps, *aux]
556
+
557
+ # Get cosine zenith angle
558
+ if spec.use_cos_zenith:
559
+ cos_zenith = dali.fn.cast(
560
+ cos_zenith_angle(timestamps, latlon=self.window_latlon_dali),
561
+ dtype=dali.types.FLOAT,
562
+ )
563
+ outputs.append(cos_zenith)
564
+
565
+ return outputs
566
+
567
+ def _invariant_outputs(self):
568
+ for inv in self.invariants.values():
569
+ if self.crop_window is not None:
570
+ inv = self._crop_to_window(inv)
571
+ yield dali.types.Constant(inv)
572
+
573
+ def _create_pipeline(self) -> dali.Pipeline:
574
+ """Create DALI pipeline
575
+
576
+ Returns
577
+ -------
578
+ dali.Pipeline
579
+ Climate DALI pipeline
580
+ """
581
+ pipe = dali.Pipeline(
582
+ batch_size=self.batch_size,
583
+ num_threads=2,
584
+ prefetch_queue_depth=2,
585
+ py_num_workers=self.num_workers,
586
+ device_id=self.device.index,
587
+ py_start_method="spawn",
588
+ )
589
+
590
+ with pipe:
591
+ # Concatenate outputs from all sources as well as invariants
592
+ outputs = list(
593
+ chain(
594
+ *(self._source_outputs(spec) for spec in self.sources),
595
+ self._invariant_outputs(),
596
+ )
597
+ )
598
+
599
+ if self.device.type == "cuda":
600
+ # Move tensors to GPU as external_source won't do that
601
+ outputs = [o.gpu() for o in outputs]
602
+
603
+ # Set outputs
604
+ pipe.set_outputs(*outputs)
605
+
606
+ return pipe
607
+
608
+ def __iter__(self):
609
+ # Reset the pipeline before creating an iterator to enable epochs.
610
+ self.pipe.reset()
611
+ # Create DALI PyTorch iterator.
612
+ return dali_pth.DALIGenericIterator([self.pipe], self.pipe_outputs)
613
+
614
+ def __len__(self):
615
+ return self.total_length
616
+
617
+
618
+ class ClimateDaliExternalSource(ABC):
619
+ """DALI Source for lazy-loading the HDF5/NetCDF4 climate files
620
+
621
+ Parameters
622
+ ----------
623
+ data_paths : Iterable[str]
624
+ Directory where climate data is stored
625
+ num_samples : int
626
+ Total number of training samples
627
+ channels : Iterable[int]
628
+ List representing which climate variables to load
629
+ num_steps : int
630
+ Number of timesteps to load
631
+ stride : int
632
+ Number of steps between input and output variables
633
+ dt : float, optional
634
+ Time in hours between each timestep in the dataset, by default 6 hr
635
+ start_year : int, optional
636
+ Start year of dataset, by default 1980
637
+ num_samples_per_year : int
638
+ Number of samples randomly taken from each year
639
+ variables: Union[List[str], None], optional for HDF5 files, mandatory for NetCDF4 files
640
+ List of named variables to load. Variables will be read in the order specified
641
+ by this parameter.
642
+ aux_variables : Union[Mapping[str, Callable], None], optional
643
+ A dictionary mapping strings to callables that accept arguments
644
+ (timestamps: numpy.ndarray, latlon: numpy.ndarray). These define any auxiliary
645
+ variables returned from this source.
646
+ batch_size : int, optional
647
+ Batch size, by default 1
648
+ shuffle : bool, optional
649
+ Shuffle dataset, by default True
650
+ process_rank : int, optional
651
+ Rank ID of local process, by default 0
652
+ world_size : int, optional
653
+ Number of training processes, by default 1
654
+
655
+ Note
656
+ ----
657
+ For more information about DALI external source operator:
658
+ https://docs.nvidia.com/deeplearning/dali/archives/dali_1_13_0/user-guide/docs/examples/general/data_loading/parallel_external_source.html
659
+ """
660
+
661
+ def __init__(
662
+ self,
663
+ data_paths: Iterable[str],
664
+ num_samples: int,
665
+ channels: Iterable[int],
666
+ num_steps: int,
667
+ stride: int,
668
+ dt: float,
669
+ start_year: int,
670
+ num_samples_per_year: int,
671
+ latlon: np.ndarray,
672
+ variables: Union[List[str], None] = None,
673
+ aux_variables: List[Union[str, Callable]] = (),
674
+ batch_size: int = 1,
675
+ shuffle: bool = True,
676
+ process_rank: int = 0,
677
+ world_size: int = 1,
678
+ backend_kwargs: Union[dict, None] = None,
679
+ ):
680
+ self.data_paths = list(data_paths)
681
+ # Will be populated later once each worker starts running in its own process.
682
+ self.data_files = [None] * len(self.data_paths)
683
+ self.num_samples = num_samples
684
+ self.chans = list(channels)
685
+ self.latlon = latlon
686
+ self.variables = variables
687
+ self.aux_variables = aux_variables
688
+ self.num_steps = num_steps
689
+ self.stride = stride
690
+ self.dt = dt
691
+ self.start_year = start_year
692
+ self.num_samples_per_year = num_samples_per_year
693
+ self.batch_size = batch_size
694
+ self.shuffle = shuffle
695
+ self.backend_kwargs = {} if backend_kwargs is None else backend_kwargs
696
+
697
+ self.last_epoch = None
698
+
699
+ self.indices = np.arange(num_samples)
700
+ # Shard from indices if running in parallel
701
+ self.indices = np.array_split(self.indices, world_size)[process_rank]
702
+
703
+ # Get number of full batches, ignore possible last incomplete batch for now.
704
+ # Also, DALI external source does not support incomplete batches in parallel mode.
705
+ self.num_batches = len(self.indices) // self.batch_size
706
+
707
+ @abstractmethod
708
+ def _load_sequence(self, year_idx: int, idx: int) -> np.array:
709
+ """Write data from year index `year_idx` and sample index `idx` to output"""
710
+ pass
711
+
712
+ def __call__(self, sample_info: dali.types.SampleInfo) -> Tuple[Tensor, np.ndarray]:
713
+ if sample_info.iteration >= self.num_batches:
714
+ raise StopIteration()
715
+
716
+ # Shuffle before the next epoch starts
717
+ if self.shuffle and sample_info.epoch_idx != self.last_epoch:
718
+ # All workers use the same rng seed so the resulting
719
+ # indices are the same across workers
720
+ np.random.default_rng(seed=sample_info.epoch_idx).shuffle(self.indices)
721
+ self.last_epoch = sample_info.epoch_idx
722
+
723
+ # Get local indices from global index
724
+ # TODO: This is very hacky, but it works for now
725
+ idx = self.indices[sample_info.idx_in_epoch]
726
+ year_idx = idx // self.num_samples_per_year
727
+ in_idx = idx % self.num_samples_per_year
728
+
729
+ state_seq = self._load_sequence(year_idx, in_idx)
730
+
731
+ # Load sequence of timestamps
732
+ year = self.start_year + year_idx
733
+ start_time = datetime(year, 1, 1, tzinfo=pytz.utc) + timedelta(
734
+ hours=int(in_idx) * self.dt
735
+ )
736
+ timestamps = np.array(
737
+ [
738
+ (start_time + timedelta(hours=i * self.stride * self.dt)).timestamp()
739
+ for i in range(self.num_steps)
740
+ ]
741
+ )
742
+
743
+ # outputs from auxiliary sources
744
+ aux_outputs = (
745
+ callback(timestamps, self.latlon)
746
+ for callback in self.aux_variables.values()
747
+ )
748
+
749
+ return (state_seq, timestamps, *aux_outputs)
750
+
751
+ def num_outputs(self):
752
+ return 2 + len(self.aux_variables)
753
+
754
+ def __len__(self):
755
+ return len(self.indices)
756
+
757
+
758
+ class ClimateHDF5DaliExternalSource(ClimateDaliExternalSource):
759
+ """DALI source for reading HDF5 formatted climate data files."""
760
+
761
+ def _get_data_file(self, year_idx: int) -> h5py.File:
762
+ """Return the opened file for year `year_idx`."""
763
+ if self.data_files[year_idx] is None:
764
+ # This will be called once per worker. Workers are persistent,
765
+ # so there is no need to explicitly close the files - this will be done
766
+ # when corresponding pipeline/dataset is destroyed.
767
+ # Lazy opening avoids unnecessary file open ops when sharding.
768
+ self.data_files[year_idx] = h5py.File(self.data_paths[year_idx], "r")
769
+ return self.data_files[year_idx]
770
+
771
+ def _load_sequence(self, year_idx: int, idx: int) -> np.array:
772
+ # TODO: the data is returned in a weird (time, channels, width, height) shape
773
+ data = self._get_data_file(year_idx)["fields"]
774
+ return data[idx : idx + self.num_steps * self.stride : self.stride, self.chans]
775
+
776
+
777
+ class ClimateNetCDF4DaliExternalSource(ClimateDaliExternalSource):
778
+ """DALI source for reading NetCDF4 formatted climate data files."""
779
+
780
+ def _get_data_file(self, year_idx: int) -> netcdf_file:
781
+ """Return the opened file for year `year_idx`."""
782
+ if self.data_files[year_idx] is None:
783
+ # This will be called once per worker. Workers are persistent,
784
+ # so there is no need to explicitly close the files - this will be done
785
+ # when corresponding pipeline/dataset is destroyed
786
+ # Lazy opening avoids unnecessary file open ops when sharding.
787
+ # NOTE: The SciPy NetCDF reader can be used if the netCDF4 library
788
+ # causes crashes.
789
+ reader = self.backend_kwargs.get("reader", "netcdf4")
790
+ if reader == "scipy":
791
+ self.data_files[year_idx] = netcdf_file(self.data_paths[year_idx])
792
+ elif reader == "netcdf4":
793
+ self.data_files[year_idx] = nc.Dataset(self.data_paths[year_idx], "r")
794
+ self.data_files[year_idx].set_auto_maskandscale(False)
795
+
796
+ return self.data_files[year_idx]
797
+
798
+ def _load_sequence(self, year_idx: int, idx: int) -> np.array:
799
+ data_file = self._get_data_file(year_idx)
800
+ shape = data_file.variables[self.variables[0]].shape
801
+ shape = (self.num_steps, len(self.variables)) + shape[1:]
802
+ # TODO: this can be optimized to do the NetCDF scale/offset on GPU
803
+ output = np.empty(shape, dtype=np.float32)
804
+ for i, var in enumerate(self.variables):
805
+ v = data_file.variables[var]
806
+ output[:, i] = v[
807
+ idx : idx + self.num_steps * self.stride : self.stride
808
+ ].copy() # .copy() avoids hanging references
809
+ if hasattr(v, "scale_factor"):
810
+ output[:, i] *= v.scale_factor
811
+ if hasattr(v, "add_offset"):
812
+ output[:, i] += v.add_offset
813
+ return output
physics_mcp/source/physicsnemo/datapipes/climate/era5_hdf5.py ADDED
@@ -0,0 +1,622 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import h5py
18
+ import numpy as np
19
+ import torch
20
+
21
+ try:
22
+ import nvidia.dali as dali
23
+ import nvidia.dali.plugin.pytorch as dali_pth
24
+ except ImportError:
25
+ raise ImportError(
26
+ "DALI dataset requires NVIDIA DALI package to be installed. "
27
+ + "The package can be installed at:\n"
28
+ + "https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html"
29
+ )
30
+
31
+ from dataclasses import dataclass
32
+ from datetime import datetime, timedelta
33
+ from pathlib import Path
34
+ from typing import Dict, Iterable, List, Tuple, Union
35
+
36
+ import pytz
37
+
38
+ from physicsnemo.datapipes.climate.utils.invariant import latlon_grid
39
+ from physicsnemo.datapipes.climate.utils.zenith_angle import cos_zenith_angle
40
+
41
+ from ..datapipe import Datapipe
42
+ from ..meta import DatapipeMetaData
43
+
44
+ Tensor = torch.Tensor
45
+
46
+
47
+ @dataclass
48
+ class MetaData(DatapipeMetaData):
49
+ name: str = "ERA5HDF5"
50
+ # Optimization
51
+ auto_device: bool = True
52
+ cuda_graphs: bool = True
53
+ # Parallel
54
+ ddp_sharding: bool = True
55
+
56
+
57
+ class ERA5HDF5Datapipe(Datapipe):
58
+ """ERA5 DALI data pipeline for HDF5 files
59
+
60
+ Parameters
61
+ ----------
62
+ data_dir : str
63
+ Directory where ERA5 data is stored
64
+ stats_dir : Union[str, None], optional
65
+ Directory to data statistic numpy files for normalization, if None, no normalization
66
+ will be used, by default None
67
+ channels : Union[List[int], None], optional
68
+ Defines which ERA5 variables to load, if None will use all in HDF5 file, by default None
69
+ batch_size : int, optional
70
+ Batch size, by default 1
71
+ stride : int, optional
72
+ Number of steps between input and output variables. For example, if the dataset
73
+ contains data at every 6 hours, a stride 1 = 6 hour delta t and
74
+ stride 2 = 12 hours delta t, by default 1
75
+ num_steps : int, optional
76
+ Number of timesteps are included in the output variables, by default 1
77
+ num_history : int, optional
78
+ Number of previous timesteps included in the input variables, by default 0
79
+ latlon_resolution: Tuple[int, int], optional
80
+ The resolution for the latitude-longitude grid (H, W). Needs to be specified
81
+ for cos zenith angle computation, or interpolation. By default None
82
+ interpolation_type: str, optional
83
+ Interpolation type for resizing. Supports ["INTERP_NN", "INTERP_LINEAR", "INTERP_CUBIC",
84
+ "INTERP_LANCZOS3", "INTERP_TRIANGULAR", "INTERP_GAUSSIAN"]. By default None
85
+ (no interpolation is done)
86
+ patch_size : Union[Tuple[int, int], int, None], optional
87
+ If specified, crops input and output variables so image dimensions are
88
+ divisible by patch_size, by default None
89
+ num_samples_per_year : int, optional
90
+ Number of samples randomly taken from each year. If None, all will be used, by default None
91
+ use_cos_zenith: bool, optional
92
+ If True, the cosine zenith angles corresponding to the coordinates will be produced,
93
+ by default False
94
+ cos_zenith_args: Dict, optional
95
+ Dictionary containing the following:
96
+
97
+ dt: float, optional
98
+ Time in hours between each timestep in the dataset, by default 6 hr
99
+
100
+ start_year: int, optional
101
+ Start year of dataset, by default 1980
102
+
103
+ latlon_bounds : Tuple[Tuple[float, float], Tuple[float, float]], optional
104
+ Bounds of latitude and longitude in the data, in the format
105
+ ((lat_start, lat_end,), (lon_start, lon_end)).
106
+ By default ((90, -90), (0, 360)).
107
+
108
+ Defaults are only applicable if use_cos_zenith is True. Otherwise, defaults to {}.
109
+ use_time_of_year_index: bool
110
+ If true, also returns the index that can be used to determine the time of the year
111
+ corresponding to each sample. By default False.
112
+ shuffle : bool, optional
113
+ Shuffle dataset, by default True
114
+ num_workers : int, optional
115
+ Number of workers, by default 1
116
+ device: Union[str, torch.device], optional
117
+ Device for DALI pipeline to run on, by default cuda
118
+ process_rank : int, optional
119
+ Rank ID of local process, by default 0
120
+ world_size : int, optional
121
+ Number of training processes, by default 1
122
+ """
123
+
124
+ def __init__(
125
+ self,
126
+ data_dir: str,
127
+ stats_dir: Union[str, None] = None,
128
+ channels: Union[List[int], None] = None,
129
+ batch_size: int = 1,
130
+ num_steps: int = 1,
131
+ num_history: int = 0,
132
+ stride: int = 1,
133
+ latlon_resolution: Union[Tuple[int, int], None] = None,
134
+ interpolation_type: Union[str, None] = None,
135
+ patch_size: Union[Tuple[int, int], int, None] = None,
136
+ num_samples_per_year: Union[int, None] = None,
137
+ use_cos_zenith: bool = False,
138
+ cos_zenith_args: Dict = {},
139
+ use_time_of_year_index: bool = False,
140
+ shuffle: bool = True,
141
+ num_workers: int = 1,
142
+ device: Union[str, torch.device] = "cuda",
143
+ process_rank: int = 0,
144
+ world_size: int = 1,
145
+ ):
146
+ super().__init__(meta=MetaData())
147
+ self.batch_size = batch_size
148
+ self.num_workers = num_workers
149
+ self.shuffle = shuffle
150
+ self.data_dir = Path(data_dir)
151
+ self.stats_dir = Path(stats_dir) if stats_dir is not None else None
152
+ self.channels = channels
153
+ self.stride = stride
154
+ self.latlon_resolution = latlon_resolution
155
+ self.interpolation_type = interpolation_type
156
+ self.num_steps = num_steps
157
+ self.num_history = num_history
158
+ self.num_samples_per_year = num_samples_per_year
159
+ self.use_cos_zenith = use_cos_zenith
160
+ self.cos_zenith_args = cos_zenith_args
161
+ self.use_time_of_year_index = use_time_of_year_index
162
+ self.process_rank = process_rank
163
+ self.world_size = world_size
164
+
165
+ # cos zenith defaults
166
+ if use_cos_zenith:
167
+ cos_zenith_args["dt"] = cos_zenith_args.get("dt", 6.0)
168
+ cos_zenith_args["start_year"] = cos_zenith_args.get("start_year", 1980)
169
+ cos_zenith_args["latlon_bounds"] = cos_zenith_args.get(
170
+ "latlon_bounds",
171
+ (
172
+ (90, -90),
173
+ (0, 360),
174
+ ),
175
+ )
176
+ self.latlon_bounds = cos_zenith_args.get("latlon_bounds")
177
+
178
+ if isinstance(patch_size, int):
179
+ patch_size = (patch_size, patch_size)
180
+ self.patch_size = patch_size
181
+
182
+ # Set up device, needed for pipeline
183
+ if isinstance(device, str):
184
+ device = torch.device(device)
185
+ # Need a index id if cuda
186
+ if device.type == "cuda" and device.index is None:
187
+ device = torch.device("cuda:0")
188
+ self.device = device
189
+
190
+ # check root directory exists
191
+ if not self.data_dir.is_dir():
192
+ raise IOError(f"Error, data directory {self.data_dir} does not exist")
193
+ if self.stats_dir is not None and not self.stats_dir.is_dir():
194
+ raise IOError(f"Error, stats directory {self.stats_dir} does not exist")
195
+
196
+ # Check interpolation type
197
+ if self.interpolation_type is not None:
198
+ valid_interpolation = [
199
+ "INTERP_NN",
200
+ "INTERP_LINEAR",
201
+ "INTERP_CUBIC",
202
+ "INTERP_LANCZOS3",
203
+ "INTERP_TRIANGULAR",
204
+ "INTERP_GAUSSIAN",
205
+ ]
206
+ if self.interpolation_type not in valid_interpolation:
207
+ raise ValueError(
208
+ f"Interpolation type {self.interpolation_type} not supported"
209
+ )
210
+ self.interpolation_type = getattr(dali.types, self.interpolation_type)
211
+
212
+ # Layout
213
+ # Avoiding API change for self.num_history == 0.
214
+ # Need to use FCHW layout in the future regardless of the num_history.
215
+ if self.num_history == 0:
216
+ self.layout = ["CHW", "FCHW"]
217
+ else:
218
+ self.layout = ["FCHW", "FCHW"]
219
+
220
+ self.output_keys = ["invar", "outvar"]
221
+
222
+ # Get latlon for zenith angle
223
+ if self.use_cos_zenith:
224
+ if not self.latlon_resolution:
225
+ raise ValueError("latlon_resolution must be set for cos zenith angle")
226
+ self.data_latlon = np.stack(
227
+ latlon_grid(bounds=self.latlon_bounds, shape=self.latlon_resolution),
228
+ axis=0,
229
+ )
230
+ self.latlon_dali = dali.types.Constant(self.data_latlon)
231
+ self.output_keys += ["cos_zenith"]
232
+
233
+ if self.use_time_of_year_index:
234
+ self.output_keys += ["time_of_year_idx"]
235
+
236
+ self.parse_dataset_files()
237
+ self.load_statistics()
238
+
239
+ self.pipe = self._create_pipeline()
240
+
241
+ def parse_dataset_files(self) -> None:
242
+ """Parses the data directory for valid HDF5 files and determines training samples
243
+
244
+ Raises
245
+ ------
246
+ ValueError
247
+ In channels specified or number of samples per year is not valid
248
+ """
249
+ # get all input data files
250
+ self.data_paths = sorted(self.data_dir.glob("????.h5"))
251
+ for data_path in self.data_paths:
252
+ self.logger.info(f"ERA5 file found: {data_path}")
253
+ self.n_years = len(self.data_paths)
254
+ self.logger.info(f"Number of years: {self.n_years}")
255
+
256
+ # get total number of examples and image shape from the first file,
257
+ # assuming other files have exactly the same format.
258
+ self.logger.info(f"Getting file stats from {self.data_paths[0]}")
259
+ with h5py.File(self.data_paths[0], "r") as f:
260
+ # truncate the dataset to avoid out-of-range sampling and ensure each
261
+ # rank has same number of samples (to avoid deadlocks)
262
+ data_samples_per_year = (
263
+ (
264
+ f["fields"].shape[0]
265
+ - (self.num_steps + self.num_history) * self.stride
266
+ )
267
+ // self.world_size
268
+ ) * self.world_size
269
+ if data_samples_per_year < 1:
270
+ raise ValueError(
271
+ f"Not enough number of samples per year ({data_samples_per_year})"
272
+ )
273
+ self.img_shape = f["fields"].shape[2:]
274
+
275
+ # If channels not provided, use all of them
276
+ if self.channels is None:
277
+ self.channels = [i for i in range(f["fields"].shape[1])]
278
+
279
+ # If num_samples_per_year use all
280
+ if self.num_samples_per_year is None:
281
+ self.num_samples_per_year = data_samples_per_year
282
+
283
+ # Adjust image shape if patch_size defined
284
+ if self.patch_size is not None:
285
+ if self.use_cos_zenith:
286
+ raise ValueError("Patching is not supported with cos zenith angle")
287
+ self.img_shape = [
288
+ s - s % self.patch_size[i] for i, s in enumerate(self.img_shape)
289
+ ]
290
+ self.logger.info(f"Input image shape: {self.img_shape}")
291
+
292
+ # Get total length
293
+ self.total_length = self.n_years * self.num_samples_per_year
294
+ self.length = self.total_length
295
+
296
+ # Sanity checks
297
+ if max(self.channels) >= f["fields"].shape[1]:
298
+ raise ValueError(
299
+ f"Provided channel has indexes greater than the number \
300
+ of fields {f['fields'].shape[1]}"
301
+ )
302
+
303
+ if self.num_samples_per_year > data_samples_per_year:
304
+ raise ValueError(
305
+ f"num_samples_per_year ({self.num_samples_per_year}) > number of \
306
+ samples available ({data_samples_per_year})!"
307
+ )
308
+
309
+ self.logger.info(f"Number of samples/year: {self.num_samples_per_year}")
310
+ self.logger.info(f"Number of channels available: {f['fields'].shape[1]}")
311
+
312
+ def load_statistics(self) -> None:
313
+ """Loads ERA5 statistics from pre-computed numpy files
314
+
315
+ The statistic files should be of name global_means.npy and global_std.npy with
316
+ a shape of [1, C, 1, 1] located in the stat_dir.
317
+
318
+ Raises
319
+ ------
320
+ IOError
321
+ If mean or std numpy files are not found
322
+ AssertionError
323
+ If loaded numpy arrays are not of correct size
324
+ """
325
+ # If no stats dir we just skip loading the stats
326
+ if self.stats_dir is None:
327
+ self.mu = None
328
+ self.std = None
329
+ return
330
+ # load normalisation values
331
+ mean_stat_file = self.stats_dir / Path("global_means.npy")
332
+ std_stat_file = self.stats_dir / Path("global_stds.npy")
333
+
334
+ if not mean_stat_file.exists():
335
+ raise IOError(f"Mean statistics file {mean_stat_file} not found")
336
+ if not std_stat_file.exists():
337
+ raise IOError(f"Std statistics file {std_stat_file} not found")
338
+
339
+ # has shape [1, C, 1, 1]
340
+ self.mu = np.load(str(mean_stat_file))[:, self.channels]
341
+ # has shape [1, C, 1, 1]
342
+ self.sd = np.load(str(std_stat_file))[:, self.channels]
343
+
344
+ if not self.mu.shape == self.sd.shape == (1, len(self.channels), 1, 1):
345
+ raise AssertionError("Error, normalisation arrays have wrong shape")
346
+
347
+ def _create_pipeline(self) -> dali.Pipeline:
348
+ """Create DALI pipeline
349
+
350
+ Returns
351
+ -------
352
+ dali.Pipeline
353
+ HDF5 DALI pipeline
354
+ """
355
+ pipe = dali.Pipeline(
356
+ batch_size=self.batch_size,
357
+ num_threads=2,
358
+ prefetch_queue_depth=2,
359
+ py_num_workers=self.num_workers,
360
+ device_id=self.device.index,
361
+ py_start_method="spawn",
362
+ )
363
+
364
+ with pipe:
365
+ source = ERA5DaliExternalSource(
366
+ data_paths=self.data_paths,
367
+ num_samples=self.total_length,
368
+ channels=self.channels,
369
+ stride=self.stride,
370
+ num_steps=self.num_steps,
371
+ num_history=self.num_history,
372
+ num_samples_per_year=self.num_samples_per_year,
373
+ use_cos_zenith=self.use_cos_zenith,
374
+ cos_zenith_args=self.cos_zenith_args,
375
+ use_time_of_year_index=self.use_time_of_year_index,
376
+ batch_size=self.batch_size,
377
+ shuffle=self.shuffle,
378
+ process_rank=self.process_rank,
379
+ world_size=self.world_size,
380
+ )
381
+ # Update length of dataset
382
+ self.length = len(source) // self.batch_size
383
+ # Read current batch.
384
+ invar, outvar, timestamps, time_of_year_idx = dali.fn.external_source(
385
+ source,
386
+ num_outputs=4,
387
+ parallel=True,
388
+ batch=False,
389
+ layout=self.layout,
390
+ )
391
+ if self.device.type == "cuda":
392
+ # Move tensors to GPU as external_source won't do that.
393
+ invar = invar.gpu()
394
+ outvar = outvar.gpu()
395
+
396
+ # Crop.
397
+ h, w = self.img_shape
398
+ if self.num_history == 0:
399
+ invar = invar[:, :h, :w]
400
+ else:
401
+ invar = invar[:, :, :h, :w]
402
+ outvar = outvar[:, :, :h, :w]
403
+
404
+ # Standardize.
405
+ if self.stats_dir is not None:
406
+ if self.num_history == 0:
407
+ invar = dali.fn.normalize(invar, mean=self.mu[0], stddev=self.sd[0])
408
+ else:
409
+ invar = dali.fn.normalize(invar, mean=self.mu, stddev=self.sd)
410
+ outvar = dali.fn.normalize(outvar, mean=self.mu, stddev=self.sd)
411
+
412
+ # Resize.
413
+ if self.interpolation_type is not None:
414
+ invar = dali.fn.resize(
415
+ invar,
416
+ resize_x=self.latlon_resolution[1],
417
+ resize_y=self.latlon_resolution[0],
418
+ interp_type=self.interpolation_type,
419
+ antialias=False,
420
+ )
421
+ outvar = dali.fn.resize(
422
+ outvar,
423
+ resize_x=self.latlon_resolution[1],
424
+ resize_y=self.latlon_resolution[0],
425
+ interp_type=self.interpolation_type,
426
+ antialias=False,
427
+ )
428
+
429
+ # cos zenith angle
430
+ if self.use_cos_zenith:
431
+ cos_zenith = dali.fn.cast(
432
+ cos_zenith_angle(timestamps, latlon=self.latlon_dali),
433
+ dtype=dali.types.FLOAT,
434
+ )
435
+ if self.device.type == "cuda":
436
+ cos_zenith = cos_zenith.gpu()
437
+
438
+ # # Time of the year
439
+ # time_of_year_idx = dali.fn.cast(
440
+ # time_of_year_idx,
441
+ # dtype=dali.types.UINT32,
442
+ # )
443
+
444
+ # Set outputs.
445
+ outputs = (invar, outvar)
446
+ if self.use_cos_zenith:
447
+ outputs += (cos_zenith,)
448
+ if self.use_time_of_year_index:
449
+ outputs += (time_of_year_idx,)
450
+ pipe.set_outputs(*outputs)
451
+
452
+ return pipe
453
+
454
+ def __iter__(self):
455
+ # Reset the pipeline before creating an iterator to enable epochs.
456
+ self.pipe.reset()
457
+ # Create DALI PyTorch iterator.
458
+ return dali_pth.DALIGenericIterator([self.pipe], self.output_keys)
459
+
460
+ def __len__(self):
461
+ return self.length
462
+
463
+
464
+ class ERA5DaliExternalSource:
465
+ """DALI Source for lazy-loading the HDF5 ERA5 files
466
+
467
+ Parameters
468
+ ----------
469
+ data_paths : Iterable[str]
470
+ Directory where ERA5 data is stored
471
+ num_samples : int
472
+ Total number of training samples
473
+ channels : Iterable[int]
474
+ List representing which ERA5 variables to load
475
+ start_year : int, optional
476
+ Start year of dataset
477
+ stride : int
478
+ Number of steps between input and output variables
479
+ num_steps : int
480
+ Number of timesteps are included in the output variables
481
+ num_history : int
482
+ Number of previous timesteps included in the input variables
483
+ num_samples_per_year : int
484
+ Number of samples randomly taken from each year
485
+ batch_size : int, optional
486
+ Batch size, by default 1
487
+ use_cos_zenith: bool
488
+ If True, the cosine zenith angles corresponding to the coordinates will be produced
489
+ cos_zenith_args: Dict
490
+ Dictionary containing the following:
491
+
492
+ dt: float
493
+ Time in hours between each timestep in the dataset
494
+
495
+ start_year: int
496
+ Start year of dataset
497
+ shuffle : bool, optional
498
+ Shuffle dataset, by default True
499
+ process_rank : int, optional
500
+ Rank ID of local process, by default 0
501
+ world_size : int, optional
502
+ Number of training processes, by default 1
503
+
504
+ Note
505
+ ----
506
+ For more information about DALI external source operator:
507
+ https://docs.nvidia.com/deeplearning/dali/archives/dali_1_13_0/user-guide/docs/examples/general/data_loading/parallel_external_source.html
508
+ """
509
+
510
+ def __init__(
511
+ self,
512
+ data_paths: Iterable[str],
513
+ num_samples: int,
514
+ channels: Iterable[int],
515
+ num_steps: int,
516
+ num_history: int,
517
+ stride: int,
518
+ num_samples_per_year: int,
519
+ use_cos_zenith: bool,
520
+ cos_zenith_args: Dict,
521
+ use_time_of_year_index: bool,
522
+ batch_size: int = 1,
523
+ shuffle: bool = True,
524
+ process_rank: int = 0,
525
+ world_size: int = 1,
526
+ ):
527
+ self.data_paths = list(data_paths)
528
+ # Will be populated later once each worker starts running in its own process.
529
+ self.data_files = None
530
+ self.num_samples = num_samples
531
+ self.chans = list(channels)
532
+ self.num_steps = num_steps
533
+ self.num_history = num_history
534
+ self.stride = stride
535
+ self.num_samples_per_year = num_samples_per_year
536
+ self.use_cos_zenith = use_cos_zenith
537
+ self.use_time_of_year_index = use_time_of_year_index
538
+ self.batch_size = batch_size
539
+ self.shuffle = shuffle
540
+
541
+ self.last_epoch = None
542
+
543
+ self.indices = np.arange(num_samples)
544
+ # Shard from indices if running in parallel
545
+ self.indices = np.array_split(self.indices, world_size)[process_rank]
546
+
547
+ # Get number of full batches, ignore possible last incomplete batch for now.
548
+ # Also, DALI external source does not support incomplete batches in parallel mode.
549
+ self.num_batches = len(self.indices) // self.batch_size
550
+
551
+ # cos zenith args
552
+ if self.use_cos_zenith:
553
+ self.dt: float = cos_zenith_args.get("dt")
554
+ self.start_year: int = cos_zenith_args.get("start_year")
555
+
556
+ def __call__(
557
+ self, sample_info: dali.types.SampleInfo
558
+ ) -> Tuple[Tensor, Tensor, np.ndarray]:
559
+ if sample_info.iteration >= self.num_batches:
560
+ raise StopIteration()
561
+
562
+ if self.data_files is None:
563
+ # This will be called once per worker. Workers are persistent,
564
+ # so there is no need to explicitly close the files - this will be done
565
+ # when corresponding pipeline/dataset is destroyed.
566
+ self.data_files = [h5py.File(path, "r") for path in self.data_paths]
567
+
568
+ # Shuffle before the next epoch starts.
569
+ if self.shuffle and sample_info.epoch_idx != self.last_epoch:
570
+ # All workers use the same rng seed so the resulting
571
+ # indices are the same across workers.
572
+ np.random.default_rng(seed=sample_info.epoch_idx).shuffle(self.indices)
573
+ self.last_epoch = sample_info.epoch_idx
574
+
575
+ # Get local indices from global index.
576
+ idx = self.indices[sample_info.idx_in_epoch]
577
+ year_idx = idx // self.num_samples_per_year
578
+ in_idx = idx % self.num_samples_per_year
579
+
580
+ # Load sequence of timestamps
581
+ if self.use_cos_zenith:
582
+ year = self.start_year + year_idx
583
+ start_time = datetime(year, 1, 1, tzinfo=pytz.utc) + timedelta(
584
+ hours=int(in_idx) * self.dt
585
+ )
586
+ timestamps = np.array(
587
+ [
588
+ (
589
+ start_time + timedelta(hours=i * self.stride * self.dt)
590
+ ).timestamp()
591
+ for i in range(self.num_history + self.num_steps + 1)
592
+ ]
593
+ )
594
+ else:
595
+ timestamps = np.array([])
596
+ if self.use_time_of_year_index:
597
+ time_of_year_idx = in_idx
598
+ else:
599
+ time_of_year_idx = -1
600
+
601
+ data = self.data_files[year_idx]["fields"]
602
+ if self.num_history == 0:
603
+ # Has [C,H,W] shape.
604
+ invar = data[in_idx, self.chans]
605
+ else:
606
+ # Has [T,C,H,W] shape.
607
+ invar = data[
608
+ in_idx : in_idx + (self.num_history + 1) * self.stride : self.stride,
609
+ self.chans,
610
+ ]
611
+
612
+ # Has [T,C,H,W] shape.
613
+ outvar = np.empty((self.num_steps,) + invar.shape[-3:], dtype=invar.dtype)
614
+
615
+ for i in range(self.num_steps):
616
+ out_idx = in_idx + (self.num_history + i + 1) * self.stride
617
+ outvar[i] = data[out_idx, self.chans]
618
+
619
+ return invar, outvar, timestamps, np.array([time_of_year_idx])
620
+
621
+ def __len__(self):
622
+ return len(self.indices)
physics_mcp/source/physicsnemo/datapipes/climate/era5_netcdf.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
physics_mcp/source/physicsnemo/datapipes/climate/synthetic.py ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+
18
+ import time
19
+ from typing import Any, Dict, List, Tuple
20
+
21
+ import numpy as np
22
+ import torch
23
+ from torch.utils.data import DataLoader, Dataset
24
+
25
+
26
+ class SyntheticWeatherDataLoader(DataLoader):
27
+ """
28
+ This custom DataLoader initializes the SyntheticWeatherDataset with given arguments.
29
+ """
30
+
31
+ def __init__(self, *args, **kwargs):
32
+ dataset = SyntheticWeatherDataset(*args, **kwargs)
33
+ super().__init__(
34
+ dataset=dataset,
35
+ batch_size=kwargs.get("batch_size", 1),
36
+ shuffle=kwargs.get("shuffle", False),
37
+ num_workers=kwargs.get("num_workers", 0),
38
+ pin_memory=kwargs.get("pin_memory", False),
39
+ drop_last=kwargs.get("drop_last", False),
40
+ )
41
+
42
+
43
+ class SyntheticWeatherDataset(Dataset):
44
+ """
45
+ A dataset for generating synthetic temperature data on a latitude-longitude grid for multiple atmospheric layers.
46
+
47
+ Args:
48
+ channels (list): List of channels representing different atmospheric layers.
49
+ num_samples_per_year (int): Total number of days to simulate per year.
50
+ num_steps (int): Number of consecutive days in each training sample.
51
+ grid_size (tuple): Latitude by longitude dimensions of the temperature grid.
52
+ base_temp (float): Base temperature around which variations are simulated.
53
+ amplitude (float): Amplitude of the sinusoidal temperature variation.
54
+ noise_level (float): Standard deviation of the noise added to temperature data.
55
+ **kwargs: Additional keyword arguments for advanced configurations.
56
+ """
57
+
58
+ def __init__(
59
+ self,
60
+ channels: List[int],
61
+ num_samples_per_year: int,
62
+ num_steps: int,
63
+ device: str | torch.device = "cuda",
64
+ grid_size: Tuple[int, int] = (721, 1440),
65
+ base_temp: float = 15,
66
+ amplitude: float = 10,
67
+ noise_level: float = 2,
68
+ **kwargs: Any,
69
+ ):
70
+ self.num_days: int = num_samples_per_year
71
+ self.num_steps: int = num_steps
72
+ self.num_channels: int = len(channels)
73
+ self.device = device
74
+ self.grid_size: Tuple[int, int] = grid_size
75
+ start_time = time.time()
76
+ self.temperatures: np.ndarray = self.generate_data(
77
+ self.num_days,
78
+ self.num_channels,
79
+ self.grid_size,
80
+ base_temp,
81
+ amplitude,
82
+ noise_level,
83
+ )
84
+ print(
85
+ f"Generated synthetic temperature data in {time.time() - start_time:.2f} seconds."
86
+ )
87
+ self.extra_args: Dict[str, Any] = kwargs
88
+
89
+ def generate_data(
90
+ self,
91
+ num_days: int,
92
+ num_channels: int,
93
+ grid_size: Tuple[int, int],
94
+ base_temp: float,
95
+ amplitude: float,
96
+ noise_level: float,
97
+ ) -> np.ndarray:
98
+ """
99
+ Generates synthetic temperature data over a specified number of days for multiple atmospheric layers.
100
+
101
+ Args:
102
+ num_days (int): Number of days to generate data for.
103
+ num_channels (int): Number of channels representing different layers.
104
+ grid_size (tuple): Grid size (latitude, longitude).
105
+ base_temp (float): Base mean temperature for the data.
106
+ amplitude (float): Amplitude of temperature variations.
107
+ noise_level (float): Noise level to add stochasticity to the temperature.
108
+
109
+ Returns:
110
+ numpy.ndarray: A 4D array of temperature values across days, channels, latitudes, and longitudes.
111
+ """
112
+ days = np.arange(num_days)
113
+ latitudes, longitudes = grid_size
114
+
115
+ # Create altitude effect and reshape
116
+ altitude_effect = np.arange(num_channels) * -0.5
117
+ altitude_effect = altitude_effect[
118
+ :, np.newaxis, np.newaxis
119
+ ] # Shape: (num_channels, 1, 1)
120
+ altitude_effect = np.tile(
121
+ altitude_effect, (1, latitudes, longitudes)
122
+ ) # Shape: (num_channels, latitudes, longitudes)
123
+ altitude_effect = altitude_effect[
124
+ np.newaxis, :, :, :
125
+ ] # Shape: (1, num_channels, latitudes, longitudes)
126
+ altitude_effect = np.tile(
127
+ altitude_effect, (num_days, 1, 1, 1)
128
+ ) # Shape: (num_days, num_channels, latitudes, longitudes)
129
+
130
+ # Create latitude variation and reshape
131
+ lat_variation = np.linspace(-amplitude, amplitude, latitudes)
132
+ lat_variation = lat_variation[:, np.newaxis] # Shape: (latitudes, 1)
133
+ lat_variation = np.tile(
134
+ lat_variation, (1, longitudes)
135
+ ) # Shape: (latitudes, longitudes)
136
+ lat_variation = lat_variation[
137
+ np.newaxis, np.newaxis, :, :
138
+ ] # Shape: (1, 1, latitudes, longitudes)
139
+ lat_variation = np.tile(
140
+ lat_variation, (num_days, num_channels, 1, 1)
141
+ ) # Shape: (num_days, num_channels, latitudes, longitudes)
142
+
143
+ # Create time effect and reshape
144
+ time_effect = np.sin(2 * np.pi * days / 365)
145
+ time_effect = time_effect[
146
+ :, np.newaxis, np.newaxis, np.newaxis
147
+ ] # Shape: (num_days, 1, 1, 1)
148
+ time_effect = np.tile(
149
+ time_effect, (1, num_channels, latitudes, longitudes)
150
+ ) # Shape: (num_days, num_channels, latitudes, longitudes)
151
+
152
+ # Generate noise
153
+ noise = np.random.normal(
154
+ scale=noise_level, size=(num_days, num_channels, latitudes, longitudes)
155
+ )
156
+
157
+ # Calculate daily temperatures
158
+ daily_temps = base_temp + altitude_effect + lat_variation + time_effect + noise
159
+
160
+ return daily_temps
161
+
162
+ def __len__(self) -> int:
163
+ """
164
+ Returns the number of samples available in the dataset.
165
+ """
166
+ return self.num_days - self.num_steps
167
+
168
+ def __getitem__(self, idx: int) -> torch.Tensor:
169
+ """
170
+ Retrieves a sample from the dataset at the specified index.
171
+ """
172
+ return [
173
+ {
174
+ "invar": torch.tensor(self.temperatures[idx], dtype=torch.float32).to(
175
+ self.device
176
+ ),
177
+ "outvar": torch.tensor(
178
+ self.temperatures[idx + 1 : idx + self.num_steps + 1],
179
+ dtype=torch.float32,
180
+ ).to(self.device),
181
+ }
182
+ ]
physics_mcp/source/physicsnemo/datapipes/climate/utils/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
physics_mcp/source/physicsnemo/datapipes/climate/utils/invariant.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ from abc import ABC, abstractmethod
18
+ from typing import List, Tuple
19
+
20
+ import numpy as np
21
+ import xarray as xr
22
+
23
+
24
+ def latlon_grid(
25
+ bounds: Tuple[Tuple[float, float], Tuple[float, float]] = (
26
+ (90, -90),
27
+ (0, 360),
28
+ ),
29
+ shape: Tuple[int, int] = (1440, 721),
30
+ ) -> np.ndarray:
31
+ """Infer latitude and longitude coordinates from bounds and data shape on a
32
+ equirectangular grid."""
33
+
34
+ # get latitudes and longitudes from data shape
35
+ lat = np.linspace(*bounds[0], shape[0], dtype=np.float32)
36
+
37
+ # does longitude wrap around the globe?
38
+ lon_wraparound = (bounds[1][0] % 360) == (bounds[1][1] % 360)
39
+ if lon_wraparound:
40
+ # treat differently from lat due to wrap-around
41
+ lon = np.linspace(*bounds[1], shape[1] + 1, dtype=np.float32)[:-1]
42
+ else:
43
+ lon = np.linspace(*bounds[1], shape[1], dtype=np.float32)
44
+
45
+ return np.meshgrid(lat, lon, indexing="ij")
46
+
47
+
48
+ class Invariant(ABC):
49
+ """Invariant abstract class representing data that is invariant to inputs on load"""
50
+
51
+ @abstractmethod
52
+ def __call__(self, latlon: np.ndarray):
53
+ pass
54
+
55
+
56
+ class LatLon(Invariant):
57
+ """Time invariant latitude and longitude coordinates and trig functions"""
58
+
59
+ def __init__(
60
+ self, outputs: List[str] = ("sin_lat", "cos_lat", "sin_lon", "cos_lon")
61
+ ):
62
+ """
63
+ Outputs latitude and longitude and their trigonometric functions.
64
+
65
+ Parameters
66
+ ----------
67
+ outputs: List[str]
68
+ List of outputs. Supported values are
69
+ `{"lat", "lon", "sin_lat", "cos_lat", "sin_lon", "cos_lon"}`
70
+ """
71
+ self.outputs = outputs
72
+
73
+ def __call__(self, latlon: np.ndarray):
74
+ (lat, lon) = latlon
75
+
76
+ vars = {"lat": lat, "lon": lon}
77
+
78
+ # cos/sin latitudes and longitudes
79
+ if "sin_lat" in self.outputs:
80
+ vars["sin_lat"] = np.sin(np.deg2rad(lat))
81
+ if "cos_lat" in self.outputs:
82
+ vars["cos_lat"] = np.cos(np.deg2rad(lat))
83
+ if "sin_lon" in self.outputs:
84
+ vars["sin_lon"] = np.sin(np.deg2rad(lon))
85
+ if "cos_lon" in self.outputs:
86
+ vars["cos_lon"] = np.cos(np.deg2rad(lon))
87
+
88
+ return np.stack([vars[o] for o in self.outputs], axis=0)
89
+
90
+
91
+ class FileInvariant(Invariant):
92
+ """
93
+ Loads an time-invariant variable from a NetCDF4 file. The file should
94
+ contain one or more data variables of dimensions
95
+ `(channels, latitude, longitude)` as well as variables `latitude` and
96
+ `longitude` specifying these coordinates. `latitude` and `longitude`
97
+ can be either 2D or 1D.
98
+
99
+ Parameters
100
+ ----------
101
+ filename: str
102
+ Path to the file containing the variable
103
+ var_name: str
104
+ The variable in the file containing the data
105
+ normalize: bool, optional
106
+ If True, normalize the data by to zero-mean and unit variance.
107
+ Default False.
108
+ interp_method: str, optional
109
+ Any argument accepted by xarray.DataArray.interp.
110
+ Default 'linear'.
111
+ """
112
+
113
+ def __init__(
114
+ self,
115
+ filename: str,
116
+ var_name: str,
117
+ normalize=False,
118
+ interp_method="linear",
119
+ ):
120
+ with xr.open_dataset(filename) as ds:
121
+ self.data = ds[var_name].astype(np.float32)
122
+ self.lat = ds["latitude"].to_numpy().astype(np.float32)
123
+ self.lon = ds["longitude"].to_numpy().astype(np.float32)
124
+
125
+ if self.lat.ndim == 1:
126
+ (self.lat, self.lon) = np.meshgrid(self.lat, self.lon, indexing="ij")
127
+
128
+ if normalize:
129
+ self.data = (self.data - self.data.mean()) / self.data.std()
130
+
131
+ self.interp_method = interp_method
132
+
133
+ def __call__(self, latlon: np.ndarray):
134
+ (lat, lon) = latlon
135
+ lat = xr.DataArray(lat, dims=["latitude", "longitude"])
136
+ lon = xr.DataArray(lon, dims=["latitude", "longitude"])
137
+ return self.data.interp(
138
+ method=self.interp_method, latitude=lat, longitude=lon
139
+ ).to_numpy()
physics_mcp/source/physicsnemo/datapipes/climate/utils/zenith_angle.py ADDED
@@ -0,0 +1,208 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # ignore_header_test
2
+
3
+ # climt/LICENSE
4
+ # @mcgibbon
5
+ # BSD License
6
+ # Copyright (c) 2016, Rodrigo Caballero
7
+ # All rights reserved.
8
+ # Redistribution and use in source and binary forms, with or without modification,
9
+ # are permitted provided that the following conditions are met:
10
+ # * Redistributions of source code must retain the above copyright notice, this
11
+ # list of conditions and the following disclaimer.
12
+ # * Redistributions in binary form must reproduce the above copyright notice, this
13
+ # list of conditions and the following disclaimer in the documentation and/or
14
+ # other materials provided with the distribution.
15
+ # * Neither the name of the copyright holder nor the names of its
16
+ # contributors may be used to endorse or promote products derived from this
17
+ # software without specific prior written permission.
18
+ # THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
19
+ # ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
20
+ # WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
21
+ # IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
22
+ # INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
23
+ # BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
24
+ # DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
25
+ # OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
26
+ # OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
27
+ # OF THE POSSIBILITY OF SUCH DAMAGE.
28
+
29
+
30
+ import datetime
31
+
32
+ import numpy as np
33
+ import pytz
34
+
35
+ try:
36
+ import nvidia.dali as dali
37
+ except ImportError:
38
+ raise ImportError(
39
+ "DALI dataset requires NVIDIA DALI package to be installed. "
40
+ + "The package can be installed at:\n"
41
+ + "https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html"
42
+ )
43
+
44
+ RAD_PER_DEG = np.pi / 180.0
45
+ DATETIME_2000 = datetime.datetime(2000, 1, 1, 12, 0, 0, tzinfo=pytz.utc).timestamp()
46
+
47
+
48
+ def _dali_mod(a, b):
49
+ return a - b * dali.math.floor(a / b)
50
+
51
+
52
+ def cos_zenith_angle(
53
+ time: dali.types.DALIDataType,
54
+ latlon: dali.types.DALIDataType,
55
+ ):
56
+ """
57
+ Dali datapipe for computing Cosine of sun-zenith angle for lon, lat at time (UTC).
58
+
59
+ Parameters
60
+ ----------
61
+ time : dali.types.DALIDataType
62
+ Time in seconds since 2000-01-01 12:00:00 UTC. Shape `(seq_length,)`.
63
+ latlon : dali.types.DALIDataType
64
+ Latitude and longitude in degrees. Shape `(2, nr_lat, nr_lon)`.
65
+
66
+ Returns
67
+ -------
68
+ dali.types.DALIDataType
69
+ Cosine of sun-zenith angle. Shape `(seq_length, 1, nr_lat, nr_lon)`.
70
+ """
71
+ lat = latlon[dali.newaxis, 0:1, :, :] * RAD_PER_DEG
72
+ lon = latlon[dali.newaxis, 1:2, :, :] * RAD_PER_DEG
73
+ time = time[:, dali.newaxis, dali.newaxis, dali.newaxis]
74
+ return _star_cos_zenith(time, lat, lon)
75
+
76
+
77
+ def _days_from_2000(model_time): # pragma: no cover
78
+ """Get the days since year 2000."""
79
+ return (model_time - DATETIME_2000) / (24.0 * 3600.0)
80
+
81
+
82
+ def _greenwich_mean_sidereal_time(model_time):
83
+ """
84
+ Greenwich mean sidereal time, in radians.
85
+ Reference:
86
+ The AIAA 2006 implementation:
87
+ http://www.celestrak.com/publications/AIAA/2006-6753/
88
+ """
89
+ jul_centuries = _days_from_2000(model_time) / 36525.0
90
+ theta = 67310.54841 + jul_centuries * (
91
+ 876600 * 3600
92
+ + 8640184.812866
93
+ + jul_centuries * (0.093104 - jul_centuries * 6.2 * 10e-6)
94
+ )
95
+
96
+ theta_radians = _dali_mod((theta / 240.0) * RAD_PER_DEG, 2 * np.pi)
97
+ return theta_radians
98
+
99
+
100
+ def _local_mean_sidereal_time(model_time, longitude):
101
+ """
102
+ Local mean sidereal time. requires longitude in radians.
103
+ Ref:
104
+ http://www.setileague.org/askdr/lmst.htm
105
+ """
106
+ return _greenwich_mean_sidereal_time(model_time) + longitude
107
+
108
+
109
+ def _sun_ecliptic_longitude(model_time):
110
+ """
111
+ Ecliptic longitude of the sun.
112
+ Reference:
113
+ http://www.geoastro.de/elevaz/basics/meeus.htm
114
+ """
115
+ julian_centuries = _days_from_2000(model_time) / 36525.0
116
+
117
+ # mean anomaly calculation
118
+ mean_anomaly = (
119
+ 357.52910
120
+ + 35999.05030 * julian_centuries
121
+ - 0.0001559 * julian_centuries * julian_centuries
122
+ - 0.00000048 * julian_centuries * julian_centuries * julian_centuries
123
+ ) * RAD_PER_DEG
124
+
125
+ # mean longitude
126
+ mean_longitude = (
127
+ 280.46645 + 36000.76983 * julian_centuries + 0.0003032 * (julian_centuries**2)
128
+ ) * RAD_PER_DEG
129
+
130
+ d_l = (
131
+ (1.914600 - 0.004817 * julian_centuries - 0.000014 * (julian_centuries**2))
132
+ * dali.math.sin(mean_anomaly)
133
+ + (0.019993 - 0.000101 * julian_centuries) * dali.math.sin(2 * mean_anomaly)
134
+ + 0.000290 * dali.math.sin(3 * mean_anomaly)
135
+ ) * RAD_PER_DEG
136
+
137
+ # true longitude
138
+ return mean_longitude + d_l
139
+
140
+
141
+ def _obliquity_star(julian_centuries):
142
+ """
143
+ return obliquity of the sun
144
+ Use 5th order equation from
145
+ https://en.wikipedia.org/wiki/Ecliptic#Obliquity_of_the_ecliptic
146
+ """
147
+ return (
148
+ 23.0
149
+ + 26.0 / 60
150
+ + 21.406 / 3600.0
151
+ - (
152
+ 46.836769 * julian_centuries
153
+ - 0.0001831 * (julian_centuries**2)
154
+ + 0.00200340 * (julian_centuries**3)
155
+ - 0.576e-6 * (julian_centuries**4)
156
+ - 4.34e-8 * (julian_centuries**5)
157
+ )
158
+ / 3600.0
159
+ ) * RAD_PER_DEG
160
+
161
+
162
+ def _right_ascension_declination(model_time):
163
+ """
164
+ Right ascension and declination of the sun.
165
+ """
166
+ julian_centuries = _days_from_2000(model_time) / 36525.0
167
+ eps = _obliquity_star(julian_centuries)
168
+
169
+ eclon = _sun_ecliptic_longitude(model_time)
170
+ x = dali.math.cos(eclon)
171
+ y = dali.math.cos(eps) * dali.math.sin(eclon)
172
+ z = dali.math.sin(eps) * dali.math.sin(eclon)
173
+ r = dali.math.sqrt(1.0 - z * z)
174
+ # sun declination
175
+ declination = dali.math.atan2(z, r)
176
+ # right ascension
177
+ right_ascension = 2 * dali.math.atan2(y, (x + r))
178
+ return right_ascension, declination
179
+
180
+
181
+ def _local_hour_angle(model_time, longitude, right_ascension):
182
+ """
183
+ Hour angle at model_time for the given longitude and right_ascension
184
+ longitude in radians
185
+ Ref:
186
+ https://en.wikipedia.org/wiki/Hour_angle#Relation_with_the_right_ascension
187
+ """
188
+ return _local_mean_sidereal_time(model_time, longitude) - right_ascension
189
+
190
+
191
+ def _star_cos_zenith(model_time, lat, lon):
192
+ """
193
+ Return cosine of star zenith angle
194
+ lon,lat in radians
195
+ Ref:
196
+ Azimuth:
197
+ https://en.wikipedia.org/wiki/Solar_azimuth_angle#Formulas
198
+ Zenith:
199
+ https://en.wikipedia.org/wiki/Solar_zenith_angle
200
+ """
201
+
202
+ ra, dec = _right_ascension_declination(model_time)
203
+ h_angle = _local_hour_angle(model_time, lon, ra)
204
+
205
+ cosine_zenith = dali.math.sin(lat) * dali.math.sin(dec) + dali.math.cos(
206
+ lat
207
+ ) * dali.math.cos(dec) * dali.math.cos(h_angle)
208
+ return cosine_zenith
physics_mcp/source/physicsnemo/datapipes/datapipe.py ADDED
@@ -0,0 +1,60 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.
16
+
17
+ import logging
18
+
19
+ from physicsnemo.datapipes.meta import DatapipeMetaData
20
+
21
+
22
+ class Datapipe:
23
+ """The base class for all datapipes in PhysicsNeMo.
24
+
25
+ Parameters
26
+ ----------
27
+ meta : DatapipeMetaData, optional
28
+ Meta data class for storing info regarding model, by default None
29
+ """
30
+
31
+ def __init__(self, meta: DatapipeMetaData = None):
32
+ super().__init__()
33
+
34
+ if not meta or not isinstance(meta, DatapipeMetaData):
35
+ self.meta = DatapipeMetaData()
36
+ else:
37
+ self.meta = meta
38
+
39
+ self.logger = logging.getLogger("core.datapipe")
40
+ handler = logging.StreamHandler()
41
+ formatter = logging.Formatter(
42
+ "[%(asctime)s - %(levelname)s] %(message)s", datefmt="%H:%M:%S"
43
+ )
44
+ handler.setFormatter(formatter)
45
+ self.logger.addHandler(handler)
46
+ self.logger.setLevel(logging.WARNING)
47
+
48
+ def debug(self):
49
+ """Turn on debug logging"""
50
+ self.logger.handlers.clear()
51
+ handler = logging.StreamHandler()
52
+ formatter = logging.Formatter(
53
+ f"[%(asctime)s - %(levelname)s - {self.meta.name}] %(message)s",
54
+ datefmt="%Y-%m-%d %H:%M:%S",
55
+ )
56
+ handler.setFormatter(formatter)
57
+ self.logger.addHandler(handler)
58
+ self.logger.setLevel(logging.DEBUG)
59
+ # TODO: set up debug log
60
+ # fh = logging.FileHandler(f'physicsnemo-core-{self.meta.name}.log')
physics_mcp/source/physicsnemo/datapipes/gnn/__init__.py ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
2
+ # SPDX-FileCopyrightText: All rights reserved.
3
+ # SPDX-License-Identifier: Apache-2.0
4
+ #
5
+ # Licensed under the Apache License, Version 2.0 (the "License");
6
+ # you may not use this file except in compliance with the License.
7
+ # You may obtain a copy of the License at
8
+ #
9
+ # http://www.apache.org/licenses/LICENSE-2.0
10
+ #
11
+ # Unless required by applicable law or agreed to in writing, software
12
+ # distributed under the License is distributed on an "AS IS" BASIS,
13
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
14
+ # See the License for the specific language governing permissions and
15
+ # limitations under the License.