Spaces:
Sleeping
Sleeping
update source
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitignore +70 -1
- Dockerfile +3 -0
- physics_mcp/mcp_output/requirements.txt +18 -11
- physics_mcp/source/.dockerignore +8 -0
- physics_mcp/source/.gitattributes +0 -0
- physics_mcp/source/.gitignore +176 -0
- physics_mcp/source/CHANGELOG.md +556 -0
- physics_mcp/source/CITATION.cff +7 -0
- physics_mcp/source/CONTRIBUTING.md +251 -0
- physics_mcp/source/FAQ.md +60 -0
- physics_mcp/source/LICENSE.txt +201 -0
- physics_mcp/source/README.md +472 -0
- physics_mcp/source/SECURITY.md +34 -0
- physics_mcp/source/__init__.py +4 -0
- physics_mcp/source/greptile.json +59 -0
- physics_mcp/source/physicsnemo/__init__.py +22 -0
- physics_mcp/source/physicsnemo/active_learning/README.md +66 -0
- physics_mcp/source/physicsnemo/active_learning/__init__.py +35 -0
- physics_mcp/source/physicsnemo/active_learning/_registry.py +332 -0
- physics_mcp/source/physicsnemo/active_learning/config.py +808 -0
- physics_mcp/source/physicsnemo/active_learning/driver.py +1449 -0
- physics_mcp/source/physicsnemo/active_learning/logger.py +330 -0
- physics_mcp/source/physicsnemo/active_learning/loop.py +534 -0
- physics_mcp/source/physicsnemo/active_learning/protocols.py +1394 -0
- physics_mcp/source/physicsnemo/constants.py +48 -0
- physics_mcp/source/physicsnemo/datapipes/__init__.py +15 -0
- physics_mcp/source/physicsnemo/datapipes/benchmarks/__init__.py +15 -0
- physics_mcp/source/physicsnemo/datapipes/benchmarks/darcy.py +322 -0
- physics_mcp/source/physicsnemo/datapipes/benchmarks/kelvin_helmholtz.py +436 -0
- physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/__init__.py +15 -0
- physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/finite_difference.py +139 -0
- physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/finite_volume.py +759 -0
- physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/indexing.py +182 -0
- physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/initialization.py +77 -0
- physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/utils.py +141 -0
- physics_mcp/source/physicsnemo/datapipes/cae/__init__.py +18 -0
- physics_mcp/source/physicsnemo/datapipes/cae/cae_dataset.py +1275 -0
- physics_mcp/source/physicsnemo/datapipes/cae/domino_datapipe.py +1334 -0
- physics_mcp/source/physicsnemo/datapipes/cae/mesh_datapipe.py +490 -0
- physics_mcp/source/physicsnemo/datapipes/cae/readers.py +191 -0
- physics_mcp/source/physicsnemo/datapipes/climate/__init__.py +19 -0
- physics_mcp/source/physicsnemo/datapipes/climate/climate.py +813 -0
- physics_mcp/source/physicsnemo/datapipes/climate/era5_hdf5.py +622 -0
- physics_mcp/source/physicsnemo/datapipes/climate/era5_netcdf.py +15 -0
- physics_mcp/source/physicsnemo/datapipes/climate/synthetic.py +182 -0
- physics_mcp/source/physicsnemo/datapipes/climate/utils/__init__.py +15 -0
- physics_mcp/source/physicsnemo/datapipes/climate/utils/invariant.py +139 -0
- physics_mcp/source/physicsnemo/datapipes/climate/utils/zenith_angle.py +208 -0
- physics_mcp/source/physicsnemo/datapipes/datapipe.py +60 -0
- physics_mcp/source/physicsnemo/datapipes/gnn/__init__.py +15 -0
.gitignore
CHANGED
|
@@ -1 +1,70 @@
|
|
| 1 |
-
*.DS_Store
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
*.DS_Store
|
| 2 |
+
|
| 3 |
+
# ===== 源代码中不必要的目录(来自NVIDIA原项目) =====
|
| 4 |
+
# 文档 (102MB)
|
| 5 |
+
physics_mcp/source/docs/
|
| 6 |
+
|
| 7 |
+
# 测试 (28MB)
|
| 8 |
+
physics_mcp/source/test/
|
| 9 |
+
|
| 10 |
+
# 示例 (17MB)
|
| 11 |
+
physics_mcp/source/examples/
|
| 12 |
+
|
| 13 |
+
# ===== Git和CI/CD配置 =====
|
| 14 |
+
physics_mcp/source/.github/
|
| 15 |
+
physics_mcp/source/.gitlab/
|
| 16 |
+
physics_mcp/source/.gitlab-ci.yml
|
| 17 |
+
physics_mcp/source/.pre-commit-config.yaml
|
| 18 |
+
physics_mcp/source/.markdownlint.yaml
|
| 19 |
+
|
| 20 |
+
# ===== 项目配置文件 =====
|
| 21 |
+
physics_mcp/source/Dockerfile
|
| 22 |
+
physics_mcp/source/Makefile
|
| 23 |
+
physics_mcp/source/.gitmodules
|
| 24 |
+
|
| 25 |
+
# ===== Python缓存 =====
|
| 26 |
+
**/__pycache__/
|
| 27 |
+
**/*.py[cod]
|
| 28 |
+
**/*$py.class
|
| 29 |
+
*.so
|
| 30 |
+
.Python
|
| 31 |
+
build/
|
| 32 |
+
develop-eggs/
|
| 33 |
+
dist/
|
| 34 |
+
downloads/
|
| 35 |
+
eggs/
|
| 36 |
+
.eggs/
|
| 37 |
+
lib/
|
| 38 |
+
lib64/
|
| 39 |
+
parts/
|
| 40 |
+
sdist/
|
| 41 |
+
var/
|
| 42 |
+
wheels/
|
| 43 |
+
*.egg-info/
|
| 44 |
+
.installed.cfg
|
| 45 |
+
*.egg
|
| 46 |
+
|
| 47 |
+
# ===== 虚拟环境 =====
|
| 48 |
+
venv/
|
| 49 |
+
env/
|
| 50 |
+
ENV/
|
| 51 |
+
.venv
|
| 52 |
+
|
| 53 |
+
# ===== IDE配置 =====
|
| 54 |
+
.vscode/
|
| 55 |
+
.idea/
|
| 56 |
+
*.swp
|
| 57 |
+
*.swo
|
| 58 |
+
*~
|
| 59 |
+
|
| 60 |
+
# ===== Pytest和覆盖率 =====
|
| 61 |
+
.pytest_cache/
|
| 62 |
+
.coverage
|
| 63 |
+
htmlcov/
|
| 64 |
+
|
| 65 |
+
# ===== 日志和临时文件 =====
|
| 66 |
+
*.log
|
| 67 |
+
*.tmp
|
| 68 |
+
*.tmp.txt
|
| 69 |
+
physics_mcp/mcp_output/mcp_logs/
|
| 70 |
+
physics_mcp/mcp_output/output/
|
Dockerfile
CHANGED
|
@@ -11,6 +11,9 @@ RUN apt-get update && apt-get install -y \
|
|
| 11 |
wget \
|
| 12 |
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
|
|
|
|
|
|
|
|
|
|
| 14 |
# Copy physics_mcp folder
|
| 15 |
COPY physics_mcp /app/physics_mcp
|
| 16 |
|
|
|
|
| 11 |
wget \
|
| 12 |
&& rm -rf /var/lib/apt/lists/*
|
| 13 |
|
| 14 |
+
# Copy source directory (original NVIDIA physicsnemo code) - REQUIRED
|
| 15 |
+
COPY physics_mcp/source /app/physics_mcp/source
|
| 16 |
+
|
| 17 |
# Copy physics_mcp folder
|
| 18 |
COPY physics_mcp /app/physics_mcp
|
| 19 |
|
physics_mcp/mcp_output/requirements.txt
CHANGED
|
@@ -1,19 +1,26 @@
|
|
| 1 |
fastmcp>=0.1.0
|
| 2 |
pydantic>=2.0.0
|
| 3 |
-
torch
|
| 4 |
-
numpy
|
| 5 |
-
scipy
|
| 6 |
-
onnx
|
| 7 |
tritonclient
|
| 8 |
matplotlib
|
| 9 |
pandas
|
| 10 |
-
pyyaml
|
| 11 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 12 |
|
| 13 |
-
# Optional Dependencies
|
| 14 |
-
#
|
| 15 |
-
#
|
|
|
|
| 16 |
# dgl
|
| 17 |
# pyg
|
| 18 |
-
# vtk
|
| 19 |
-
# netCDF4
|
|
|
|
|
|
|
|
|
| 1 |
fastmcp>=0.1.0
|
| 2 |
pydantic>=2.0.0
|
| 3 |
+
torch>=2.4.0
|
| 4 |
+
numpy>=1.22.4
|
| 5 |
+
scipy>=1.9.0
|
| 6 |
+
onnx>=1.14.0
|
| 7 |
tritonclient
|
| 8 |
matplotlib
|
| 9 |
pandas
|
| 10 |
+
pyyaml>=6.0
|
| 11 |
+
tqdm>=4.60.0
|
| 12 |
+
xarray>=2023.1.0
|
| 13 |
+
zarr>=2.14.2
|
| 14 |
+
s3fs>=2023.5.0
|
| 15 |
+
timm>=1.0.0
|
| 16 |
|
| 17 |
+
# Optional Dependencies (can be uncommented as needed)
|
| 18 |
+
# cuml>=24.0.0 (requires RAPIDS conda channel - use scipy fallback instead)
|
| 19 |
+
# wandb>=0.13.7
|
| 20 |
+
# mlflow>=2.1.1
|
| 21 |
# dgl
|
| 22 |
# pyg
|
| 23 |
+
# vtk>=9.2.6
|
| 24 |
+
# netCDF4>=1.6.3
|
| 25 |
+
# h5py>=3.7.0
|
| 26 |
+
# nvidia-dali-cuda120>=1.35.0
|
physics_mcp/source/.dockerignore
ADDED
|
@@ -0,0 +1,8 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
.git
|
| 2 |
+
.github
|
| 3 |
+
.gitlab
|
| 4 |
+
.coverage*
|
| 5 |
+
.*cache
|
| 6 |
+
examples
|
| 7 |
+
docs
|
| 8 |
+
test
|
physics_mcp/source/.gitattributes
ADDED
|
File without changes
|
physics_mcp/source/.gitignore
ADDED
|
@@ -0,0 +1,176 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Byte-compiled / optimized / DLL files
|
| 2 |
+
__pycache__/
|
| 3 |
+
*.py[cod]
|
| 4 |
+
*$py.class
|
| 5 |
+
|
| 6 |
+
# C extensions
|
| 7 |
+
*.so
|
| 8 |
+
|
| 9 |
+
# Distribution / packaging
|
| 10 |
+
.Python
|
| 11 |
+
build/
|
| 12 |
+
develop-eggs/
|
| 13 |
+
dist/
|
| 14 |
+
downloads/
|
| 15 |
+
eggs/
|
| 16 |
+
.eggs/
|
| 17 |
+
lib/
|
| 18 |
+
lib64/
|
| 19 |
+
parts/
|
| 20 |
+
sdist/
|
| 21 |
+
var/
|
| 22 |
+
wheels/
|
| 23 |
+
share/python-wheels/
|
| 24 |
+
*.egg-info/
|
| 25 |
+
.installed.cfg
|
| 26 |
+
*.egg
|
| 27 |
+
MANIFEST
|
| 28 |
+
|
| 29 |
+
# PyInstaller
|
| 30 |
+
# Usually these files are written by a python script from a template
|
| 31 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 32 |
+
*.manifest
|
| 33 |
+
*.spec
|
| 34 |
+
|
| 35 |
+
# Installer logs
|
| 36 |
+
pip-log.txt
|
| 37 |
+
pip-delete-this-directory.txt
|
| 38 |
+
|
| 39 |
+
# Unit test / coverage reports
|
| 40 |
+
htmlcov/
|
| 41 |
+
.tox/
|
| 42 |
+
.nox/
|
| 43 |
+
.coverage
|
| 44 |
+
.coverage.*
|
| 45 |
+
.cache
|
| 46 |
+
nosetests.xml
|
| 47 |
+
coverage.xml
|
| 48 |
+
*.cover
|
| 49 |
+
*.py,cover
|
| 50 |
+
.hypothesis/
|
| 51 |
+
.pytest_cache/
|
| 52 |
+
cover/
|
| 53 |
+
|
| 54 |
+
# Translations
|
| 55 |
+
*.mo
|
| 56 |
+
*.pot
|
| 57 |
+
|
| 58 |
+
# Django stuff:
|
| 59 |
+
*.log
|
| 60 |
+
local_settings.py
|
| 61 |
+
db.sqlite3
|
| 62 |
+
db.sqlite3-journal
|
| 63 |
+
|
| 64 |
+
# Flask stuff:
|
| 65 |
+
instance/
|
| 66 |
+
.webassets-cache
|
| 67 |
+
|
| 68 |
+
# Scrapy stuff:
|
| 69 |
+
.scrapy
|
| 70 |
+
|
| 71 |
+
# Sphinx documentation
|
| 72 |
+
docs/_build/
|
| 73 |
+
docs/examples/
|
| 74 |
+
|
| 75 |
+
# PyBuilder
|
| 76 |
+
.pybuilder/
|
| 77 |
+
target/
|
| 78 |
+
|
| 79 |
+
# Jupyter Notebook
|
| 80 |
+
.ipynb_checkpoints
|
| 81 |
+
|
| 82 |
+
# IPython
|
| 83 |
+
profile_default/
|
| 84 |
+
ipython_config.py
|
| 85 |
+
|
| 86 |
+
# pyenv
|
| 87 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 88 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 89 |
+
# .python-version
|
| 90 |
+
|
| 91 |
+
# pipenv
|
| 92 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 93 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 94 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 95 |
+
# install all needed dependencies.
|
| 96 |
+
#Pipfile.lock
|
| 97 |
+
|
| 98 |
+
# poetry
|
| 99 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 100 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 101 |
+
# commonly ignored for libraries.
|
| 102 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 103 |
+
#poetry.lock
|
| 104 |
+
|
| 105 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow
|
| 106 |
+
__pypackages__/
|
| 107 |
+
|
| 108 |
+
# Celery stuff
|
| 109 |
+
celerybeat-schedule
|
| 110 |
+
celerybeat.pid
|
| 111 |
+
|
| 112 |
+
# SageMath parsed files
|
| 113 |
+
*.sage.py
|
| 114 |
+
|
| 115 |
+
# Environments
|
| 116 |
+
.env
|
| 117 |
+
.venv
|
| 118 |
+
env/
|
| 119 |
+
venv/
|
| 120 |
+
ENV/
|
| 121 |
+
env.bak/
|
| 122 |
+
venv.bak/
|
| 123 |
+
|
| 124 |
+
# Spyder project settings
|
| 125 |
+
.spyderproject
|
| 126 |
+
.spyproject
|
| 127 |
+
|
| 128 |
+
# Rope project settings
|
| 129 |
+
.ropeproject
|
| 130 |
+
|
| 131 |
+
# mkdocs documentation
|
| 132 |
+
/site
|
| 133 |
+
|
| 134 |
+
# mypy
|
| 135 |
+
.mypy_cache/
|
| 136 |
+
.dmypy.json
|
| 137 |
+
dmypy.json
|
| 138 |
+
|
| 139 |
+
# Pyre type checker
|
| 140 |
+
.pyre/
|
| 141 |
+
|
| 142 |
+
# pytype static type analyzer
|
| 143 |
+
.pytype/
|
| 144 |
+
|
| 145 |
+
# Cython debug symbols
|
| 146 |
+
cython_debug/
|
| 147 |
+
|
| 148 |
+
# PyCharm
|
| 149 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 150 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 151 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 152 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 153 |
+
.idea/
|
| 154 |
+
|
| 155 |
+
# VsCode
|
| 156 |
+
.vscode/
|
| 157 |
+
.cursor/
|
| 158 |
+
|
| 159 |
+
# VIM
|
| 160 |
+
*.swp
|
| 161 |
+
*~
|
| 162 |
+
|
| 163 |
+
# Additional stuff
|
| 164 |
+
nsight-systems*
|
| 165 |
+
build/
|
| 166 |
+
mlruns/
|
| 167 |
+
checkpoints/
|
| 168 |
+
|
| 169 |
+
# Hydra
|
| 170 |
+
outputs/
|
| 171 |
+
multirun/
|
| 172 |
+
.hydra/
|
| 173 |
+
|
| 174 |
+
# SLURM
|
| 175 |
+
slurm-*.out
|
| 176 |
+
sbatch_logs/
|
physics_mcp/source/CHANGELOG.md
ADDED
|
@@ -0,0 +1,556 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<!-- markdownlint-disable MD024 -->
|
| 2 |
+
# Changelog
|
| 3 |
+
|
| 4 |
+
All notable changes to this project will be documented in this file.
|
| 5 |
+
|
| 6 |
+
The format is based on [Keep a Changelog](https://keepachangelog.com/en/1.0.0/),
|
| 7 |
+
and this project adheres to [Semantic Versioning](https://semver.org/spec/v2.0.0.html).
|
| 8 |
+
|
| 9 |
+
## [1.3.0a0] - 2025-XX-YY
|
| 10 |
+
|
| 11 |
+
### Added
|
| 12 |
+
|
| 13 |
+
- Added mixture_of_experts for weather example in physicsnemo.examples.weather.
|
| 14 |
+
**⚠️Warning:** - It uses experimental DiT model subject to future API changes.
|
| 15 |
+
Added some modifications to DiT architecture in physicsnemo.experimental.models.dit.
|
| 16 |
+
Added learnable option to PositionalEmbedding in physicsnemo.models.diffusion.layers.
|
| 17 |
+
- Added lead-time aware training support to the StormCast example.
|
| 18 |
+
- Add a device aware kNN method to physicsnemo.utils.neighbors. Works with CPU or GPU
|
| 19 |
+
by dispatching to the proper optimized library, and torch.compile compatible.
|
| 20 |
+
- Added additional testing of the DoMINO datapipe.
|
| 21 |
+
- Examples: added a new example for full-waveform inversion using diffusion
|
| 22 |
+
models. Accessible in `examples/geophysics/diffusion_fwi`.
|
| 23 |
+
- Domain Parallelism: Domain Parallelism is now available for kNN, radius_search,
|
| 24 |
+
and torch.nn.functional.pad.
|
| 25 |
+
- Unified recipe for crash modeling, supporting Transolver and MeshGraphNet,
|
| 26 |
+
and three transient schemes.
|
| 27 |
+
- Added a check to `stochastic_sampler` that helps handle the `EDMPrecond` model,
|
| 28 |
+
which has a specific `.forward()` signature
|
| 29 |
+
- Added abstract interfaces for constructing active learning workflows, contained
|
| 30 |
+
under the `physicsnemo.active_learning` namespace. A preliminary example of how
|
| 31 |
+
to compose and define an active learning workflow is provided in `examples/active_learning`.
|
| 32 |
+
The `moons` example provides a minimal (pedagogical) composition that is meant to
|
| 33 |
+
illustrate how to define the necessary parts of the workflow.
|
| 34 |
+
|
| 35 |
+
### Changed
|
| 36 |
+
|
| 37 |
+
- Migrated Stokes MGN example to PyTorch Geometric.
|
| 38 |
+
- Migrated Lennard Jones example to PyTorch Geometric.
|
| 39 |
+
- Migrated physicsnemo.utils.sdf.signed_distance_field to a static return,
|
| 40 |
+
torch-only interface. It also now works on distributed meshes and input fields.
|
| 41 |
+
- Refactored DiTBlock to be more modular
|
| 42 |
+
- Added NATTEN 2D neighborhood attention backend for DiTBlock
|
| 43 |
+
- Migrated blood flow example to PyTorch Geometric.
|
| 44 |
+
- Refactored DoMINO model code and examples for performance optimizations and improved readability.
|
| 45 |
+
- Migrated HydroGraphNet example to PyTorch Geometric.
|
| 46 |
+
- Support for saving and loading nested `physicsnemo.Module`s. It is now
|
| 47 |
+
possible to create nested modules with `m = Module(submodule, ...)`, and save
|
| 48 |
+
and load them with `Module.save` and `Module.from_checkpoint`.
|
| 49 |
+
**⚠️Warning:** - The modules have to be `physicsnemo.Module`s, and not
|
| 50 |
+
`torch.nn.Module`s.
|
| 51 |
+
- Support passing custom tokenizer, detokenizer, and attention `Module`s in
|
| 52 |
+
experimental DiT architecture
|
| 53 |
+
- Improved Transolver training recipe's configuration for checkpointing and normalization.
|
| 54 |
+
- Bumped `multi-storage-client` version to 0.33.0 with rust client.
|
| 55 |
+
|
| 56 |
+
### Deprecated
|
| 57 |
+
|
| 58 |
+
### Removed
|
| 59 |
+
|
| 60 |
+
### Fixed
|
| 61 |
+
|
| 62 |
+
- Set `skip_scale` to Python float in U-Net to ensure compilation works.
|
| 63 |
+
- Ensure stream dependencies are handled correctly in physicsnemo.utils.neighbors
|
| 64 |
+
- Fixed the issue with incorrect handling of files with consecutive runs of
|
| 65 |
+
`combine_stl_solids.py` in the X-MGN recipe.
|
| 66 |
+
- Fixed the `RuntimeError: Worker data receiving interrupted` error in the datacenter example.
|
| 67 |
+
|
| 68 |
+
### Security
|
| 69 |
+
|
| 70 |
+
### Dependencies
|
| 71 |
+
|
| 72 |
+
## [1.2.0] - 2025-08-26
|
| 73 |
+
|
| 74 |
+
### Added
|
| 75 |
+
|
| 76 |
+
- Diffusion Transformer (DiT) model. The DiT model can be accessed in
|
| 77 |
+
`physicsnemo.experimental.models.dit.DiT`. **⚠️Warning:** - Experimental feature
|
| 78 |
+
subject to future API changes.
|
| 79 |
+
- Improved documentation for diffusion models and diffusion utils.
|
| 80 |
+
- Safe API to override `__init__`'s arguments saved in checkpoint file with
|
| 81 |
+
`Module.from_checkpoint("chkpt.mdlus", override_args=set(...))`.
|
| 82 |
+
- PyTorch Geometric MeshGraphNet backend.
|
| 83 |
+
- Functionality in DoMINO to take arbitrary number of `scalar` or `vector`
|
| 84 |
+
global parameters and encode them using `class ParameterModel`
|
| 85 |
+
- TopoDiff model and example.
|
| 86 |
+
- Added ability for DoMINO model to return volume neighbors.
|
| 87 |
+
- Added functionality in DoMINO recipe to introduce physics residual losses.
|
| 88 |
+
- Diffusion models, metrics, and utils: implementation of Student-t
|
| 89 |
+
distribution for EDM-based diffusion models (t-EDM). This feature is adapted
|
| 90 |
+
from the paper [Heavy-Tailed Diffusion Models, Pandey et al.](https://arxiv.org/abs/2410.14171>).
|
| 91 |
+
This includes a new EDM preconditioner (`tEDMPrecondSuperRes`), a loss
|
| 92 |
+
function (`tEDMResidualLoss`), and a new option in corrdiff `diffusion_step`.
|
| 93 |
+
⚠️ This is an experimental feature that can be accessed through the
|
| 94 |
+
`physicsnemo.experimental` module; it might also be subjected to API changes
|
| 95 |
+
without notice.
|
| 96 |
+
- Bumped Ruff version from 0.0.290 to 0.12.5. Replaced Black with `ruff-format`.
|
| 97 |
+
- Domino improvements with Unet attention module and user configs
|
| 98 |
+
- Hybrid MeshGraphNet for modeling structural deformation
|
| 99 |
+
- Enabled TransformerEngine backend in the `transolver` model.
|
| 100 |
+
- Inference code for x-meshgraphnet example for external aerodynamics.
|
| 101 |
+
- Added a new example for external_aerodynamics: training `transolver` on
|
| 102 |
+
irregular mesh data for DrivaerML surface data.
|
| 103 |
+
- Added a new example for external aerodynamics for finetuning pretrained models.
|
| 104 |
+
|
| 105 |
+
### Changed
|
| 106 |
+
|
| 107 |
+
- Diffusion utils: `physicsnemo.utils.generative` renamed into `physicsnemo.utils.diffusion`
|
| 108 |
+
- Diffusion models: in CorrDiff model wrappers (`EDMPrecondSuperResolution` and
|
| 109 |
+
`UNet`), the arguments `profile_mode` and `amp_mode` cannot be overriden by
|
| 110 |
+
`from_checkpoint`. They are now properties that can be dynamically changed
|
| 111 |
+
*after* the model instantiation with, for example, `model.amp_mode = True`
|
| 112 |
+
and `model.profile_mode = False`.
|
| 113 |
+
- Updated healpix data module to use correct `DistributedSampler` target for
|
| 114 |
+
test data loader
|
| 115 |
+
- Existing DGL-based vortex shedding example has been renamed to `vortex_shedding_mgn_dgl`.
|
| 116 |
+
Added new `vortex_shedding_mgn` example that uses PyTorch Geometric instead.
|
| 117 |
+
- HEALPixLayer can now use earth2grid HEALPix padding ops, if desired
|
| 118 |
+
- Migrated Vortex Shedding Reduced Mesh example to PyTorch Geometric.
|
| 119 |
+
- CorrDiff example: fixed bugs when training regression `UNet`.
|
| 120 |
+
- Diffusion models: fixed bugs related to gradient checkpointing on non-square
|
| 121 |
+
images.
|
| 122 |
+
- Diffusion models: created a separate class `Attention` for clarity and
|
| 123 |
+
modularity. Updated `UNetBlock` accordingly to use the `Attention` class
|
| 124 |
+
instead of custom attention logic. This will update the model architecture
|
| 125 |
+
for `SongUNet`-based diffusion models. Changes are not BC-breaking and are
|
| 126 |
+
transparent to the user.
|
| 127 |
+
- ⚠️ **BC-breaking:** refactored the automatic mixed precision
|
| 128 |
+
(AMP) API in layers and models defined in `physicsnemo/models/diffusion/` for
|
| 129 |
+
improved usability. Note: it is now, not only possible, but *required* to
|
| 130 |
+
explicitly set `model.amp_mode = True` in order to use the model in a
|
| 131 |
+
`torch.autocast` clause. This applies to all `SongUNet`-based models.
|
| 132 |
+
- Diffusion models: fixed and improved API to enable fp16 forward pass in
|
| 133 |
+
`UNet` and `EDMPrecondSuperResolution` model wrappers; fp16 forward pass can
|
| 134 |
+
now be toggled/untoggled by setting `model.use_fp16 = True`.
|
| 135 |
+
- Diffusion models: improved API for Apex group norm. `SongUNet`-based models
|
| 136 |
+
will automatically perform conversion of the input tensors to
|
| 137 |
+
`torch.channels_last` memory format when `model.use_apex_gn` is `True`. New
|
| 138 |
+
warnings are raised when attempting to use Apex group norm on CPU.
|
| 139 |
+
- Diffusion utils: systematic compilation of patching operations in `stochastic_sampler`
|
| 140 |
+
for improved performance.
|
| 141 |
+
- CorrDiff example: added option for Student-t EDM (t-EDM) in `train.py` and
|
| 142 |
+
`generate.py`. When training a CorrDiff diffusion model, this feature can be
|
| 143 |
+
enabled with the hydra overrides `++training.hp.distribution=student_t` and
|
| 144 |
+
`++training.hp.nu_student_t=<nu_value>`. For generation, this feature can be
|
| 145 |
+
enabled with similar overrides: `++generation.distribution=student_t` and
|
| 146 |
+
`++generation.nu_student_t=<nu_value>`.
|
| 147 |
+
- CorrDiff example: the parameters `P_mean` and `P_std` (used to compute the
|
| 148 |
+
noise level `sigma`) are now configurable. They can be set with the hydra
|
| 149 |
+
overrides `++training.hp.P_mean=<P_mean_value>` and
|
| 150 |
+
`++training.hp.P_std=<P_std_value>` for training (and similar ones with
|
| 151 |
+
`training.hp` replaced by `generation` for generation).
|
| 152 |
+
- Diffusion utils: patch-based inference and lead time support with
|
| 153 |
+
deterministic sampler.
|
| 154 |
+
- Existing DGL-based XAeroNet example has been renamed to `xaeronet_dgl`.
|
| 155 |
+
Added new `xaeronet` example that uses PyTorch Geometric instead.
|
| 156 |
+
- Updated the deforming plate example to use the Hybrid MeshGraphNet model.
|
| 157 |
+
- ⚠️ **BC-breaking:** Refactored the `transolver` model to improve
|
| 158 |
+
readability and performance, and extend to more use cases.
|
| 159 |
+
- Diffusion models: improved lead time support for `SongUNetPosLtEmbd` and
|
| 160 |
+
`EDMLoss`. Lead-time embeddings can now be used with/without positional
|
| 161 |
+
embeddings.
|
| 162 |
+
- Diffusion models: consolidate `ApexGroupNorm` and `GroupNorm` in
|
| 163 |
+
`models/diffusion/layers.py` with a factory `get_group_norm` that can
|
| 164 |
+
be used to instantiate either one of them. `get_group_norm` is now the
|
| 165 |
+
recommended way to instantiate a GroupNorm layer in `SongUNet`-based and
|
| 166 |
+
other diffusion models.
|
| 167 |
+
- Physicsnemo models: improved checkpoint loading API in
|
| 168 |
+
`Module.from_checkpoint` that now exposes a `strict` parameter to raise error
|
| 169 |
+
on missing/unexpected keys, similar to that used in
|
| 170 |
+
`torch.nn.Module.load_state_dict`.
|
| 171 |
+
- Migrated Hybrid MGN and deforming plate example to PyTorch Geometric.
|
| 172 |
+
|
| 173 |
+
### Fixed
|
| 174 |
+
|
| 175 |
+
- Bug fixes in DoMINO model in sphere sampling and tensor reshaping
|
| 176 |
+
- Bug fixes in DoMINO utils random sampling and test.py
|
| 177 |
+
- Optimized DoMINO config params based on DrivAer ML
|
| 178 |
+
|
| 179 |
+
## [1.1.1] - 2025-06-16
|
| 180 |
+
|
| 181 |
+
### Fixed
|
| 182 |
+
|
| 183 |
+
- Fixed an inadvertent change to the deterministic sampler 2nd order correction
|
| 184 |
+
- Bug Fix in Domino model ball query layer
|
| 185 |
+
- Fixed bug models/unet/unet.py: setting num_conv_layers=1 gives errors
|
| 186 |
+
|
| 187 |
+
## [1.1.0] - 2025-06-05
|
| 188 |
+
|
| 189 |
+
### Added
|
| 190 |
+
|
| 191 |
+
- Added ReGen score-based data assimilation example
|
| 192 |
+
- General purpose patching API for patch-based diffusion
|
| 193 |
+
- New positional embedding selection strategy for CorrDiff SongUNet models
|
| 194 |
+
- Added Multi-Storage Client to allow checkpointing to/from Object Storage
|
| 195 |
+
- Added a new aerodynamics example using DoMINO to compute design sensitivities
|
| 196 |
+
(e.g., drag adjoint) with respect to underlying input geometry.
|
| 197 |
+
|
| 198 |
+
### Changed
|
| 199 |
+
|
| 200 |
+
- Simplified CorrDiff config files, updated default values
|
| 201 |
+
- Refactored CorrDiff losses and samplers to use the patching API
|
| 202 |
+
- Support for non-square images and patches in patch-based diffusion
|
| 203 |
+
- ERA5 download example updated to use current file format convention and
|
| 204 |
+
restricts global statistics computation to the training set
|
| 205 |
+
- Support for training custom StormCast models and various other improvements for StormCast
|
| 206 |
+
- Updated CorrDiff training code to support multiple patch iterations to amortize
|
| 207 |
+
regression cost and usage of `torch.compile`
|
| 208 |
+
- Refactored `physicsnemo/models/diffusion/layers.py` to optimize data type
|
| 209 |
+
casting workflow, avoiding unnecessary casting under autocast mode
|
| 210 |
+
- Refactored Conv2d to enable fusion of conv2d with bias addition
|
| 211 |
+
- Refactored GroupNorm, UNetBlock, SongUNet, SongUNetPosEmbd to support usage of
|
| 212 |
+
Apex GroupNorm, fusion of activation with GroupNorm, and AMP workflow.
|
| 213 |
+
- Updated SongUNetPosEmbd to avoid unnecessary HtoD Memcpy of `pos_embd`
|
| 214 |
+
- Updated `from_checkpoint` to accommodate conversion between Apex optimized ckp
|
| 215 |
+
and non-optimized ckp
|
| 216 |
+
- Refactored CorrDiff NVTX annotation workflow to be configurable
|
| 217 |
+
- Refactored `ResidualLoss` to support patch-accumlating training for
|
| 218 |
+
amortizing regression costs
|
| 219 |
+
- Explicit handling of Warp device for ball query and sdf
|
| 220 |
+
- Merged SongUNetPosLtEmb with SongUNetPosEmb, add support for batch>1
|
| 221 |
+
- Add lead time embedding support for `positional_embedding_selector`. Enable
|
| 222 |
+
arbitrary positioning of probabilistic variables
|
| 223 |
+
- Enable lead time aware regression without CE loss
|
| 224 |
+
- Bumped minimum PyTorch version from 2.0.0 to 2.4.0, to minimize
|
| 225 |
+
support surface for `physicsnemo.distributed` functionality.
|
| 226 |
+
|
| 227 |
+
### Dependencies
|
| 228 |
+
|
| 229 |
+
- Made `nvidia.dali` an optional dependency
|
| 230 |
+
|
| 231 |
+
## [1.0.1] - 2025-03-25
|
| 232 |
+
|
| 233 |
+
### Added
|
| 234 |
+
|
| 235 |
+
- Added version checks to ensure compatibility with older PyTorch for distributed
|
| 236 |
+
utilities and ShardTensor
|
| 237 |
+
|
| 238 |
+
### Fixed
|
| 239 |
+
|
| 240 |
+
- `EntryPoint` error that occured during physicsnemo checkpoint loading
|
| 241 |
+
|
| 242 |
+
## [1.0.0] - 2025-03-18
|
| 243 |
+
|
| 244 |
+
### Added
|
| 245 |
+
|
| 246 |
+
- DoMINO model architecture, datapipe and training recipe
|
| 247 |
+
- Added matrix decomposition scheme to improve graph partitioning
|
| 248 |
+
- DrivAerML dataset support in FIGConvNet example.
|
| 249 |
+
- Retraining recipe for DoMINO from a pretrained model checkpoint
|
| 250 |
+
- Prototype support for domain parallelism of using ShardTensor (new).
|
| 251 |
+
- Enable DeviceMesh initialization via DistributedManager.
|
| 252 |
+
- Added Datacenter CFD use case.
|
| 253 |
+
- Add leave-in profiling utilities to physicsnemo, to easily enable torch/python/nsight
|
| 254 |
+
profiling in all aspects of the codebase.
|
| 255 |
+
|
| 256 |
+
### Changed
|
| 257 |
+
|
| 258 |
+
- Refactored StormCast training example
|
| 259 |
+
- Enhancements and bug fixes to DoMINO model and training example
|
| 260 |
+
- Enhancement to parameterize DoMINO model with inlet velocity
|
| 261 |
+
- Moved non-dimensionaliztion out of domino datapipe to datapipe in domino example
|
| 262 |
+
- Updated utils in `physicsnemo.launch.logging` to avoid unnecessary `wandb` and `mlflow`
|
| 263 |
+
imports
|
| 264 |
+
- Moved to experiment-based Hydra config in Lagrangian-MGN example
|
| 265 |
+
- Make data caching optional in `MeshDatapipe`
|
| 266 |
+
- The use of older `importlib_metadata` library is removed
|
| 267 |
+
|
| 268 |
+
### Deprecated
|
| 269 |
+
|
| 270 |
+
- ProcessGroupConfig is tagged for future deprecation in favor of DeviceMesh.
|
| 271 |
+
|
| 272 |
+
### Fixed
|
| 273 |
+
|
| 274 |
+
- Update pytests to skip when the required dependencies are not present
|
| 275 |
+
- Bug in data processing script in domino training example
|
| 276 |
+
- Fixed NCCL_ASYNC_ERROR_HANDLING deprecation warning
|
| 277 |
+
|
| 278 |
+
### Dependencies
|
| 279 |
+
|
| 280 |
+
- Remove the numpy dependency upper bound
|
| 281 |
+
- Moved pytz and nvtx to optional
|
| 282 |
+
- Update the base image for the Dockerfile
|
| 283 |
+
- Introduce Multi-Storage Client (MSC) as an optional dependency.
|
| 284 |
+
- Introduce `wrapt` as an optional dependency, needed when using
|
| 285 |
+
ShardTensor's automatic domain parallelism
|
| 286 |
+
|
| 287 |
+
## [0.9.0] - 2024-12-04
|
| 288 |
+
|
| 289 |
+
### Added
|
| 290 |
+
|
| 291 |
+
- Graph Transformer processor for GraphCast/GenCast.
|
| 292 |
+
- Utility to generate STL from Signed Distance Field.
|
| 293 |
+
- Metrics for CAE and CFD domain such as integrals, drag, and turbulence invariances and
|
| 294 |
+
spectrum.
|
| 295 |
+
- Added gradient clipping to StaticCapture utilities.
|
| 296 |
+
- Bistride Multiscale MeshGraphNet example.
|
| 297 |
+
- FIGConvUNet model and example.
|
| 298 |
+
- The Transolver model.
|
| 299 |
+
- The XAeroNet model.
|
| 300 |
+
- Incoporated CorrDiff-GEFS-HRRR model into CorrDiff, with lead-time aware SongUNet and
|
| 301 |
+
cross entropy loss.
|
| 302 |
+
- Option to offload checkpoints to further reduce memory usage
|
| 303 |
+
- Added StormCast model training and simple inference to examples
|
| 304 |
+
- Multi-scale geometry features for DoMINO model.
|
| 305 |
+
|
| 306 |
+
### Changed
|
| 307 |
+
|
| 308 |
+
- Refactored CorrDiff training recipe for improved usability
|
| 309 |
+
- Fixed timezone calculation in datapipe cosine zenith utility.
|
| 310 |
+
- Refactored EDMPrecondSRV2 preconditioner and fixed the bug related to the metadata
|
| 311 |
+
- Extended the checkpointing utility to store metadata.
|
| 312 |
+
- Corrected missing export of loggin function used by transolver model
|
| 313 |
+
|
| 314 |
+
## [0.8.0] - 2024-09-24
|
| 315 |
+
|
| 316 |
+
### Added
|
| 317 |
+
|
| 318 |
+
- Graph Transformer processor for GraphCast/GenCast.
|
| 319 |
+
- Utility to generate STL from Signed Distance Field.
|
| 320 |
+
- Metrics for CAE and CFD domain such as integrals, drag, and turbulence invariances and
|
| 321 |
+
spectrum.
|
| 322 |
+
- Added gradient clipping to StaticCapture utilities.
|
| 323 |
+
- Bistride Multiscale MeshGraphNet example.
|
| 324 |
+
|
| 325 |
+
### Changed
|
| 326 |
+
|
| 327 |
+
- Refactored CorrDiff training recipe for improved usability
|
| 328 |
+
- Fixed timezone calculation in datapipe cosine zenith utility.
|
| 329 |
+
|
| 330 |
+
## [0.7.0] - 2024-07-23
|
| 331 |
+
|
| 332 |
+
### Added
|
| 333 |
+
|
| 334 |
+
- Code logging for CorrDiff via Wandb.
|
| 335 |
+
- Augmentation pipeline for CorrDiff.
|
| 336 |
+
- Regression output as additional conditioning for CorrDiff.
|
| 337 |
+
- Learnable positional embedding for CorrDiff.
|
| 338 |
+
- Support for patch-based CorrDiff training and generation (stochastic sampling only)
|
| 339 |
+
- Enable CorrDiff multi-gpu generation
|
| 340 |
+
- Diffusion model for fluid data super-resolution (CMU contribution).
|
| 341 |
+
- The Virtual Foundry GraphNet.
|
| 342 |
+
- A synthetic dataloader for global weather prediction models, demonstrated on GraphCast.
|
| 343 |
+
- Sorted Empirical CDF CRPS algorithm
|
| 344 |
+
- Support for history, cos zenith, and downscaling/upscaling in the ERA5 HDF5 dataloader.
|
| 345 |
+
- An example showing how to train a "tensor-parallel" version of GraphCast on a
|
| 346 |
+
Shallow-Water-Equation example.
|
| 347 |
+
- 3D UNet
|
| 348 |
+
- AeroGraphNet example of training of MeshGraphNet on Ahmed body and DrivAerNet datasets.
|
| 349 |
+
- Warp SDF routine
|
| 350 |
+
- DLWP HEALPix model
|
| 351 |
+
- Pangu Weather model
|
| 352 |
+
- Fengwu model
|
| 353 |
+
- SwinRNN model
|
| 354 |
+
- Modulated AFNO model
|
| 355 |
+
|
| 356 |
+
### Changed
|
| 357 |
+
|
| 358 |
+
- Raise `PhysicsNeMoUndefinedGroupError` when querying undefined process groups
|
| 359 |
+
- Changed Indexing error in `examples/cfd/swe_nonlinear_pino` for `physicsnemo` loss function
|
| 360 |
+
- Safeguarding against uninitialized usage of `DistributedManager`
|
| 361 |
+
|
| 362 |
+
### Removed
|
| 363 |
+
|
| 364 |
+
- Remove mlflow from deployment image
|
| 365 |
+
|
| 366 |
+
### Fixed
|
| 367 |
+
|
| 368 |
+
- Fixed bug in the partitioning logic for distributing graph structures
|
| 369 |
+
intended for distributed message-passing.
|
| 370 |
+
- Fixed bugs for corrdiff diffusion training of `EDMv1` and `EDMv2`
|
| 371 |
+
- Fixed bug when trying to save DDP model trained through unified recipe
|
| 372 |
+
|
| 373 |
+
### Dependencies
|
| 374 |
+
|
| 375 |
+
- Update DALI to CUDA 12 compatible version.
|
| 376 |
+
- Update minimum python version to 3.10
|
| 377 |
+
|
| 378 |
+
## [0.6.0] - 2024-04-17
|
| 379 |
+
|
| 380 |
+
### Added
|
| 381 |
+
|
| 382 |
+
- The citation file.
|
| 383 |
+
- Link to the CWA dataset.
|
| 384 |
+
- ClimateDatapipe: an improved datapipe for HDF5/NetCDF4 formatted climate data
|
| 385 |
+
- Performance optimizations to CorrDiff.
|
| 386 |
+
- Physics-Informed Nonlinear Shallow Water Equations example.
|
| 387 |
+
- Warp neighbor search routine with a minimal example.
|
| 388 |
+
- Strict option for loading PhysicsNeMo checkpoints.
|
| 389 |
+
- Regression only or diffusion only inference for CorrDiff.
|
| 390 |
+
- Support for organization level model files on NGC file system
|
| 391 |
+
- Physics-Informed Magnetohydrodynamics example.
|
| 392 |
+
|
| 393 |
+
### Changed
|
| 394 |
+
|
| 395 |
+
- Updated Ahmed Body and Vortex Shedding examples to use Hydra config.
|
| 396 |
+
- Added more config options to FCN AFNO example.
|
| 397 |
+
- Moved posiitonal embedding in CorrDiff from the dataloader to network architecture
|
| 398 |
+
|
| 399 |
+
### Deprecated
|
| 400 |
+
|
| 401 |
+
- `physicsnemo.models.diffusion.preconditioning.EDMPrecondSR`. Use `EDMPecondSRV2` instead.
|
| 402 |
+
|
| 403 |
+
### Removed
|
| 404 |
+
|
| 405 |
+
- Pickle dependency for CorrDiff.
|
| 406 |
+
|
| 407 |
+
### Fixed
|
| 408 |
+
|
| 409 |
+
- Consistent handling of single GPU runs in DistributedManager
|
| 410 |
+
- Output location of objects downloaded with NGC file system
|
| 411 |
+
- Bug in scaling the conditional input in CorrDiff deterministic sampler
|
| 412 |
+
|
| 413 |
+
### Dependencies
|
| 414 |
+
|
| 415 |
+
- Updated DGL build in Dockerfile
|
| 416 |
+
- Updated default base image
|
| 417 |
+
- Moved Onnx from optional to required dependencies
|
| 418 |
+
- Optional Makani dependency required for SFNO model.
|
| 419 |
+
|
| 420 |
+
## [0.5.0] - 2024-01-25
|
| 421 |
+
|
| 422 |
+
### Added
|
| 423 |
+
|
| 424 |
+
- Distributed process group configuration mechanism.
|
| 425 |
+
- DistributedManager utility to instantiate process groups based on a process group config.
|
| 426 |
+
- Helper functions to faciliate distributed training with shared parameters.
|
| 427 |
+
- Brain anomaly detection example.
|
| 428 |
+
- Updated Frechet Inception Distance to use Wasserstein 2-norm with improved stability.
|
| 429 |
+
- Molecular Dynamics example.
|
| 430 |
+
- Improved usage of GraphPartition, added more flexible ways of defining a partitioned graph.
|
| 431 |
+
- Physics-Informed Stokes Flow example.
|
| 432 |
+
- Profiling markers, benchmarking and performance optimizations for CorrDiff inference.
|
| 433 |
+
- Unified weather model training example.
|
| 434 |
+
|
| 435 |
+
### Changed
|
| 436 |
+
|
| 437 |
+
- MLFLow logging such that only proc 0 logs to MLFlow.
|
| 438 |
+
- FNO given seperate methods for constructing lift and spectral encoder layers.
|
| 439 |
+
|
| 440 |
+
### Removed
|
| 441 |
+
|
| 442 |
+
- The experimental SFNO
|
| 443 |
+
|
| 444 |
+
### Dependencies
|
| 445 |
+
|
| 446 |
+
- Removed experimental SFNO dependencies
|
| 447 |
+
- Added CorrDiff dependencies (cftime, einops, pyspng, nvtx)
|
| 448 |
+
- Made tqdm a required dependency
|
| 449 |
+
|
| 450 |
+
## [0.4.0] - 2023-11-20
|
| 451 |
+
|
| 452 |
+
### Added
|
| 453 |
+
|
| 454 |
+
- Added Stokes flow dataset
|
| 455 |
+
- An experimental version of SFNO to be used in unified training recipe for
|
| 456 |
+
weather models
|
| 457 |
+
- Added distributed FFT utility.
|
| 458 |
+
- Added ruff as a linting tool.
|
| 459 |
+
- Ported utilities from PhysicsNeMo Launch to main package.
|
| 460 |
+
- EDM diffusion models and recipes for training and sampling.
|
| 461 |
+
- NGC model registry download integration into package/filesystem.
|
| 462 |
+
- Denoising diffusion tutorial.
|
| 463 |
+
|
| 464 |
+
### Changed
|
| 465 |
+
|
| 466 |
+
- The AFNO input argument `img_size` to `inp_shape`
|
| 467 |
+
- Integrated the network architecture layers from PhysicsNeMo-Sym.
|
| 468 |
+
- Updated the SFNO model, and the training and inference recipes.
|
| 469 |
+
|
| 470 |
+
### Fixed
|
| 471 |
+
|
| 472 |
+
- Fixed physicsnemo.Module `from_checkpoint` to work from custom model classes
|
| 473 |
+
|
| 474 |
+
### Dependencies
|
| 475 |
+
|
| 476 |
+
- Updated the base container to PyTorch 23.10.
|
| 477 |
+
- Updated examples to use Pydantic v2.
|
| 478 |
+
|
| 479 |
+
## [0.3.0] - 2023-09-21
|
| 480 |
+
|
| 481 |
+
### Added
|
| 482 |
+
|
| 483 |
+
- Added ability to compute CRPS(..., dim: int = 0).
|
| 484 |
+
- Added EFI for arbitrary climatological CDF.
|
| 485 |
+
- Added Kernel CRPS implementation (kcrps)
|
| 486 |
+
- Added distributed utilities to create process groups and orthogonal process groups.
|
| 487 |
+
- Added distributed AFNO model implementation.
|
| 488 |
+
- Added distributed utilities for communication of buffers of varying size per rank.
|
| 489 |
+
- Added distributed utilities for message passing across multiple GPUs.
|
| 490 |
+
- Added instructions for docker build on ARM architecture.
|
| 491 |
+
- Added batching support and fix the input time step for the DLWP wrapper.
|
| 492 |
+
|
| 493 |
+
### Changed
|
| 494 |
+
|
| 495 |
+
- Updating file system cache location to physicsnemo folder
|
| 496 |
+
|
| 497 |
+
### Fixed
|
| 498 |
+
|
| 499 |
+
- Fixed physicsnemo uninstall in CI docker image
|
| 500 |
+
|
| 501 |
+
### Security
|
| 502 |
+
|
| 503 |
+
- Handle the tar ball extracts in a safer way.
|
| 504 |
+
|
| 505 |
+
### Dependencies
|
| 506 |
+
|
| 507 |
+
- Updated the base container to latest PyTorch 23.07.
|
| 508 |
+
- Update DGL version.
|
| 509 |
+
- Updated require installs for python wheel
|
| 510 |
+
- Added optional dependency list for python wheel
|
| 511 |
+
|
| 512 |
+
## [0.2.1] - 2023-08-08
|
| 513 |
+
|
| 514 |
+
### Fixed
|
| 515 |
+
|
| 516 |
+
- Added a workaround fix for the CUDA graphs error in multi-node runs
|
| 517 |
+
|
| 518 |
+
### Security
|
| 519 |
+
|
| 520 |
+
- Update `certifi` package version
|
| 521 |
+
|
| 522 |
+
## [0.2.0] - 2023-08-07
|
| 523 |
+
|
| 524 |
+
### Added
|
| 525 |
+
|
| 526 |
+
- Added a CHANGELOG.md
|
| 527 |
+
- Added build support for internal DGL
|
| 528 |
+
- 4D Fourier Neural Operator model
|
| 529 |
+
- Ahmed body dataset
|
| 530 |
+
- Unified Climate Datapipe
|
| 531 |
+
|
| 532 |
+
### Changed
|
| 533 |
+
|
| 534 |
+
- DGL install changed from pypi to source
|
| 535 |
+
- Updated SFNO to add support for super resolution, flexible checkpoining, etc.
|
| 536 |
+
|
| 537 |
+
### Fixed
|
| 538 |
+
|
| 539 |
+
- Fixed issue with torch-harmonics version locking
|
| 540 |
+
- Fixed the PhysicsNeMo editable install
|
| 541 |
+
- Fixed AMP bug in static capture
|
| 542 |
+
|
| 543 |
+
### Security
|
| 544 |
+
|
| 545 |
+
- Fixed security issues with subprocess and urllib in `filesystem.py`
|
| 546 |
+
|
| 547 |
+
### Dependencies
|
| 548 |
+
|
| 549 |
+
- Updated the base container to latest PyTorch base container which is based on torch 2.0
|
| 550 |
+
- Container now supports CUDA 12, Python 3.10
|
| 551 |
+
|
| 552 |
+
## [0.1.0] - 2023-05-08
|
| 553 |
+
|
| 554 |
+
### Added
|
| 555 |
+
|
| 556 |
+
- Initial public release.
|
physics_mcp/source/CITATION.cff
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
cff-version: 1.2.0
|
| 2 |
+
message: "If you use this software, please cite it as below."
|
| 3 |
+
title: "NVIDIA PhysicsNeMo: An open-source framework for physics-based deep learning in science and engineering"
|
| 4 |
+
date-released: "2023-02-24"
|
| 5 |
+
authors:
|
| 6 |
+
- name: "PhysicsNeMo Contributors"
|
| 7 |
+
repository-code: "https://github.com/NVIDIA/physicsnemo"
|
physics_mcp/source/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,251 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# PhysicsNeMo Contribution Guide
|
| 2 |
+
|
| 3 |
+
## Introduction
|
| 4 |
+
|
| 5 |
+
Welcome to Project PhysicsNeMo! We're excited you're here and want to contribute.
|
| 6 |
+
This documentation is intended for individuals and institutions interested in
|
| 7 |
+
contributing to PhysicsNeMo. PhysicsNeMo is an open-source project and, as such, its
|
| 8 |
+
success relies on its community of contributors willing to keep improving it.
|
| 9 |
+
Your contribution will be a valued addition to the code base; we simply ask
|
| 10 |
+
that you read this page and understand our contribution process, whether you
|
| 11 |
+
are a seasoned open-source contributor or whether you are a first-time
|
| 12 |
+
contributor.
|
| 13 |
+
|
| 14 |
+
### Communicate with Us
|
| 15 |
+
|
| 16 |
+
We are happy to talk with you about your needs for PhysicsNeMo and your ideas for
|
| 17 |
+
contributing to the project. One way to do this is to create an issue discussing
|
| 18 |
+
your thoughts. It might be that a very similar feature is under development or
|
| 19 |
+
already exists, so an issue is a great starting point. If you are looking for an
|
| 20 |
+
issue to resolve that will help, refer to the
|
| 21 |
+
[issue](https://github.com/NVIDIA/physicsnemo/issues) section.
|
| 22 |
+
If you are considering collaborating with NVIDIA PhysicsNeMo team to enhance PhysicsNeMo,
|
| 23 |
+
fill this [proposal form](https://forms.gle/fYsbZEtgRWJUQ3oQ9) and
|
| 24 |
+
we will get back to you.
|
| 25 |
+
|
| 26 |
+
## Contribute to PhysicsNeMo-Core
|
| 27 |
+
|
| 28 |
+
### Pull Requests
|
| 29 |
+
|
| 30 |
+
Developer workflow for code contributions is as follows:
|
| 31 |
+
|
| 32 |
+
1. Developers must first [fork](https://help.github.com/en/articles/fork-a-repo)
|
| 33 |
+
the [upstream](https://github.com/NVIDIA/physicsnemo) PhysicsNeMo repository.
|
| 34 |
+
|
| 35 |
+
2. Git clone the forked repository and push changes to the personal fork.
|
| 36 |
+
|
| 37 |
+
3. Once the code changes are staged on the fork and ready for review, a
|
| 38 |
+
[Pull Request](https://help.github.com/en/articles/about-pull-requests) (PR)
|
| 39 |
+
can be [requested](https://help.github.com/en/articles/creating-a-pull-request)
|
| 40 |
+
to merge the changes from a branch of the fork into a selected branch of upstream.
|
| 41 |
+
|
| 42 |
+
- Exercise caution when selecting the source and target branches for the PR.
|
| 43 |
+
- Ensure that you update the [`CHANGELOG.md`](CHANGELOG.md) to reflect your contributions.
|
| 44 |
+
- Creation of a PR creation kicks off CI and a code review process.
|
| 45 |
+
- Atleast one PhysicsNeMo engineer will be assigned for the review.
|
| 46 |
+
|
| 47 |
+
4. The PR will be accepted and the corresponding issue closed after adequate review and
|
| 48 |
+
testing has been completed. Note that every PR should correspond to an open issue and
|
| 49 |
+
should be linked on Github.
|
| 50 |
+
|
| 51 |
+
### Licensing Information
|
| 52 |
+
|
| 53 |
+
All source code files should start with this paragraph:
|
| 54 |
+
|
| 55 |
+
```bash
|
| 56 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2024 NVIDIA CORPORATION & AFFILIATES.
|
| 57 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 58 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 59 |
+
#
|
| 60 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 61 |
+
# you may not use this file except in compliance with the License.
|
| 62 |
+
# You may obtain a copy of the License at
|
| 63 |
+
#
|
| 64 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 65 |
+
#
|
| 66 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 67 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 68 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 69 |
+
# See the License for the specific language governing permissions and
|
| 70 |
+
# limitations under the License.
|
| 71 |
+
```
|
| 72 |
+
|
| 73 |
+
### Signing Your Work
|
| 74 |
+
|
| 75 |
+
- We require that all contributors "sign-off" on their commits. This certifies that the
|
| 76 |
+
contribution is your original work, or you have rights to submit it under the same
|
| 77 |
+
license, or a compatible license.
|
| 78 |
+
|
| 79 |
+
- Any contribution which contains commits that are not Signed-Off will not be accepted.
|
| 80 |
+
|
| 81 |
+
- To sign off on a commit you simply use the `--signoff` (or `-s`) option when
|
| 82 |
+
committing your changes:
|
| 83 |
+
|
| 84 |
+
```bash
|
| 85 |
+
git commit -s -m "Add cool feature."
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
This will append the following to your commit message:
|
| 89 |
+
|
| 90 |
+
```text
|
| 91 |
+
Signed-off-by: Your Name <your@email.com>
|
| 92 |
+
```
|
| 93 |
+
|
| 94 |
+
- Full text of the DCO:
|
| 95 |
+
|
| 96 |
+
```text
|
| 97 |
+
Developer Certificate of Origin
|
| 98 |
+
Version 1.1
|
| 99 |
+
|
| 100 |
+
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
|
| 101 |
+
1 Letterman Drive
|
| 102 |
+
Suite D4700
|
| 103 |
+
San Francisco, CA, 94129
|
| 104 |
+
|
| 105 |
+
Everyone is permitted to copy and distribute verbatim copies of this license
|
| 106 |
+
document, but changing it is not allowed.
|
| 107 |
+
```
|
| 108 |
+
|
| 109 |
+
```text
|
| 110 |
+
Developer's Certificate of Origin 1.1
|
| 111 |
+
|
| 112 |
+
By making a contribution to this project, I certify that:
|
| 113 |
+
|
| 114 |
+
(a) The contribution was created in whole or in part by me and I have the right to
|
| 115 |
+
submit it under the open source license indicated in the file; or
|
| 116 |
+
|
| 117 |
+
(b) The contribution is based upon previous work that, to the best of my knowledge,
|
| 118 |
+
is covered under an appropriate open source license and I have the right under that
|
| 119 |
+
license to submit that work with modifications, whether created in whole or in part
|
| 120 |
+
by me, under the same open source license (unless I am permitted to submit under a
|
| 121 |
+
different license), as indicated in the file; or
|
| 122 |
+
|
| 123 |
+
(c) The contribution was provided directly to me by some other person who certified
|
| 124 |
+
(a), (b) or (c) and I have not modified it.
|
| 125 |
+
|
| 126 |
+
(d) I understand and agree that this project and the contribution are public and
|
| 127 |
+
that a record of the contribution (including all personal information I submit with
|
| 128 |
+
it, including my sign-off) is maintained indefinitely and may be redistributed
|
| 129 |
+
consistent with this project or the open source license(s) involved.
|
| 130 |
+
|
| 131 |
+
```
|
| 132 |
+
|
| 133 |
+
### Pre-commit
|
| 134 |
+
|
| 135 |
+
For PhysicsNeMo development, [pre-commit](https://pre-commit.com/) is **required**.
|
| 136 |
+
This will not only help developers pass the CI pipeline, but also accelerate reviews.
|
| 137 |
+
Contributions that have not used pre-commit will *not be reviewed*.
|
| 138 |
+
|
| 139 |
+
`pre-commit` is installed as part of the `dev` optional dependencies defined in `pyproject.toml`.
|
| 140 |
+
To install `pre-commit` in an existing environment, follow the below steps inside the PhysicsNeMo
|
| 141 |
+
repository folder:
|
| 142 |
+
|
| 143 |
+
```bash
|
| 144 |
+
pip install pre-commit
|
| 145 |
+
pre-commit install
|
| 146 |
+
```
|
| 147 |
+
|
| 148 |
+
Once the above commands are executed, the pre-commit hooks will be activated and all
|
| 149 |
+
the commits will be checked for appropriate formatting.
|
| 150 |
+
|
| 151 |
+
### Continuous Integration (CI)
|
| 152 |
+
|
| 153 |
+
To ensure quality of the code, your merge request (MR) will pass through several CI checks.
|
| 154 |
+
It is mandatory for your MRs to pass these pipelines to ensure a successful merge.
|
| 155 |
+
Please keep checking this document for the latest guidelines on pushing code. Currently,
|
| 156 |
+
The pipeline has following stages:
|
| 157 |
+
|
| 158 |
+
1. `format`
|
| 159 |
+
*Pre-commit will check this for you!* Checks for formatting of your
|
| 160 |
+
Python code, using `ruff format` via [Ruff](https://docs.astral.sh/ruff/).
|
| 161 |
+
If your MR fails this test, run `ruff format <script-name>.py` on
|
| 162 |
+
problematic scripts and Ruff will take care of the rest.
|
| 163 |
+
|
| 164 |
+
2. `interrogate`
|
| 165 |
+
*Pre-commit will check this for you!*
|
| 166 |
+
Checks if the code being pushed is well documented. The goal is to make the
|
| 167 |
+
documentation live inside code. Very few exceptions are made.
|
| 168 |
+
Elements that are fine to have no documentation include `init-module`, `init-method`,
|
| 169 |
+
`private` and `semiprivate` classes/functions and `dunder` methods. For definitions of
|
| 170 |
+
these, refer [interrogate](https://interrogate.readthedocs.io/en/latest/). Meaning for
|
| 171 |
+
some methods/functions is very explicit and exceptions for these are made. These
|
| 172 |
+
include `forward`, `reset_parameters`, `extra_repr`, `MetaData`. If your MR fails this
|
| 173 |
+
test, add the missing documentation. Take a look at the pipeline output for hints on
|
| 174 |
+
which functions/classes need documentation.
|
| 175 |
+
To test the documentation before making a commit, you can run the following during
|
| 176 |
+
your development
|
| 177 |
+
|
| 178 |
+
```bash
|
| 179 |
+
interrogate \
|
| 180 |
+
--ignore-init-method \
|
| 181 |
+
--ignore-init-module \
|
| 182 |
+
--ignore-module \
|
| 183 |
+
--ignore-private \
|
| 184 |
+
--ignore-semiprivate \
|
| 185 |
+
--ignore-magic \
|
| 186 |
+
--fail-under 99 \
|
| 187 |
+
--exclude '[setup.py]' \
|
| 188 |
+
--ignore-regex forward \
|
| 189 |
+
--ignore-regex reset_parameters \
|
| 190 |
+
--ignore-regex extra_repr \
|
| 191 |
+
--ignore-regex MetaData \
|
| 192 |
+
-vv \
|
| 193 |
+
--color \
|
| 194 |
+
./physicsnemo/
|
| 195 |
+
```
|
| 196 |
+
|
| 197 |
+
3. `lint`
|
| 198 |
+
*Pre-commit will check this for you!*
|
| 199 |
+
Linters will perform static analysis to check the style, complexity, errors
|
| 200 |
+
and more. For markdown files `markdownlint` is used, its suggested to use
|
| 201 |
+
the vscode, neovim or sublime
|
| 202 |
+
[extensions](https://github.com/DavidAnson/markdownlint#related).
|
| 203 |
+
PhysicsNeMo uses `ruff check` via[Ruff](https://docs.astral.sh/ruff/) for
|
| 204 |
+
linting of various types. Currently we use flake8/pycodestyle (`E`),
|
| 205 |
+
Pyflakes (`F`), flake8-bandit (`S`), isort (`I`), and performance 'PERF'
|
| 206 |
+
rules. Many rule violations will be automatically fixed by Ruff; others may
|
| 207 |
+
require manual changes.
|
| 208 |
+
|
| 209 |
+
4. `license`
|
| 210 |
+
*Pre-commit will check this for you!*
|
| 211 |
+
Checks for correct license headers of all files.
|
| 212 |
+
To run this locally use `make license`.
|
| 213 |
+
See the Licensing Information section above for details about the license header required.
|
| 214 |
+
|
| 215 |
+
5. `pytest`
|
| 216 |
+
Checks if the test scripts from the `test` folder run and produce desired outputs. It
|
| 217 |
+
is imperative that your changes don't break the existing tests. If your MR fails this
|
| 218 |
+
test, you will have to review your changes and fix the issues.
|
| 219 |
+
To run pytest locally you can simply run `pytest` inside the `test` folder.
|
| 220 |
+
|
| 221 |
+
While writing these tests, we encourage you to make use of the [`@import_of_fail`](https://github.com/NVIDIA/physicsnemo/blob/main/test/pytest_utils.py#L25)
|
| 222 |
+
decorator to appropriately skip your tests for developers and users not having your
|
| 223 |
+
test specific dependencies. This mechanism helps us provide a better developer and
|
| 224 |
+
user experience when working with the unit tests.
|
| 225 |
+
|
| 226 |
+
Some of the tests require test data to be run; otherwise, they will be skipped.
|
| 227 |
+
To get the data (available to NVIDIANs only), set the `TEST_DATA_DIR` environment variable
|
| 228 |
+
to a desired value and run make get-data. After that, pytest will use the same
|
| 229 |
+
variable to find the test data. Alternatively, you can pass it explicitly using
|
| 230 |
+
`pytest --nfs-data-dir=<path to test data>`.
|
| 231 |
+
|
| 232 |
+
6. `doctest`
|
| 233 |
+
Checks if the examples in the docstrings run and produce desired outputs.
|
| 234 |
+
It is highly recommended that you provide simple examples of your functions/classes
|
| 235 |
+
in the code's docstring itself.
|
| 236 |
+
Keep these examples simple and also add the expected outputs.
|
| 237 |
+
Refer [doctest](https://docs.python.org/3/library/doctest.html) for more information.
|
| 238 |
+
If your MR fails this test, check your changes and the docstrings.
|
| 239 |
+
To run doctest locally, you can simply run `pytest --doctest-modules` inside the
|
| 240 |
+
`physicsnemo` folder.
|
| 241 |
+
|
| 242 |
+
7. `coverage`
|
| 243 |
+
Checks if your code additions have sufficient coverage.
|
| 244 |
+
Refer [coverage](https://coverage.readthedocs.io/en/6.5.0/index.html#) for more details.
|
| 245 |
+
If your MR fails this test, this means that you have not added enough tests to the `test`
|
| 246 |
+
folder for your module/functions.
|
| 247 |
+
Add extensive test scripts to cover different
|
| 248 |
+
branches and lines of your additions.
|
| 249 |
+
Aim for more than 80% code coverage.
|
| 250 |
+
To test coverage locally, run the `get_coverage.sh` script from the `test` folder and
|
| 251 |
+
check the coverage of the module that you added/edited.
|
physics_mcp/source/FAQ.md
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Frequently Asked Questions about PhysicsNeMo
|
| 2 |
+
|
| 3 |
+
## Table of contents
|
| 4 |
+
|
| 5 |
+
- [What is the recommended hardware for training using PhysicsNeMo framework?](#what-is-the-recommended-hardware-for-training-using-physicsnemo-framework)
|
| 6 |
+
- [What model architectures are in PhysicsNeMo?](#what-model-architectures-are-in-physicsnemo)
|
| 7 |
+
- [What is the difference between PhysicsNeMo Core and Symbolic?](#what-is-the-difference-between-physicsnemo-core-and-symbolic)
|
| 8 |
+
- [What can I do if I dont see a PDE in PhysicsNeMo?](#what-can-i-do-if-i-dont-see-a-pde-in-physicsnemo)
|
| 9 |
+
- [What is the difference between the pip install and the container?](#what-is-the-difference-between-the-pip-install-and-the-container)
|
| 10 |
+
|
| 11 |
+
## What is the recommended hardware for training using PhysicsNeMo framework?
|
| 12 |
+
|
| 13 |
+
Please refer to the recommended hardware section:
|
| 14 |
+
[System Requirements](https://docs.nvidia.com/deeplearning/physicsnemo/getting-started/index.html#system-requirements)
|
| 15 |
+
|
| 16 |
+
## What model architectures are in PhysicsNeMo?
|
| 17 |
+
|
| 18 |
+
Nvidia PhysicsNeMo is built on top of PyTorch and you can build and train any model
|
| 19 |
+
architecture you want in PhysicsNeMo. PhysicsNeMo however has a catalog of models that
|
| 20 |
+
have been packaged in a configurable form to make it easy to retrain with new data or certain
|
| 21 |
+
config parameters. Examples include GNNs like MeshGraphNet or Neural Operators like FNO.
|
| 22 |
+
PhysicsNeMo samples have more models that illustrate how a specific approach with a specific
|
| 23 |
+
model architecture can be applied to a specific problem.
|
| 24 |
+
These are reference starting points for users to get started.
|
| 25 |
+
|
| 26 |
+
You can find the list of built in model architectures
|
| 27 |
+
[here](https://github.com/NVIDIA/physicsnemo/tree/main/physicsnemo/models) and
|
| 28 |
+
[here](https://github.com/NVIDIA/physicsnemo-sym/tree/main/physicsnemo/sym/models)
|
| 29 |
+
|
| 30 |
+
## What is the difference between PhysicsNeMo Core and Symbolic?
|
| 31 |
+
|
| 32 |
+
PhysicsNeMo core is the foundational module that provides the core algorithms, network
|
| 33 |
+
architectures and utilities that cover a broad spectrum of Physics-ML approaches.
|
| 34 |
+
PhysicsNeMo Symbolic provides pythonic APIs, algorithms and utilities to be used with
|
| 35 |
+
PhysicsNeMo core, to explicitly physics inform the model training. This includes symbolic
|
| 36 |
+
APIs for PDEs, domain sampling and PDE-based residuals. It also provides higher level
|
| 37 |
+
abstraction to compose a training loop from specification of the geometry, PDEs and
|
| 38 |
+
constraints like boundary conditions using simple symbolic APIs.
|
| 39 |
+
So if you are familiar with PyTorch and want to train model from a dataset, you start
|
| 40 |
+
with PhysicsNeMo core and you import PhysicsNeMo symbolic to bring in explicit domain knowledge.
|
| 41 |
+
Please refer to the [DeepONet example](https://github.com/physicsnemo/tree/main/examples/cfd/darcy_deeponet_physics)
|
| 42 |
+
that illustrates the concept.
|
| 43 |
+
If you are an engineer or domain expert accustomed to using numerical solvers, you can
|
| 44 |
+
use PhysicsNeMo Symbolic to define your problem at a higher level of abstraction. Please
|
| 45 |
+
refer to the [Lid Driven cavity](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/basics/lid_driven_cavity_flow.html)
|
| 46 |
+
that illustrates the concept.
|
| 47 |
+
|
| 48 |
+
## What can I do if I dont see a PDE in PhysicsNeMo?
|
| 49 |
+
|
| 50 |
+
PhysicsNeMo Symbolic provides a well documented
|
| 51 |
+
[example](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/foundational/1d_wave_equation.html#writing-custom-pdes-and-boundary-initial-conditions)
|
| 52 |
+
that walks you through how to define a custom PDE. Please see the source [here](https://github.com/NVIDIA/physicsnemo-sym/tree/main/physicsnemo/sym/eq/pdes)
|
| 53 |
+
to see the built-in PDE implementation as an additional reference for your own implementation.
|
| 54 |
+
|
| 55 |
+
## What is the difference between the pip install and the container?
|
| 56 |
+
|
| 57 |
+
There is no functional difference between the two. This is to simplify the ease of
|
| 58 |
+
installing and setting up the PhysicsNeMo environment. Please refer to the
|
| 59 |
+
[getting started guide](https://docs.nvidia.com/deeplearning/physicsnemo/getting-started/index.html#physicsnemo-with-docker-image-recommended)
|
| 60 |
+
on how to install using Pip or using the container.
|
physics_mcp/source/LICENSE.txt
ADDED
|
@@ -0,0 +1,201 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Apache License
|
| 2 |
+
Version 2.0, January 2004
|
| 3 |
+
http://www.apache.org/licenses/
|
| 4 |
+
|
| 5 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 6 |
+
|
| 7 |
+
1. Definitions.
|
| 8 |
+
|
| 9 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 10 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 11 |
+
|
| 12 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 13 |
+
the copyright owner that is granting the License.
|
| 14 |
+
|
| 15 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 16 |
+
other entities that control, are controlled by, or are under common
|
| 17 |
+
control with that entity. For the purposes of this definition,
|
| 18 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 19 |
+
direction or management of such entity, whether by contract or
|
| 20 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 21 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 22 |
+
|
| 23 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 24 |
+
exercising permissions granted by this License.
|
| 25 |
+
|
| 26 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 27 |
+
including but not limited to software source code, documentation
|
| 28 |
+
source, and configuration files.
|
| 29 |
+
|
| 30 |
+
"Object" form shall mean any form resulting from mechanical
|
| 31 |
+
transformation or translation of a Source form, including but
|
| 32 |
+
not limited to compiled object code, generated documentation,
|
| 33 |
+
and conversions to other media types.
|
| 34 |
+
|
| 35 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 36 |
+
Object form, made available under the License, as indicated by a
|
| 37 |
+
copyright notice that is included in or attached to the work
|
| 38 |
+
(an example is provided in the Appendix below).
|
| 39 |
+
|
| 40 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 41 |
+
form, that is based on (or derived from) the Work and for which the
|
| 42 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 43 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 44 |
+
of this License, Derivative Works shall not include works that remain
|
| 45 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 46 |
+
the Work and Derivative Works thereof.
|
| 47 |
+
|
| 48 |
+
"Contribution" shall mean any work of authorship, including
|
| 49 |
+
the original version of the Work and any modifications or additions
|
| 50 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 51 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 52 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 53 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 54 |
+
means any form of electronic, verbal, or written communication sent
|
| 55 |
+
to the Licensor or its representatives, including but not limited to
|
| 56 |
+
communication on electronic mailing lists, source code control systems,
|
| 57 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 58 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 59 |
+
excluding communication that is conspicuously marked or otherwise
|
| 60 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 61 |
+
|
| 62 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 63 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 64 |
+
subsequently incorporated within the Work.
|
| 65 |
+
|
| 66 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 67 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 68 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 69 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 70 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 71 |
+
Work and such Derivative Works in Source or Object form.
|
| 72 |
+
|
| 73 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 74 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 75 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 76 |
+
(except as stated in this section) patent license to make, have made,
|
| 77 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 78 |
+
where such license applies only to those patent claims licensable
|
| 79 |
+
by such Contributor that are necessarily infringed by their
|
| 80 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 81 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 82 |
+
institute patent litigation against any entity (including a
|
| 83 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 84 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 85 |
+
or contributory patent infringement, then any patent licenses
|
| 86 |
+
granted to You under this License for that Work shall terminate
|
| 87 |
+
as of the date such litigation is filed.
|
| 88 |
+
|
| 89 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 90 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 91 |
+
modifications, and in Source or Object form, provided that You
|
| 92 |
+
meet the following conditions:
|
| 93 |
+
|
| 94 |
+
(a) You must give any other recipients of the Work or
|
| 95 |
+
Derivative Works a copy of this License; and
|
| 96 |
+
|
| 97 |
+
(b) You must cause any modified files to carry prominent notices
|
| 98 |
+
stating that You changed the files; and
|
| 99 |
+
|
| 100 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 101 |
+
that You distribute, all copyright, patent, trademark, and
|
| 102 |
+
attribution notices from the Source form of the Work,
|
| 103 |
+
excluding those notices that do not pertain to any part of
|
| 104 |
+
the Derivative Works; and
|
| 105 |
+
|
| 106 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 107 |
+
distribution, then any Derivative Works that You distribute must
|
| 108 |
+
include a readable copy of the attribution notices contained
|
| 109 |
+
within such NOTICE file, excluding those notices that do not
|
| 110 |
+
pertain to any part of the Derivative Works, in at least one
|
| 111 |
+
of the following places: within a NOTICE text file distributed
|
| 112 |
+
as part of the Derivative Works; within the Source form or
|
| 113 |
+
documentation, if provided along with the Derivative Works; or,
|
| 114 |
+
within a display generated by the Derivative Works, if and
|
| 115 |
+
wherever such third-party notices normally appear. The contents
|
| 116 |
+
of the NOTICE file are for informational purposes only and
|
| 117 |
+
do not modify the License. You may add Your own attribution
|
| 118 |
+
notices within Derivative Works that You distribute, alongside
|
| 119 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 120 |
+
that such additional attribution notices cannot be construed
|
| 121 |
+
as modifying the License.
|
| 122 |
+
|
| 123 |
+
You may add Your own copyright statement to Your modifications and
|
| 124 |
+
may provide additional or different license terms and conditions
|
| 125 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 126 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 127 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 128 |
+
the conditions stated in this License.
|
| 129 |
+
|
| 130 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 131 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 132 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 133 |
+
this License, without any additional terms or conditions.
|
| 134 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 135 |
+
the terms of any separate license agreement you may have executed
|
| 136 |
+
with Licensor regarding such Contributions.
|
| 137 |
+
|
| 138 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 139 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 140 |
+
except as required for reasonable and customary use in describing the
|
| 141 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 142 |
+
|
| 143 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 144 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 145 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 146 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 147 |
+
implied, including, without limitation, any warranties or conditions
|
| 148 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 149 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 150 |
+
appropriateness of using or redistributing the Work and assume any
|
| 151 |
+
risks associated with Your exercise of permissions under this License.
|
| 152 |
+
|
| 153 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 154 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 155 |
+
unless required by applicable law (such as deliberate and grossly
|
| 156 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 157 |
+
liable to You for damages, including any direct, indirect, special,
|
| 158 |
+
incidental, or consequential damages of any character arising as a
|
| 159 |
+
result of this License or out of the use or inability to use the
|
| 160 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 161 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 162 |
+
other commercial damages or losses), even if such Contributor
|
| 163 |
+
has been advised of the possibility of such damages.
|
| 164 |
+
|
| 165 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 166 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 167 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 168 |
+
or other liability obligations and/or rights consistent with this
|
| 169 |
+
License. However, in accepting such obligations, You may act only
|
| 170 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 171 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 172 |
+
defend, and hold each Contributor harmless for any liability
|
| 173 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 174 |
+
of your accepting any such warranty or additional liability.
|
| 175 |
+
|
| 176 |
+
END OF TERMS AND CONDITIONS
|
| 177 |
+
|
| 178 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 179 |
+
|
| 180 |
+
To apply the Apache License to your work, attach the following
|
| 181 |
+
boilerplate notice, with the fields enclosed by brackets "{}"
|
| 182 |
+
replaced with your own identifying information. (Don't include
|
| 183 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 184 |
+
comment syntax for the file format. We also recommend that a
|
| 185 |
+
file or class name and description of purpose be included on the
|
| 186 |
+
same "printed page" as the copyright notice for easier
|
| 187 |
+
identification within third-party archives.
|
| 188 |
+
|
| 189 |
+
Copyright 2022 NVIDIA Corporation
|
| 190 |
+
|
| 191 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 192 |
+
you may not use this file except in compliance with the License.
|
| 193 |
+
You may obtain a copy of the License at
|
| 194 |
+
|
| 195 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 196 |
+
|
| 197 |
+
Unless required by applicable law or agreed to in writing, software
|
| 198 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 199 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 200 |
+
See the License for the specific language governing permissions and
|
| 201 |
+
limitations under the License.
|
physics_mcp/source/README.md
ADDED
|
@@ -0,0 +1,472 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# NVIDIA PhysicsNeMo
|
| 2 |
+
|
| 3 |
+
<!-- markdownlint-disable -->
|
| 4 |
+
|
| 5 |
+
📝 NVIDIA Modulus has been renamed to NVIDIA PhysicsNeMo
|
| 6 |
+
|
| 7 |
+
[](https://www.repostatus.org/#active)
|
| 8 |
+
[](https://github.com/NVIDIA/physicsnemo/blob/master/LICENSE.txt)
|
| 9 |
+
[](https://github.com/psf/black)
|
| 10 |
+
<!-- markdownlint-enable -->
|
| 11 |
+
[**NVIDIA PhysicsNeMo**](#what-is-physicsnemo)
|
| 12 |
+
| [**Documentation**](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/index.html)
|
| 13 |
+
| [**Install Guide**](#installation)
|
| 14 |
+
| [**Getting Started**](#getting-started)
|
| 15 |
+
| [**Contributing Guidelines**](#contributing-to-physicsnemo)
|
| 16 |
+
| [**License**](#license)
|
| 17 |
+
|
| 18 |
+
## What is PhysicsNeMo?
|
| 19 |
+
|
| 20 |
+
NVIDIA PhysicsNeMo is an open-source deep-learning framework for building, training,
|
| 21 |
+
fine-tuning, and inferring Physics AI models using state-of-the-art SciML methods for
|
| 22 |
+
AI4Science and engineering.
|
| 23 |
+
|
| 24 |
+
PhysicsNeMo provides Python modules to compose scalable and optimized training and
|
| 25 |
+
inference pipelines to explore, develop, validate, and deploy AI models that combine
|
| 26 |
+
physics knowledge with data, enabling real-time predictions.
|
| 27 |
+
|
| 28 |
+
Whether you are exploring the use of neural operators, GNNs, or transformers, or are
|
| 29 |
+
interested in Physics-Informed Neural Networks or a hybrid approach in between, PhysicsNeMo
|
| 30 |
+
provides you with an optimized stack that will enable you to train your models at scale.
|
| 31 |
+
|
| 32 |
+
<!-- markdownlint-disable -->
|
| 33 |
+
<p align="center">
|
| 34 |
+
<img src=https://raw.githubusercontent.com/NVIDIA/physicsnemo/main/docs/img/value_prop/Knowledge_guided_models.gif alt="PhysicsNeMo"/>
|
| 35 |
+
</p>
|
| 36 |
+
<!-- markdownlint-enable -->
|
| 37 |
+
|
| 38 |
+
<!-- toc -->
|
| 39 |
+
|
| 40 |
+
- [More About PhysicsNeMo](#more-about-physicsnemo)
|
| 41 |
+
- [Scalable GPU-Optimized Training Library](#scalable-gpu-optimized-training-library)
|
| 42 |
+
- [A Suite of Physics-Informed ML Models](#a-suite-of-physics-informed-ml-models)
|
| 43 |
+
- [Seamless PyTorch Integration](#seamless-pytorch-integration)
|
| 44 |
+
- [Easy Customization and Extension](#easy-customization-and-extension)
|
| 45 |
+
- [AI4Science Library](#ai4science-library)
|
| 46 |
+
- [Domain-Specific Packages](#domain-specific-packages)
|
| 47 |
+
- [Who is Using and Contributing to PhysicsNeMo](#who-is-using-and-contributing-to-physicsnemo)
|
| 48 |
+
- [Why Use PhysicsNeMo](#why-are-they-using-physicsnemo)
|
| 49 |
+
- [Getting Started](#getting-started)
|
| 50 |
+
- [Resources](#resources)
|
| 51 |
+
- [Installation](#installation)
|
| 52 |
+
- [Contributing](#contributing-to-physicsnemo)
|
| 53 |
+
- [Communication](#communication)
|
| 54 |
+
- [License](#license)
|
| 55 |
+
|
| 56 |
+
<!-- tocstop -->
|
| 57 |
+
|
| 58 |
+
## More About PhysicsNeMo
|
| 59 |
+
|
| 60 |
+
At a granular level, PhysicsNeMo is developed as modular functionality and therefore
|
| 61 |
+
provides built-in composable modules that are packaged into a few key components:
|
| 62 |
+
|
| 63 |
+
<!-- markdownlint-disable -->
|
| 64 |
+
Component | Description |
|
| 65 |
+
---- | --- |
|
| 66 |
+
[**physicsnemo.models**](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.models.html) | A collection of optimized, customizable, and easy-to-use families of model architectures such as Neural Operators, Graph Neural Networks, Diffusion models, Transformer models, and many more|
|
| 67 |
+
[**physicsnemo.datapipes**](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.datapipes.html) | Optimized and scalable built-in data pipelines fine-tuned to handle engineering and scientific data structures like point clouds, meshes, etc.|
|
| 68 |
+
[**physicsnemo.distributed**](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.distributed.html) | A distributed computing sub-module built on top of `torch.distributed` to enable parallel training with just a few steps|
|
| 69 |
+
[**physicsnemo.curator**](https://github.com/NVIDIA/physicsnemo-curator) | A sub-module to streamline and accelerate the process of data curation for engineering datasets|
|
| 70 |
+
[**physicsnemo.sym.geometry**](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/features/csg_and_tessellated_module.html) | A sub-module to handle geometry for DL training using Constructive Solid Geometry modeling and CAD files in STL format|
|
| 71 |
+
[**physicsnemo.sym.eq**](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/features/nodes.html) | A sub-module to use PDEs in your DL training with several implementations of commonly observed equations and easy ways for customization|
|
| 72 |
+
<!-- markdownlint-enable -->
|
| 73 |
+
|
| 74 |
+
For a complete list, refer to the PhysicsNeMo API documentation for
|
| 75 |
+
[PhysicsNeMo](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/index.html).
|
| 76 |
+
|
| 77 |
+
## AI4Science Library
|
| 78 |
+
|
| 79 |
+
Usually, PhysicsNeMo is used either as:
|
| 80 |
+
|
| 81 |
+
- A complementary tool to PyTorch when exploring AI for SciML and AI4Science applications.
|
| 82 |
+
- A deep learning research platform that provides scale and optimal performance on
|
| 83 |
+
NVIDIA GPUs.
|
| 84 |
+
|
| 85 |
+
### Domain-Specific Packages
|
| 86 |
+
|
| 87 |
+
The following are packages dedicated to domain experts of specific communities, catering
|
| 88 |
+
to their unique exploration needs:
|
| 89 |
+
|
| 90 |
+
- [PhysicsNeMo CFD](https://github.com/NVIDIA/physicsnemo-cfd): Inference sub-module of PhysicsNeMo
|
| 91 |
+
to enable CFD domain experts to explore, experiment, and validate using pretrained
|
| 92 |
+
AI models for CFD use cases.
|
| 93 |
+
- [PhysicsNeMo Curator](https://github.com/NVIDIA/physicsnemo-curator): Inference sub-module
|
| 94 |
+
of PhysicsNeMo to streamline and accelerate the process of data curation for engineering
|
| 95 |
+
datasets.
|
| 96 |
+
- [Earth-2 Studio](https://github.com/NVIDIA/earth2studio): Inference sub-module of PhysicsNeMo
|
| 97 |
+
to enable climate researchers and scientists to explore and experiment with pretrained
|
| 98 |
+
AI models for weather and climate.
|
| 99 |
+
|
| 100 |
+
### Scalable GPU-Optimized Training Library
|
| 101 |
+
|
| 102 |
+
PhysicsNeMo provides a highly optimized and scalable training library for maximizing the
|
| 103 |
+
power of NVIDIA GPUs.
|
| 104 |
+
[Distributed computing](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.distributed.html)
|
| 105 |
+
utilities allow for efficient scaling from a single GPU to multi-node GPU clusters with
|
| 106 |
+
a few lines of code, ensuring that large-scale
|
| 107 |
+
physics-informed machine learning (ML) models can be trained quickly and effectively.
|
| 108 |
+
The framework includes support for advanced
|
| 109 |
+
[optimization utilities](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.utils.html#module-physicsnemo.utils.capture),
|
| 110 |
+
[tailor-made datapipes](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.datapipes.html),
|
| 111 |
+
and [validation utilities](https://github.com/NVIDIA/physicsnemo-sym/tree/main/physicsnemo/sym/eq)
|
| 112 |
+
to enhance end-to-end training speed.
|
| 113 |
+
|
| 114 |
+
### A Suite of Physics-Informed ML Models
|
| 115 |
+
|
| 116 |
+
PhysicsNeMo offers a library of state-of-the-art models specifically designed
|
| 117 |
+
for Physics-ML applications. Users can build any model architecture by using the underlying
|
| 118 |
+
PyTorch layers and combining them with curated PhysicsNeMo layers.
|
| 119 |
+
|
| 120 |
+
The [Model Zoo](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.models.html#model-zoo)
|
| 121 |
+
includes optimized implementations of families of model architectures such as
|
| 122 |
+
Neural Operators:
|
| 123 |
+
|
| 124 |
+
- [Fourier Neural Operators (FNOs)](physicsnemo/models/fno)
|
| 125 |
+
- [DeepONet](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/neural_operators/deeponet.html)
|
| 126 |
+
- [DoMINO](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/cfd/external_aerodynamics/domino/readme.html)
|
| 127 |
+
- [Graph Neural Networks (GNNs)](physicsnemo/models/gnn_layers)
|
| 128 |
+
- [MeshGraphNet](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/cfd/vortex_shedding_mgn/readme.html)
|
| 129 |
+
- [MeshGraphNet for Lagrangian](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/cfd/lagrangian_mgn/readme.html)
|
| 130 |
+
- [XAeroNet](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/cfd/external_aerodynamics/xaeronet/readme.html)
|
| 131 |
+
- [Diffusion Models](physicsnemo/models/diffusion)
|
| 132 |
+
- [Correction Diffusion Model](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/generative/corrdiff/readme.html)
|
| 133 |
+
- [DDPM](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/generative/diffusion/readme.html)
|
| 134 |
+
- [PhysicsNeMo GraphCast](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/examples/weather/graphcast/readme.html)
|
| 135 |
+
- [Transsolver](https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/darcy_transolver)
|
| 136 |
+
- [RNNs](https://github.com/NVIDIA/physicsnemo/tree/main/physicsnemo/models)
|
| 137 |
+
- [SwinVRNN](https://github.com/NVIDIA/physicsnemo/tree/main/physicsnemo/models/swinvrnn)
|
| 138 |
+
- [Physics-Informed Neural Networks (PINNs)](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-sym/user_guide/foundational/1d_wave_equation.html)
|
| 139 |
+
|
| 140 |
+
And many others.
|
| 141 |
+
|
| 142 |
+
These models are optimized for various physics domains, such as computational fluid
|
| 143 |
+
dynamics, structural mechanics, and electromagnetics. Users can download, customize, and
|
| 144 |
+
build upon these models to suit their specific needs, significantly reducing the time
|
| 145 |
+
required to develop high-fidelity simulations.
|
| 146 |
+
|
| 147 |
+
### Seamless PyTorch Integration
|
| 148 |
+
|
| 149 |
+
PhysicsNeMo is built on top of PyTorch, providing a familiar and user-friendly experience
|
| 150 |
+
for those already proficient with PyTorch.
|
| 151 |
+
This includes a simple Python interface and modular design, making it easy to use
|
| 152 |
+
PhysicsNeMo with existing PyTorch workflows.
|
| 153 |
+
Users can leverage the extensive PyTorch ecosystem, including its libraries and tools,
|
| 154 |
+
while benefiting from PhysicsNeMo's specialized capabilities for physics-ML. This seamless
|
| 155 |
+
integration ensures users can quickly adopt PhysicsNeMo without a steep learning curve.
|
| 156 |
+
|
| 157 |
+
For more information, refer to [Converting PyTorch Models to PhysicsNeMo Models](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.models.html#converting-pytorch-models-to-physicsnemo-models).
|
| 158 |
+
|
| 159 |
+
### Easy Customization and Extension
|
| 160 |
+
|
| 161 |
+
PhysicsNeMo is designed to be highly extensible, allowing users to add new functionality
|
| 162 |
+
with minimal effort. The framework provides Pythonic APIs for
|
| 163 |
+
defining new physics models, geometries, and constraints, making it easy to extend its
|
| 164 |
+
capabilities to new use cases.
|
| 165 |
+
The adaptability of PhysicsNeMo is further enhanced by key features such as
|
| 166 |
+
[ONNX support](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.deploy.html)
|
| 167 |
+
for flexible model deployment,
|
| 168 |
+
robust [logging utilities](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.launch.logging.html)
|
| 169 |
+
for streamlined error handling,
|
| 170 |
+
and efficient
|
| 171 |
+
[checkpointing](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.launch.utils.html#module-physicsnemo.launch.utils.checkpoint)
|
| 172 |
+
to simplify model loading and saving.
|
| 173 |
+
|
| 174 |
+
This extensibility ensures that PhysicsNeMo can adapt to the evolving needs of researchers
|
| 175 |
+
and engineers, facilitating the development of innovative solutions in the field of physics-ML.
|
| 176 |
+
|
| 177 |
+
Detailed information on features and capabilities can be found in the [PhysicsNeMo documentation](https://docs.nvidia.com/physicsnemo/index.html#core).
|
| 178 |
+
|
| 179 |
+
[Reference samples](examples/README.md) cover a broad spectrum of physics-constrained
|
| 180 |
+
and data-driven
|
| 181 |
+
workflows to suit the diversity of use cases in the science and engineering disciplines.
|
| 182 |
+
|
| 183 |
+
> [!TIP]
|
| 184 |
+
> Have questions about how PhysicsNeMo can assist you? Try our [Experimental] chatbot,
|
| 185 |
+
> [PhysicsNeMo Guide](https://chatgpt.com/g/g-PXrBv20SC-modulus-guide), for answers.
|
| 186 |
+
|
| 187 |
+
### Hello World
|
| 188 |
+
|
| 189 |
+
You can start using PhysicsNeMo in your PyTorch code as simply as shown here:
|
| 190 |
+
|
| 191 |
+
```python
|
| 192 |
+
>>> import torch
|
| 193 |
+
>>> from physicsnemo.models.mlp.fully_connected import FullyConnected
|
| 194 |
+
>>> model = FullyConnected(in_features=32, out_features=64)
|
| 195 |
+
>>> input = torch.randn(128, 32)
|
| 196 |
+
>>> output = model(input)
|
| 197 |
+
>>> output.shape
|
| 198 |
+
torch.Size([128, 64])
|
| 199 |
+
```
|
| 200 |
+
|
| 201 |
+
To use the distributed module, you can do the following (example for
|
| 202 |
+
distributed data parallel training; for a more in-depth tutorial, refer to
|
| 203 |
+
[PhysicsNeMo Distributed](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/api/physicsnemo.distributed.html#)):
|
| 204 |
+
|
| 205 |
+
```python
|
| 206 |
+
import torch
|
| 207 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 208 |
+
from physicsnemo.distributed import DistributedManager
|
| 209 |
+
from physicsnemo.models.mlp.fully_connected import FullyConnected
|
| 210 |
+
|
| 211 |
+
def main():
|
| 212 |
+
DistributedManager.initialize()
|
| 213 |
+
dist = DistributedManager()
|
| 214 |
+
|
| 215 |
+
arch = FullyConnected(in_features=32, out_features=64).to(dist.device)
|
| 216 |
+
|
| 217 |
+
if dist.distributed:
|
| 218 |
+
ddps = torch.cuda.Stream()
|
| 219 |
+
with torch.cuda.stream(ddps):
|
| 220 |
+
arch = DistributedDataParallel(
|
| 221 |
+
arch,
|
| 222 |
+
device_ids=[dist.local_rank],
|
| 223 |
+
output_device=dist.device,
|
| 224 |
+
broadcast_buffers=dist.broadcast_buffers,
|
| 225 |
+
find_unused_parameters=dist.find_unused_parameters,
|
| 226 |
+
)
|
| 227 |
+
torch.cuda.current_stream().wait_stream(ddps)
|
| 228 |
+
|
| 229 |
+
# Set up the optimizer
|
| 230 |
+
optimizer = torch.optim.Adam(
|
| 231 |
+
arch.parameters(),
|
| 232 |
+
lr=0.001,
|
| 233 |
+
)
|
| 234 |
+
|
| 235 |
+
def training_step(invar, target):
|
| 236 |
+
pred = arch(invar)
|
| 237 |
+
loss = torch.sum(torch.pow(pred - target, 2))
|
| 238 |
+
loss.backward()
|
| 239 |
+
optimizer.step()
|
| 240 |
+
return loss
|
| 241 |
+
|
| 242 |
+
# Sample training loop
|
| 243 |
+
for i in range(20):
|
| 244 |
+
# Random inputs and targets for simplicity
|
| 245 |
+
input = torch.randn(128, 32, device=dist.device)
|
| 246 |
+
target = torch.randn(128, 64, device=dist.device)
|
| 247 |
+
|
| 248 |
+
# Training step
|
| 249 |
+
loss = training_step(input, target)
|
| 250 |
+
|
| 251 |
+
if __name__ == "__main__":
|
| 252 |
+
main()
|
| 253 |
+
```
|
| 254 |
+
|
| 255 |
+
To use the PDE module, you can do the following:
|
| 256 |
+
|
| 257 |
+
```python
|
| 258 |
+
>>> from physicsnemo.sym.eq.pdes.navier_stokes import NavierStokes
|
| 259 |
+
>>> ns = NavierStokes(nu=0.01, rho=1, dim=2)
|
| 260 |
+
>>> ns.pprint()
|
| 261 |
+
continuity: u__x + v__y
|
| 262 |
+
momentum_x: u*u__x + v*u__y + p__x + u__t - 0.01*u__x__x - 0.01*u__y__y
|
| 263 |
+
momentum_y: u*v__x + v*v__y + p__y + v__t - 0.01*v__x__x - 0.01*v__y__y
|
| 264 |
+
```
|
| 265 |
+
|
| 266 |
+
## Who is Using and Contributing to PhysicsNeMo
|
| 267 |
+
|
| 268 |
+
PhysicsNeMo is an open-source project and gets contributions from researchers in
|
| 269 |
+
the SciML and AI4Science fields. While the PhysicsNeMo team works on optimizing the
|
| 270 |
+
underlying software stack, the community collaborates and contributes model architectures,
|
| 271 |
+
datasets, and reference applications so we can innovate in the pursuit of
|
| 272 |
+
developing generalizable model architectures and algorithms.
|
| 273 |
+
|
| 274 |
+
Some recent examples of community contributors are the [HP Labs 3D Printing team](https://developer.nvidia.com/blog/spotlight-hp-3d-printing-and-nvidia-physicsnemo-collaborate-on-open-source-manufacturing-digital-twin/),
|
| 275 |
+
[Stanford Cardiovascular research team](https://developer.nvidia.com/blog/enabling-greater-patient-specific-cardiovascular-care-with-ai-surrogates/),
|
| 276 |
+
[UIUC team](https://github.com/NVIDIA/physicsnemo/tree/main/examples/cfd/mhd_pino),
|
| 277 |
+
[CMU team](https://github.com/NVIDIA/physicsnemo/tree/main/examples/generative/diffusion),
|
| 278 |
+
etc.
|
| 279 |
+
|
| 280 |
+
Recent examples of research teams using PhysicsNeMo are the
|
| 281 |
+
[ORNL team](https://arxiv.org/abs/2404.05768),
|
| 282 |
+
[TU Munich CFD team](https://www.nvidia.com/en-us/on-demand/session/gtc24-s62237/), etc.
|
| 283 |
+
|
| 284 |
+
Please navigate to this page for a complete list of research work leveraging PhysicsNeMo.
|
| 285 |
+
For a list of enterprises using PhysicsNeMo, refer to the [PhysicsNeMo Webpage](https://developer.nvidia.com/physicsnemo).
|
| 286 |
+
|
| 287 |
+
Using PhysicsNeMo and interested in showcasing your work on
|
| 288 |
+
[NVIDIA Blogs](https://developer.nvidia.com/blog/category/simulation-modeling-design/)?
|
| 289 |
+
Fill out this [proposal form](https://forms.gle/XsBdWp3ji67yZAUF7) and we will get back
|
| 290 |
+
to you!
|
| 291 |
+
|
| 292 |
+
## Why Are They Using PhysicsNeMo
|
| 293 |
+
|
| 294 |
+
Here are some of the key benefits of PhysicsNeMo for SciML model development:
|
| 295 |
+
|
| 296 |
+
<!-- markdownlint-disable -->
|
| 297 |
+
<img src="docs/img/value_prop/benchmarking.svg" width="100"> | <img src="docs/img/value_prop/recipe.svg" width="100"> | <img src="docs/img/value_prop/performance.svg" width="100">
|
| 298 |
+
---|---|---|
|
| 299 |
+
|SciML Benchmarking and Validation|Ease of Using Generalized SciML Recipes with Heterogeneous Datasets |Out-of-the-Box Performance and Scalability
|
| 300 |
+
|PhysicsNeMo enables researchers to benchmark their AI models against proven architectures for standard benchmark problems with detailed domain-specific validation criteria.|PhysicsNeMo enables researchers to pick from state-of-the-art SciML architectures and use built-in data pipelines for their use case.| PhysicsNeMo provides out-of-the-box performant training pipelines, including optimized ETL pipelines for heterogeneous engineering and scientific datasets and out-of-the-box scaling across multi-GPU and multi-node GPUs.
|
| 301 |
+
<!-- markdownlint-enable -->
|
| 302 |
+
|
| 303 |
+
See what your peer SciML researchers are saying about PhysicsNeMo (coming soon).
|
| 304 |
+
|
| 305 |
+
## Getting Started
|
| 306 |
+
|
| 307 |
+
The following resources will help you learn how to use PhysicsNeMo. The best
|
| 308 |
+
way is to start with a reference sample and then update it for your own use case.
|
| 309 |
+
|
| 310 |
+
- [Using PhysicsNeMo with your PyTorch model](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/tutorials/simple_training_example.html#using-custom-models-in-physicsnemo)
|
| 311 |
+
- [Using PhysicsNeMo built-in models](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/tutorials/simple_training_example.html#using-built-in-models)
|
| 312 |
+
- [Getting Started Guide](https://docs.nvidia.com/deeplearning/physicsnemo/getting-started/index.html)
|
| 313 |
+
- [Reference Samples](https://github.com/NVIDIA/physicsnemo/blob/main/examples/README.md)
|
| 314 |
+
- [User Guide Documentation](https://docs.nvidia.com/deeplearning/physicsnemo/physicsnemo-core/index.html)
|
| 315 |
+
|
| 316 |
+
## Resources
|
| 317 |
+
|
| 318 |
+
- [Getting Started Webinar](https://www.nvidia.com/en-us/on-demand/session/gtc24-dlit61460/?playlistId=playList-bd07f4dc-1397-4783-a959-65cec79aa985)
|
| 319 |
+
- [AI4Science PhysicsNeMo Bootcamp](https://github.com/openhackathons-org/End-to-End-AI-for-Science)
|
| 320 |
+
- [PhysicsNeMo Pretrained Models](https://catalog.ngc.nvidia.com/models?filters=&orderBy=scoreDESC&query=PhysicsNeMo&page=&pageSize=)
|
| 321 |
+
- [PhysicsNeMo Datasets and Supplementary Materials](https://catalog.ngc.nvidia.com/resources?filters=&orderBy=scoreDESC&query=PhysicsNeMo&page=&pageSize=)
|
| 322 |
+
- [Self-Paced PhysicsNeMo DLI Training](https://learn.nvidia.com/courses/course-detail?course_id=course-v1:DLI+S-OV-04+V1)
|
| 323 |
+
- [Deep Learning for Science and Engineering Lecture Series with PhysicsNeMo](https://www.nvidia.com/en-us/on-demand/deep-learning-for-science-and-engineering/)
|
| 324 |
+
- [PhysicsNeMo: Purpose and Usage](https://www.nvidia.com/en-us/on-demand/session/dliteachingkit-setk5002/)
|
| 325 |
+
- [Video Tutorials](https://www.nvidia.com/en-us/on-demand/search/?facet.mimetype[]=event%20session&layout=list&page=1&q=physicsnemo&sort=relevance&sortDir=desc)
|
| 326 |
+
|
| 327 |
+
## Installation
|
| 328 |
+
|
| 329 |
+
The following instructions help you install the base PhysicsNeMo modules to get started.
|
| 330 |
+
There are additional optional dependencies for specific models that are listed under
|
| 331 |
+
[optional dependencies](#optional-dependencies).
|
| 332 |
+
The training recipes are not packaged into the pip wheels or the container to keep the
|
| 333 |
+
footprint low. We recommend users clone the appropriate training recipes and use them
|
| 334 |
+
as a starting point. These training recipes may require additional example-specific dependencies,
|
| 335 |
+
as indicated through their associated `requirements.txt` file.
|
| 336 |
+
|
| 337 |
+
### PyPI
|
| 338 |
+
|
| 339 |
+
The recommended method for installing the latest version of PhysicsNeMo is using PyPI:
|
| 340 |
+
|
| 341 |
+
```Bash
|
| 342 |
+
pip install nvidia-physicsnemo
|
| 343 |
+
```
|
| 344 |
+
|
| 345 |
+
The installation can be verified by running the [Hello World](#hello-world) example.
|
| 346 |
+
|
| 347 |
+
#### Optional Dependencies
|
| 348 |
+
|
| 349 |
+
PhysicsNeMo has many optional dependencies that are used in specific components.
|
| 350 |
+
When using pip, all dependencies used in PhysicsNeMo can be installed with
|
| 351 |
+
`pip install nvidia-physicsnemo[all]`. If you are developing PhysicsNeMo, developer dependencies
|
| 352 |
+
can be installed using `pip install nvidia-physicsnemo[dev]`. Otherwise, additional dependencies
|
| 353 |
+
can be installed on a case-by-case basis. Detailed information on installing the
|
| 354 |
+
optional dependencies can be found in the
|
| 355 |
+
[Getting Started Guide](https://docs.nvidia.com/deeplearning/physicsnemo/getting-started/index.html).
|
| 356 |
+
|
| 357 |
+
### NVCR Container
|
| 358 |
+
|
| 359 |
+
The recommended PhysicsNeMo Docker image can be pulled from the
|
| 360 |
+
[NVIDIA Container Registry](https://catalog.ngc.nvidia.com/orgs/nvidia/teams/physicsnemo/containers/physicsnemo)
|
| 361 |
+
(refer to the NGC registry for the latest tag):
|
| 362 |
+
|
| 363 |
+
```Bash
|
| 364 |
+
docker pull nvcr.io/nvidia/physicsnemo/physicsnemo:25.06
|
| 365 |
+
```
|
| 366 |
+
|
| 367 |
+
Inside the container, you can clone the PhysicsNeMo git repositories and get
|
| 368 |
+
started with the examples. The command below shows the instructions to launch
|
| 369 |
+
the PhysicsNeMo container and run examples from this repo:
|
| 370 |
+
|
| 371 |
+
```bash
|
| 372 |
+
docker run --shm-size=1g --ulimit memlock=-1 --ulimit stack=67108864 --runtime nvidia \
|
| 373 |
+
--rm -it nvcr.io/nvidia/physicsnemo/physicsnemo:25.06 bash
|
| 374 |
+
git clone https://github.com/NVIDIA/physicsnemo.git
|
| 375 |
+
cd physicsnemo/examples/cfd/darcy_fno/
|
| 376 |
+
pip install warp-lang # install NVIDIA Warp to run the Darcy example
|
| 377 |
+
python train_fno_darcy.py
|
| 378 |
+
```
|
| 379 |
+
|
| 380 |
+
## From Source
|
| 381 |
+
|
| 382 |
+
### Package
|
| 383 |
+
|
| 384 |
+
For a local build of the PhysicsNeMo Python package from source, use:
|
| 385 |
+
|
| 386 |
+
```Bash
|
| 387 |
+
git clone git@github.com:NVIDIA/physicsnemo.git && cd physicsnemo
|
| 388 |
+
|
| 389 |
+
pip install --upgrade pip
|
| 390 |
+
pip install .
|
| 391 |
+
```
|
| 392 |
+
|
| 393 |
+
### Source Container
|
| 394 |
+
|
| 395 |
+
To build the PhysicsNeMo Docker image:
|
| 396 |
+
|
| 397 |
+
```bash
|
| 398 |
+
docker build -t physicsnemo:deploy \
|
| 399 |
+
--build-arg TARGETPLATFORM=linux/amd64 --target deploy -f Dockerfile .
|
| 400 |
+
```
|
| 401 |
+
|
| 402 |
+
Alternatively, you can run `make container-deploy`.
|
| 403 |
+
|
| 404 |
+
To build the CI image:
|
| 405 |
+
|
| 406 |
+
```bash
|
| 407 |
+
docker build -t physicsnemo:ci \
|
| 408 |
+
--build-arg TARGETPLATFORM=linux/amd64 --target ci -f Dockerfile .
|
| 409 |
+
```
|
| 410 |
+
|
| 411 |
+
Alternatively, you can run `make container-ci`.
|
| 412 |
+
|
| 413 |
+
Currently, only `linux/amd64` and `linux/arm64` platforms are supported. If using
|
| 414 |
+
`linux/arm64`, some dependencies like `warp-lang` might not install correctly.
|
| 415 |
+
|
| 416 |
+
## PhysicsNeMo Migration Guide
|
| 417 |
+
|
| 418 |
+
NVIDIA Modulus has been renamed to NVIDIA PhysicsNeMo. For migration:
|
| 419 |
+
|
| 420 |
+
- Use `pip install nvidia-physicsnemo` rather than `pip install nvidia-modulus`
|
| 421 |
+
for PyPI wheels.
|
| 422 |
+
- Use `nvcr.io/nvidia/physicsnemo/physicsnemo:<tag>` rather than
|
| 423 |
+
`nvcr.io/nvidia/modulus/modulus:<tag>` for Docker containers.
|
| 424 |
+
- Replace `nvidia-modulus` with `nvidia-physicsnemo` in your pip requirements
|
| 425 |
+
files (`requirements.txt`, `setup.py`, `setup.cfg`, `pyproject.toml`, etc.).
|
| 426 |
+
- In your code, change the import statements from `import modulus` to
|
| 427 |
+
`import physicsnemo`.
|
| 428 |
+
|
| 429 |
+
The old PyPI registry and the NGC container registry will be deprecated soon
|
| 430 |
+
and will not receive any bug fixes/updates. The old checkpoints will remain
|
| 431 |
+
compatible with these updates.
|
| 432 |
+
|
| 433 |
+
More details to follow soon.
|
| 434 |
+
|
| 435 |
+
## DGL to PyTorch Geometric Migration Guide
|
| 436 |
+
|
| 437 |
+
PhysicsNeMo supports a wide range of Graph Neural Networks (GNNs),
|
| 438 |
+
including MeshGraphNet and others.
|
| 439 |
+
Currently, PhysicsNeMo uses the DGL library as its GNN backend,
|
| 440 |
+
with plans to completely transition to PyTorch Geometric (PyG) in a future release.
|
| 441 |
+
For more details, please refer to the [DGL-to-PyG migration guide](https://github.com/NVIDIA/physicsnemo/blob/main/examples/dgl_to_pyg_migration.md).
|
| 442 |
+
|
| 443 |
+
## Contributing to PhysicsNeMo
|
| 444 |
+
|
| 445 |
+
PhysicsNeMo is an open-source collaboration, and its success is rooted in community
|
| 446 |
+
contributions to further the field of Physics-ML. Thank you for contributing to the
|
| 447 |
+
project so others can build on top of your contributions.
|
| 448 |
+
|
| 449 |
+
For guidance on contributing to PhysicsNeMo, please refer to the
|
| 450 |
+
[contributing guidelines](CONTRIBUTING.md).
|
| 451 |
+
|
| 452 |
+
## Cite PhysicsNeMo
|
| 453 |
+
|
| 454 |
+
If PhysicsNeMo helped your research and you would like to cite it, please refer to the [guidelines](https://github.com/NVIDIA/physicsnemo/blob/main/CITATION.cff).
|
| 455 |
+
|
| 456 |
+
## Communication
|
| 457 |
+
|
| 458 |
+
- GitHub Discussions: Discuss new architectures, implementations, Physics-ML research, etc.
|
| 459 |
+
- GitHub Issues: Bug reports, feature requests, install issues, etc.
|
| 460 |
+
- PhysicsNeMo Forum: The [PhysicsNeMo Forum](https://forums.developer.nvidia.com/t/welcome-to-the-physicsnemo-ml-model-framework-forum/178556)
|
| 461 |
+
hosts an audience of new to moderate-level users and developers for general chat, online
|
| 462 |
+
discussions, collaboration, etc.
|
| 463 |
+
|
| 464 |
+
## Feedback
|
| 465 |
+
|
| 466 |
+
Want to suggest some improvements to PhysicsNeMo? Use our [feedback form](https://docs.google.com/forms/d/e/1FAIpQLSfX4zZ0Lp7MMxzi3xqvzX4IQDdWbkNh5H_a_clzIhclE2oSBQ/viewform?usp=sf_link).
|
| 467 |
+
|
| 468 |
+
## License
|
| 469 |
+
|
| 470 |
+
PhysicsNeMo is provided under the Apache License 2.0. Please see [LICENSE.txt](./LICENSE.txt)
|
| 471 |
+
for the full license text. Enterprise SLA, support, and preview access are available
|
| 472 |
+
under NVAIE.
|
physics_mcp/source/SECURITY.md
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Security
|
| 2 |
+
|
| 3 |
+
NVIDIA is dedicated to the security and trust of our software products and
|
| 4 |
+
services, including all source code repositories managed through our organization.
|
| 5 |
+
|
| 6 |
+
If you need to report a security issue, please use the appropriate contact points
|
| 7 |
+
outlined below. **Please do not report security vulnerabilities through GitHub/GitLab.**
|
| 8 |
+
|
| 9 |
+
## Reporting Potential Security Vulnerability in an NVIDIA Product
|
| 10 |
+
|
| 11 |
+
To report a potential security vulnerability in any NVIDIA product:
|
| 12 |
+
|
| 13 |
+
- Web: [Security Vulnerability Submission Form](https://www.nvidia.com/object/submit-security-vulnerability.html)
|
| 14 |
+
- E-Mail: `psirt@nvidia.com`
|
| 15 |
+
- We encourage you to use the following PGP key for secure email communication:
|
| 16 |
+
[NVIDIA public PGP Key for communication](https://www.nvidia.com/en-us/security/pgp-key)
|
| 17 |
+
- Please include the following information:
|
| 18 |
+
- Product/Driver name and version/branch that contains the vulnerability
|
| 19 |
+
- Type of vulnerability (code execution, denial of service, buffer overflow, etc.)
|
| 20 |
+
- Instructions to reproduce the vulnerability
|
| 21 |
+
- Proof-of-concept or exploit code
|
| 22 |
+
- Potential impact of the vulnerability, including how an attacker could
|
| 23 |
+
exploit the vulnerability
|
| 24 |
+
|
| 25 |
+
While NVIDIA currently does not have a bug bounty program, we do offer
|
| 26 |
+
acknowledgement when an externally reported security issue is addressed under our
|
| 27 |
+
coordinated vulnerability disclosure policy. Please visit our
|
| 28 |
+
[Product Security Incident Response Team (PSIRT)](https://www.nvidia.com/en-us/security/psirt-policies/)
|
| 29 |
+
policies page for more information.
|
| 30 |
+
|
| 31 |
+
## NVIDIA Product Security
|
| 32 |
+
|
| 33 |
+
For all security-related concerns, please visit NVIDIA's Product Security portal
|
| 34 |
+
at `https://www.nvidia.com/en-us/security`
|
physics_mcp/source/__init__.py
ADDED
|
@@ -0,0 +1,4 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# -*- coding: utf-8 -*-
|
| 2 |
+
"""
|
| 3 |
+
physicsnemo Project Package Initialization File
|
| 4 |
+
"""
|
physics_mcp/source/greptile.json
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"comment": "",
|
| 3 |
+
"fixWithAI": false,
|
| 4 |
+
"commentTypes": [
|
| 5 |
+
"logic",
|
| 6 |
+
"syntax",
|
| 7 |
+
"style"
|
| 8 |
+
],
|
| 9 |
+
"instructions": "",
|
| 10 |
+
"excludeAuthors": [
|
| 11 |
+
"dependabot[bot]",
|
| 12 |
+
"renovate[bot]"
|
| 13 |
+
],
|
| 14 |
+
"ignorePatterns": "greptile.json\n",
|
| 15 |
+
"summarySection": {
|
| 16 |
+
"included": true,
|
| 17 |
+
"collapsible": false,
|
| 18 |
+
"defaultOpen": false
|
| 19 |
+
},
|
| 20 |
+
"triggerOnUpdates": false,
|
| 21 |
+
"updateSummaryOnly": false,
|
| 22 |
+
"issuesTableSection": {
|
| 23 |
+
"included": true,
|
| 24 |
+
"collapsible": false,
|
| 25 |
+
"defaultOpen": false
|
| 26 |
+
},
|
| 27 |
+
"confidenceScoreSection": {
|
| 28 |
+
"included": false,
|
| 29 |
+
"collapsible": false,
|
| 30 |
+
"defaultOpen": false
|
| 31 |
+
},
|
| 32 |
+
"sequenceDiagramSection": {
|
| 33 |
+
"included": false,
|
| 34 |
+
"collapsible": false,
|
| 35 |
+
"defaultOpen": false
|
| 36 |
+
},
|
| 37 |
+
"shouldUpdateDescription": false,
|
| 38 |
+
"customContext": {
|
| 39 |
+
"other": [
|
| 40 |
+
{
|
| 41 |
+
"scope": [],
|
| 42 |
+
"content": ""
|
| 43 |
+
}
|
| 44 |
+
],
|
| 45 |
+
"rules": [
|
| 46 |
+
{
|
| 47 |
+
"scope": [],
|
| 48 |
+
"rule": ""
|
| 49 |
+
}
|
| 50 |
+
],
|
| 51 |
+
"files": [
|
| 52 |
+
{
|
| 53 |
+
"scope": [],
|
| 54 |
+
"path": "",
|
| 55 |
+
"description": ""
|
| 56 |
+
}
|
| 57 |
+
]
|
| 58 |
+
}
|
| 59 |
+
}
|
physics_mcp/source/physicsnemo/__init__.py
ADDED
|
@@ -0,0 +1,22 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from .datapipes.datapipe import Datapipe
|
| 18 |
+
from .datapipes.meta import DatapipeMetaData
|
| 19 |
+
from .models.meta import ModelMetaData
|
| 20 |
+
from .models.module import Module
|
| 21 |
+
|
| 22 |
+
__version__ = "1.3.0a0"
|
physics_mcp/source/physicsnemo/active_learning/README.md
ADDED
|
@@ -0,0 +1,66 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Active Learning Module
|
| 2 |
+
|
| 3 |
+
The `physicsnemo.active_learning` namespace is used for defining the "scaffolding"
|
| 4 |
+
that can be used to construct automated, end-to-end active learning workflows.
|
| 5 |
+
For areas of science that are difficult to source ground-truths to train on
|
| 6 |
+
(of which there are many), an active learning curriculum attempts to train a
|
| 7 |
+
model with improved data efficiency; better generalization performance but requiring
|
| 8 |
+
fewer training samples.
|
| 9 |
+
|
| 10 |
+
Generally, an active learning workflow can be decomposed into three "phases"
|
| 11 |
+
that are - in the simplest case - run sequentially:
|
| 12 |
+
|
| 13 |
+
- **Training/fine-tuning**: A "learner" or surrogate model is initially trained
|
| 14 |
+
on available data, and in subsequent active learning iterations, is fine-tuned
|
| 15 |
+
with the new data appended on the original dataset.
|
| 16 |
+
- **Querying**: One or more strategies that encode some heuristics for what
|
| 17 |
+
new data is most informative for the learner. Examples of this include
|
| 18 |
+
uncertainty-based methods, which may screen a pool of unlabeled data for
|
| 19 |
+
those the model is least confident with.
|
| 20 |
+
- **Labeling**: A method of obtaining ground truth (labels) for new data
|
| 21 |
+
points, pipelined from the querying stage. This may entail running an
|
| 22 |
+
expensive solver, or acquiring experimental data.
|
| 23 |
+
|
| 24 |
+
The three phases are repeated until the learner converges. Because "convergence"
|
| 25 |
+
may not be easily defined, we define an additional phase which we call
|
| 26 |
+
**metrology**: this represents a phase most similar to querying, but allows
|
| 27 |
+
a user to define some set of criteria to monitor over the course of active
|
| 28 |
+
learning *beyond* simple validation metrics to ensure the model can be used
|
| 29 |
+
with confidence as surrogates (e.g. within a simulation loop).
|
| 30 |
+
|
| 31 |
+
## How to use this module
|
| 32 |
+
|
| 33 |
+
With the context above in mind, inspecting the `driver` module will give you
|
| 34 |
+
a sense for how the end-to-end workflow functions; the `Driver` class acts
|
| 35 |
+
as an orchestrator for all the phases of active learning we described above.
|
| 36 |
+
|
| 37 |
+
From there, you should realize that `Driver` is written in a highly abstract
|
| 38 |
+
way: we need concrete *strategies* that implement querying, labeling, and metrology
|
| 39 |
+
concepts. The `protocols` module provides the scaffolding to do so - we implement
|
| 40 |
+
various components as `typing.Protocol` which are used for structural sub-typing:
|
| 41 |
+
they can be thought of as abstract classes that define an expected interface
|
| 42 |
+
in a function or class from which you can define your own classes by either
|
| 43 |
+
inheriting from them, or defining your own class that implements the expected
|
| 44 |
+
methods and attributes.
|
| 45 |
+
|
| 46 |
+
In order to perform the training portion of active learning, we provide a
|
| 47 |
+
minimal yet functional `DefaultTrainingLoop` inside the `loop` module. This
|
| 48 |
+
loop simply requires a `protocols.TrainingProtocol` to be passed, which is
|
| 49 |
+
a function that defines the logic for computing the loss per batch/training
|
| 50 |
+
step.
|
| 51 |
+
|
| 52 |
+
## Configuring workflows
|
| 53 |
+
|
| 54 |
+
The `config` module defines some simple `dataclass`es that can be used
|
| 55 |
+
to configure the behavior of various parts of active learning, e.g. how
|
| 56 |
+
training is conducted, etc. Because `Driver` is designed to be checkpointable,
|
| 57 |
+
with the exception of a few parts such as datasets, everything should be
|
| 58 |
+
JSON-serializable.
|
| 59 |
+
|
| 60 |
+
## Restarting workflows
|
| 61 |
+
|
| 62 |
+
For classes and functions that are created at runtime, checkpointing requires
|
| 63 |
+
that these components can be recreated when restarting from a checkpoint. To
|
| 64 |
+
that end, the `_registry` module provides a user-friendly way to instantiate
|
| 65 |
+
objects: user-defined strategy classes can be added to the registry to enable
|
| 66 |
+
their creation in checkpoint restarts.
|
physics_mcp/source/physicsnemo/active_learning/__init__.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from physicsnemo.active_learning._registry import registry
|
| 18 |
+
from physicsnemo.active_learning.config import (
|
| 19 |
+
DriverConfig,
|
| 20 |
+
OptimizerConfig,
|
| 21 |
+
StrategiesConfig,
|
| 22 |
+
TrainingConfig,
|
| 23 |
+
)
|
| 24 |
+
from physicsnemo.active_learning.driver import Driver
|
| 25 |
+
from physicsnemo.active_learning.loop import DefaultTrainingLoop
|
| 26 |
+
|
| 27 |
+
__all__ = [
|
| 28 |
+
"registry",
|
| 29 |
+
"Driver",
|
| 30 |
+
"DefaultTrainingLoop",
|
| 31 |
+
"DriverConfig",
|
| 32 |
+
"OptimizerConfig",
|
| 33 |
+
"StrategiesConfig",
|
| 34 |
+
"TrainingConfig",
|
| 35 |
+
]
|
physics_mcp/source/physicsnemo/active_learning/_registry.py
ADDED
|
@@ -0,0 +1,332 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import importlib
|
| 20 |
+
import inspect
|
| 21 |
+
from typing import Any, Callable
|
| 22 |
+
from warnings import warn
|
| 23 |
+
|
| 24 |
+
from physicsnemo.active_learning.protocols import ActiveLearningProtocol
|
| 25 |
+
|
| 26 |
+
__all__ = ["registry"]
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
class ActiveLearningRegistry:
|
| 30 |
+
"""
|
| 31 |
+
Registry for active learning protocols.
|
| 32 |
+
|
| 33 |
+
This class provides a centralized registry for user-defined active learning
|
| 34 |
+
protocols that implement the `ActiveLearningProtocol`. It enables string-based
|
| 35 |
+
lookups for checkpointing and provides argument validation when constructing
|
| 36 |
+
protocol instances.
|
| 37 |
+
|
| 38 |
+
The registry supports two primary modes of interaction:
|
| 39 |
+
1. Registration via decorator: `@registry.register("my_strategy")`
|
| 40 |
+
2. Construction with validation: `registry.construct("my_strategy", **kwargs)`
|
| 41 |
+
|
| 42 |
+
Attributes
|
| 43 |
+
----------
|
| 44 |
+
_registry : dict[str, type[ActiveLearningProtocol]]
|
| 45 |
+
Internal dictionary mapping protocol names to their class types.
|
| 46 |
+
|
| 47 |
+
Methods
|
| 48 |
+
-------
|
| 49 |
+
register(cls_name: str) -> Callable[[type[ActiveLearningProtocol]], type[ActiveLearningProtocol]]
|
| 50 |
+
Decorator to register a protocol class with a given name.
|
| 51 |
+
construct(cls_name: str, **kwargs) -> ActiveLearningProtocol
|
| 52 |
+
Construct an instance of a registered protocol with argument validation.
|
| 53 |
+
is_registered(cls_name: str) -> bool
|
| 54 |
+
Check if a protocol name is registered.
|
| 55 |
+
|
| 56 |
+
Properties
|
| 57 |
+
----------
|
| 58 |
+
registered_names : list[str]
|
| 59 |
+
A list of all registered protocol names, sorted alphabetically.
|
| 60 |
+
|
| 61 |
+
Examples
|
| 62 |
+
--------
|
| 63 |
+
Register a custom strategy:
|
| 64 |
+
|
| 65 |
+
>>> from physicsnemo.active_learning._registry import registry
|
| 66 |
+
>>> @registry.register("my_custom_strategy")
|
| 67 |
+
... class MyCustomStrategy:
|
| 68 |
+
... def __init__(self, param1: int, param2: str):
|
| 69 |
+
... self.param1 = param1
|
| 70 |
+
... self.param2 = param2
|
| 71 |
+
|
| 72 |
+
Construct an instance with validation:
|
| 73 |
+
|
| 74 |
+
>>> strategy = registry.construct("my_custom_strategy", param1=42, param2="test")
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
def __init__(self) -> None:
|
| 78 |
+
"""Initialize an empty registry."""
|
| 79 |
+
self._registry: dict[str, type[ActiveLearningProtocol]] = {}
|
| 80 |
+
|
| 81 |
+
def register(
|
| 82 |
+
self, cls_name: str
|
| 83 |
+
) -> Callable[[type[ActiveLearningProtocol]], type[ActiveLearningProtocol]]:
|
| 84 |
+
"""
|
| 85 |
+
Decorator to register an active learning protocol class.
|
| 86 |
+
|
| 87 |
+
This decorator registers a class implementing the `ActiveLearningProtocol`
|
| 88 |
+
under the given name, allowing it to be retrieved and constructed later
|
| 89 |
+
using the `construct` method.
|
| 90 |
+
|
| 91 |
+
Parameters
|
| 92 |
+
----------
|
| 93 |
+
cls_name : str
|
| 94 |
+
The name to register the protocol under. This will be used as the
|
| 95 |
+
key for later retrieval.
|
| 96 |
+
|
| 97 |
+
Returns
|
| 98 |
+
-------
|
| 99 |
+
Callable[[type[ActiveLearningProtocol]], type[ActiveLearningProtocol]]
|
| 100 |
+
A decorator function that registers the class and returns it unchanged.
|
| 101 |
+
|
| 102 |
+
Raises
|
| 103 |
+
------
|
| 104 |
+
ValueError
|
| 105 |
+
If a protocol with the same name is already registered.
|
| 106 |
+
|
| 107 |
+
Examples
|
| 108 |
+
--------
|
| 109 |
+
>>> @registry.register("my_new_strategy")
|
| 110 |
+
... class MyStrategy:
|
| 111 |
+
... def __init__(self, param: int):
|
| 112 |
+
... self.param = param
|
| 113 |
+
"""
|
| 114 |
+
|
| 115 |
+
def decorator(
|
| 116 |
+
cls: type[ActiveLearningProtocol],
|
| 117 |
+
) -> type[ActiveLearningProtocol]:
|
| 118 |
+
"""
|
| 119 |
+
Method for decorating a class to registry it with the registry.
|
| 120 |
+
"""
|
| 121 |
+
if cls_name in self._registry:
|
| 122 |
+
raise ValueError(
|
| 123 |
+
f"Protocol '{cls_name}' is already registered. "
|
| 124 |
+
f"Existing class: {self._registry[cls_name].__name__}"
|
| 125 |
+
)
|
| 126 |
+
self._registry[cls_name] = cls
|
| 127 |
+
return cls
|
| 128 |
+
|
| 129 |
+
return decorator
|
| 130 |
+
|
| 131 |
+
def construct(
|
| 132 |
+
self, cls_name: str, module_path: str | None = None, **kwargs: Any
|
| 133 |
+
) -> ActiveLearningProtocol:
|
| 134 |
+
"""
|
| 135 |
+
Construct an instance of a registered protocol with argument validation.
|
| 136 |
+
|
| 137 |
+
This method retrieves a registered protocol class by name, validates that
|
| 138 |
+
the provided keyword arguments match the class's constructor signature,
|
| 139 |
+
and returns a new instance of the class.
|
| 140 |
+
|
| 141 |
+
Parameters
|
| 142 |
+
----------
|
| 143 |
+
cls_name : str
|
| 144 |
+
The name of the registered protocol to construct.
|
| 145 |
+
module_path: str | None
|
| 146 |
+
The path to the module to get the class from.
|
| 147 |
+
**kwargs : Any
|
| 148 |
+
Keyword arguments to pass to the protocol's constructor.
|
| 149 |
+
|
| 150 |
+
Returns
|
| 151 |
+
-------
|
| 152 |
+
ActiveLearningProtocol
|
| 153 |
+
A new instance of the requested protocol class.
|
| 154 |
+
|
| 155 |
+
Raises
|
| 156 |
+
------
|
| 157 |
+
KeyError
|
| 158 |
+
If the protocol name is not registered.
|
| 159 |
+
TypeError
|
| 160 |
+
If the provided keyword arguments do not match the constructor signature.
|
| 161 |
+
This includes missing required parameters or unexpected parameters.
|
| 162 |
+
|
| 163 |
+
Examples
|
| 164 |
+
--------
|
| 165 |
+
>>> from physicsnemo.active_learning._registry import registry
|
| 166 |
+
>>> @registry.register("my_latest_strategy")
|
| 167 |
+
... class MyStrategy:
|
| 168 |
+
... def __init__(self, param: int):
|
| 169 |
+
... self.param = param
|
| 170 |
+
>>> strategy = registry.construct("my_latest_strategy", param=42)
|
| 171 |
+
"""
|
| 172 |
+
cls = self.get_class(cls_name, module_path)
|
| 173 |
+
|
| 174 |
+
# Validate arguments against the class signature
|
| 175 |
+
try:
|
| 176 |
+
sig = inspect.signature(cls.__init__)
|
| 177 |
+
except (ValueError, TypeError) as e:
|
| 178 |
+
raise TypeError(
|
| 179 |
+
f"Could not inspect signature of {cls.__name__}.__init__: {e}"
|
| 180 |
+
)
|
| 181 |
+
|
| 182 |
+
# Get parameters, excluding 'self'
|
| 183 |
+
params = {
|
| 184 |
+
name: param for name, param in sig.parameters.items() if name != "self"
|
| 185 |
+
}
|
| 186 |
+
|
| 187 |
+
# Check if the signature accepts **kwargs
|
| 188 |
+
has_var_keyword = any(
|
| 189 |
+
p.kind == inspect.Parameter.VAR_KEYWORD for p in params.values()
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
# Check for missing required parameters
|
| 193 |
+
missing = []
|
| 194 |
+
for name, param in params.items():
|
| 195 |
+
if (
|
| 196 |
+
param.kind
|
| 197 |
+
not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
|
| 198 |
+
and param.default is inspect.Parameter.empty
|
| 199 |
+
and name not in kwargs
|
| 200 |
+
):
|
| 201 |
+
missing.append(name)
|
| 202 |
+
|
| 203 |
+
if missing:
|
| 204 |
+
raise TypeError(
|
| 205 |
+
f"Missing required arguments for {cls.__name__}: {', '.join(missing)}"
|
| 206 |
+
)
|
| 207 |
+
|
| 208 |
+
# Check for unexpected parameters (unless **kwargs is present)
|
| 209 |
+
if not has_var_keyword:
|
| 210 |
+
param_names = {
|
| 211 |
+
name
|
| 212 |
+
for name, param in params.items()
|
| 213 |
+
if param.kind
|
| 214 |
+
not in (inspect.Parameter.VAR_KEYWORD, inspect.Parameter.VAR_POSITIONAL)
|
| 215 |
+
}
|
| 216 |
+
unexpected = [name for name in kwargs if name not in param_names]
|
| 217 |
+
|
| 218 |
+
if unexpected:
|
| 219 |
+
warn(
|
| 220 |
+
f"Unexpected arguments for {cls.__name__}: {', '.join(unexpected)}. "
|
| 221 |
+
f"Valid parameters: {', '.join(sorted(param_names))}"
|
| 222 |
+
)
|
| 223 |
+
return cls(**kwargs)
|
| 224 |
+
|
| 225 |
+
def __getitem__(self, cls_name: str) -> type[ActiveLearningProtocol]:
|
| 226 |
+
"""
|
| 227 |
+
Retrieve a registered protocol class by name using dict-like access.
|
| 228 |
+
|
| 229 |
+
This method allows accessing registered protocol classes using square
|
| 230 |
+
bracket notation, e.g., `registry['my_strategy']`.
|
| 231 |
+
|
| 232 |
+
Parameters
|
| 233 |
+
----------
|
| 234 |
+
cls_name : str
|
| 235 |
+
The name of the registered protocol to retrieve.
|
| 236 |
+
|
| 237 |
+
Returns
|
| 238 |
+
-------
|
| 239 |
+
type[ActiveLearningProtocol]
|
| 240 |
+
The class type of the registered protocol.
|
| 241 |
+
|
| 242 |
+
Raises
|
| 243 |
+
------
|
| 244 |
+
KeyError
|
| 245 |
+
If the protocol name is not registered.
|
| 246 |
+
|
| 247 |
+
Examples
|
| 248 |
+
--------
|
| 249 |
+
>>> from physicsnemo.active_learning._registry import registry
|
| 250 |
+
>>> @registry.register("my_strategy")
|
| 251 |
+
... class MyStrategy:
|
| 252 |
+
... def __init__(self, param: int):
|
| 253 |
+
... self.param = param
|
| 254 |
+
>>> RetrievedClass = registry['my_strategy']
|
| 255 |
+
>>> instance = RetrievedClass(param=42)
|
| 256 |
+
"""
|
| 257 |
+
if cls_name not in self._registry:
|
| 258 |
+
available = ", ".join(self._registry.keys()) if self._registry else "none"
|
| 259 |
+
raise KeyError(
|
| 260 |
+
f"Protocol '{cls_name}' is not registered. "
|
| 261 |
+
f"Available protocols: {available}"
|
| 262 |
+
)
|
| 263 |
+
return self._registry[cls_name]
|
| 264 |
+
|
| 265 |
+
def is_registered(self, cls_name: str) -> bool:
|
| 266 |
+
"""
|
| 267 |
+
Check if a protocol name is registered.
|
| 268 |
+
|
| 269 |
+
Parameters
|
| 270 |
+
----------
|
| 271 |
+
cls_name : str
|
| 272 |
+
The name of the protocol to check.
|
| 273 |
+
|
| 274 |
+
Returns
|
| 275 |
+
-------
|
| 276 |
+
bool
|
| 277 |
+
True if the protocol is registered, False otherwise.
|
| 278 |
+
"""
|
| 279 |
+
return cls_name in self._registry
|
| 280 |
+
|
| 281 |
+
@property
|
| 282 |
+
def registered_names(self) -> list[str]:
|
| 283 |
+
"""
|
| 284 |
+
A list of all registered protocol names, sorted alphabetically.
|
| 285 |
+
|
| 286 |
+
Returns
|
| 287 |
+
-------
|
| 288 |
+
list[str]
|
| 289 |
+
A list of all registered protocol names, sorted alphabetically.
|
| 290 |
+
"""
|
| 291 |
+
return sorted(self._registry.keys())
|
| 292 |
+
|
| 293 |
+
def get_class(self, cls_name: str, module_path: str | None = None) -> type:
|
| 294 |
+
"""
|
| 295 |
+
Get a class by name from the registry or from a module path.
|
| 296 |
+
|
| 297 |
+
Parameters
|
| 298 |
+
----------
|
| 299 |
+
cls_name: str
|
| 300 |
+
The name of the class to get.
|
| 301 |
+
module_path: str | None
|
| 302 |
+
The path to the module to get the class from.
|
| 303 |
+
|
| 304 |
+
Returns
|
| 305 |
+
-------
|
| 306 |
+
type
|
| 307 |
+
The class.
|
| 308 |
+
|
| 309 |
+
Raises
|
| 310 |
+
------
|
| 311 |
+
NameError: If the class is not found in the registry or module.
|
| 312 |
+
ModuleNotFoundError: If the module is not found with the specified module path.
|
| 313 |
+
"""
|
| 314 |
+
if cls_name in self.registered_names:
|
| 315 |
+
return self._registry[cls_name]
|
| 316 |
+
else:
|
| 317 |
+
if module_path:
|
| 318 |
+
module = importlib.import_module(module_path)
|
| 319 |
+
cls = getattr(module, cls_name, None)
|
| 320 |
+
if not cls:
|
| 321 |
+
raise NameError(
|
| 322 |
+
f"Class {cls_name} not found in module {module_path}"
|
| 323 |
+
)
|
| 324 |
+
return cls
|
| 325 |
+
else:
|
| 326 |
+
raise NameError(
|
| 327 |
+
f"Class {cls_name} not found in registry, and no module path was provided."
|
| 328 |
+
)
|
| 329 |
+
|
| 330 |
+
|
| 331 |
+
# Module-level registry instance for global access
|
| 332 |
+
registry = ActiveLearningRegistry()
|
physics_mcp/source/physicsnemo/active_learning/config.py
ADDED
|
@@ -0,0 +1,808 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
Configuration dataclasses for the active learning driver.
|
| 19 |
+
|
| 20 |
+
This module provides structured configuration classes that separate different
|
| 21 |
+
concerns in the active learning workflow: optimization, training, strategies,
|
| 22 |
+
and driver orchestration.
|
| 23 |
+
"""
|
| 24 |
+
|
| 25 |
+
from __future__ import annotations
|
| 26 |
+
|
| 27 |
+
import math
|
| 28 |
+
import uuid
|
| 29 |
+
from collections import defaultdict
|
| 30 |
+
from dataclasses import dataclass, field
|
| 31 |
+
from json import dumps
|
| 32 |
+
from pathlib import Path
|
| 33 |
+
from typing import Any
|
| 34 |
+
from warnings import warn
|
| 35 |
+
|
| 36 |
+
import torch
|
| 37 |
+
from torch import distributed as dist
|
| 38 |
+
from torch.optim import AdamW, Optimizer
|
| 39 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 40 |
+
|
| 41 |
+
from physicsnemo.active_learning import protocols as p
|
| 42 |
+
from physicsnemo.active_learning._registry import registry
|
| 43 |
+
from physicsnemo.active_learning.loop import DefaultTrainingLoop
|
| 44 |
+
from physicsnemo.distributed import DistributedManager
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class OptimizerConfig:
|
| 49 |
+
"""
|
| 50 |
+
Configuration for optimizer and learning rate scheduler.
|
| 51 |
+
|
| 52 |
+
This encapsulates all training optimization parameters, keeping
|
| 53 |
+
them separate from the active learning orchestration logic.
|
| 54 |
+
|
| 55 |
+
Attributes
|
| 56 |
+
----------
|
| 57 |
+
optimizer_cls: type[Optimizer]
|
| 58 |
+
The optimizer class to use. Defaults to AdamW.
|
| 59 |
+
optimizer_kwargs: dict[str, Any]
|
| 60 |
+
Keyword arguments to pass to the optimizer constructor.
|
| 61 |
+
Defaults to {"lr": 1e-4}.
|
| 62 |
+
scheduler_cls: type[_LRScheduler] | None
|
| 63 |
+
The learning rate scheduler class to use. If None, no
|
| 64 |
+
scheduler will be configured.
|
| 65 |
+
scheduler_kwargs: dict[str, Any]
|
| 66 |
+
Keyword arguments to pass to the scheduler constructor.
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
optimizer_cls: type[Optimizer] = AdamW
|
| 70 |
+
optimizer_kwargs: dict[str, Any] = field(default_factory=lambda: {"lr": 1e-4})
|
| 71 |
+
scheduler_cls: type[_LRScheduler] | None = None
|
| 72 |
+
scheduler_kwargs: dict[str, Any] = field(default_factory=dict)
|
| 73 |
+
|
| 74 |
+
def __post_init__(self) -> None:
|
| 75 |
+
"""Validate optimizer configuration."""
|
| 76 |
+
# Validate learning rate if present
|
| 77 |
+
if "lr" in self.optimizer_kwargs:
|
| 78 |
+
lr = self.optimizer_kwargs["lr"]
|
| 79 |
+
if not isinstance(lr, (int, float)) or lr <= 0:
|
| 80 |
+
raise ValueError(f"Learning rate must be positive, got {lr}")
|
| 81 |
+
|
| 82 |
+
# Validate that scheduler_kwargs is only set if scheduler_cls is provided
|
| 83 |
+
if self.scheduler_kwargs and self.scheduler_cls is None:
|
| 84 |
+
raise ValueError(
|
| 85 |
+
"scheduler_kwargs provided but scheduler_cls is None. "
|
| 86 |
+
"Provide a scheduler_cls or remove scheduler_kwargs."
|
| 87 |
+
)
|
| 88 |
+
|
| 89 |
+
def to_dict(self) -> dict[str, Any]:
|
| 90 |
+
"""
|
| 91 |
+
Returns a JSON-serializable dictionary representation of the OptimizerConfig.
|
| 92 |
+
|
| 93 |
+
For round-tripping, the registry is used to de-serialize the optimizer and scheduler
|
| 94 |
+
classes.
|
| 95 |
+
|
| 96 |
+
Returns
|
| 97 |
+
-------
|
| 98 |
+
dict[str, Any]
|
| 99 |
+
A dictionary that can be JSON serialized.
|
| 100 |
+
"""
|
| 101 |
+
opt = {
|
| 102 |
+
"__name__": self.optimizer_cls.__name__,
|
| 103 |
+
"__module__": self.optimizer_cls.__module__,
|
| 104 |
+
}
|
| 105 |
+
if self.scheduler_cls:
|
| 106 |
+
scheduler = {
|
| 107 |
+
"__name__": self.scheduler_cls.__name__,
|
| 108 |
+
"__module__": self.scheduler_cls.__module__,
|
| 109 |
+
}
|
| 110 |
+
else:
|
| 111 |
+
scheduler = None
|
| 112 |
+
return {
|
| 113 |
+
"optimizer_cls": opt,
|
| 114 |
+
"optimizer_kwargs": self.optimizer_kwargs,
|
| 115 |
+
"scheduler_cls": scheduler,
|
| 116 |
+
"scheduler_kwargs": self.scheduler_kwargs,
|
| 117 |
+
}
|
| 118 |
+
|
| 119 |
+
@classmethod
|
| 120 |
+
def from_dict(cls, data: dict[str, Any]) -> OptimizerConfig:
|
| 121 |
+
"""
|
| 122 |
+
Creates an OptimizerConfig instance from a dictionary.
|
| 123 |
+
|
| 124 |
+
This method assumes that the optimizer and scheduler classes are
|
| 125 |
+
included in the ``physicsnemo.active_learning.registry``, or
|
| 126 |
+
a module path is specified to import the class from.
|
| 127 |
+
|
| 128 |
+
Parameters
|
| 129 |
+
----------
|
| 130 |
+
data: dict[str, Any]
|
| 131 |
+
A dictionary that was previously serialized using the ``to_dict`` method.
|
| 132 |
+
|
| 133 |
+
Returns
|
| 134 |
+
-------
|
| 135 |
+
OptimizerConfig
|
| 136 |
+
A new ``OptimizerConfig`` instance.
|
| 137 |
+
"""
|
| 138 |
+
optimizer_cls = registry.get_class(
|
| 139 |
+
data["optimizer_cls"]["__name__"], data["optimizer_cls"]["__module__"]
|
| 140 |
+
)
|
| 141 |
+
if (s := data.get("scheduler_cls")) is not None:
|
| 142 |
+
scheduler_cls = registry.get_class(s["__name__"], s["__module__"])
|
| 143 |
+
else:
|
| 144 |
+
scheduler_cls = None
|
| 145 |
+
return cls(
|
| 146 |
+
optimizer_cls=optimizer_cls,
|
| 147 |
+
optimizer_kwargs=data["optimizer_kwargs"],
|
| 148 |
+
scheduler_cls=scheduler_cls,
|
| 149 |
+
scheduler_kwargs=data["scheduler_kwargs"],
|
| 150 |
+
)
|
| 151 |
+
|
| 152 |
+
|
| 153 |
+
@dataclass
|
| 154 |
+
class TrainingConfig:
|
| 155 |
+
"""
|
| 156 |
+
Configuration for the training phase of active learning.
|
| 157 |
+
|
| 158 |
+
This groups all training-related components together, making it
|
| 159 |
+
clear when training is or isn't being used in the AL workflow.
|
| 160 |
+
|
| 161 |
+
Attributes
|
| 162 |
+
----------
|
| 163 |
+
train_datapool: p.DataPool
|
| 164 |
+
The pool of labeled data to use for training.
|
| 165 |
+
max_training_epochs: int
|
| 166 |
+
The maximum number of epochs to train for. If ``max_fine_tuning_epochs``
|
| 167 |
+
isn't specified, this value is used for all active learning steps.
|
| 168 |
+
val_datapool: p.DataPool | None
|
| 169 |
+
Optional pool of data to use for validation during training.
|
| 170 |
+
optimizer_config: OptimizerConfig
|
| 171 |
+
Configuration for the optimizer and scheduler. Defaults to
|
| 172 |
+
AdamW with lr=1e-4, no scheduler.
|
| 173 |
+
max_fine_tuning_epochs: int | None
|
| 174 |
+
The maximum number of epochs used during fine-tuning steps, i.e. after
|
| 175 |
+
the first active learning step. If ``None``, then the fine-tuning will
|
| 176 |
+
be performed for the duration of the active learning loop.
|
| 177 |
+
train_loop_fn: p.TrainingLoop
|
| 178 |
+
The training loop function that orchestrates the training process.
|
| 179 |
+
This defaults to a concrete implementation, ``DefaultTrainingLoop``,
|
| 180 |
+
which provides a very typical loop that includes the use of static
|
| 181 |
+
capture, etc.
|
| 182 |
+
"""
|
| 183 |
+
|
| 184 |
+
train_datapool: p.DataPool
|
| 185 |
+
max_training_epochs: int
|
| 186 |
+
val_datapool: p.DataPool | None = None
|
| 187 |
+
optimizer_config: OptimizerConfig = field(default_factory=OptimizerConfig)
|
| 188 |
+
max_fine_tuning_epochs: int | None = None
|
| 189 |
+
train_loop_fn: p.TrainingLoop = field(default_factory=DefaultTrainingLoop)
|
| 190 |
+
|
| 191 |
+
def __post_init__(self) -> None:
|
| 192 |
+
"""Validate training configuration."""
|
| 193 |
+
# Validate datapools have consistent interface
|
| 194 |
+
if not hasattr(self.train_datapool, "__len__"):
|
| 195 |
+
raise ValueError("train_datapool must implement __len__")
|
| 196 |
+
if self.val_datapool is not None and not hasattr(self.val_datapool, "__len__"):
|
| 197 |
+
raise ValueError("val_datapool must implement __len__")
|
| 198 |
+
|
| 199 |
+
# Validate training loop is callable
|
| 200 |
+
if not callable(self.train_loop_fn):
|
| 201 |
+
raise ValueError("train_loop_fn must be callable")
|
| 202 |
+
|
| 203 |
+
# set the same value for fine tuning epochs if not provided
|
| 204 |
+
if self.max_fine_tuning_epochs is None:
|
| 205 |
+
self.max_fine_tuning_epochs = self.max_training_epochs
|
| 206 |
+
|
| 207 |
+
def to_dict(self) -> dict[str, Any]:
|
| 208 |
+
"""
|
| 209 |
+
Returns a JSON-serializable dictionary representation of the TrainingConfig.
|
| 210 |
+
|
| 211 |
+
For round-tripping, the registry is used to de-serialize the training loop
|
| 212 |
+
and optimizer configuration. Note that datapools (train_datapool and val_datapool)
|
| 213 |
+
are NOT serialized as they typically contain large datasets, file handles, or other
|
| 214 |
+
non-serializable state.
|
| 215 |
+
|
| 216 |
+
Returns
|
| 217 |
+
-------
|
| 218 |
+
dict[str, Any]
|
| 219 |
+
A dictionary that can be JSON serialized. Excludes datapools.
|
| 220 |
+
|
| 221 |
+
Warnings
|
| 222 |
+
--------
|
| 223 |
+
This method will issue a warning about the exclusion of datapools.
|
| 224 |
+
"""
|
| 225 |
+
# Warn about datapool exclusion
|
| 226 |
+
warn(
|
| 227 |
+
"The `train_datapool` and `val_datapool` attributes are not supported for "
|
| 228 |
+
"serialization and will be excluded from the ``TrainingConfig`` dictionary. "
|
| 229 |
+
"You must re-provide these datapools when deserializing."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
# Serialize optimizer config
|
| 233 |
+
optimizer_dict = self.optimizer_config.to_dict()
|
| 234 |
+
|
| 235 |
+
# Serialize training loop function
|
| 236 |
+
if not hasattr(self.train_loop_fn, "_args"):
|
| 237 |
+
raise ValueError(
|
| 238 |
+
f"Training loop {self.train_loop_fn} does not have an `_args` attribute "
|
| 239 |
+
"which is required for serialization. Make sure your training loop "
|
| 240 |
+
"either subclasses `ActiveLearningProtocol` or implements the `__new__` "
|
| 241 |
+
"method to capture object arguments."
|
| 242 |
+
)
|
| 243 |
+
|
| 244 |
+
train_loop_dict = self.train_loop_fn._args
|
| 245 |
+
|
| 246 |
+
return {
|
| 247 |
+
"max_training_epochs": self.max_training_epochs,
|
| 248 |
+
"max_fine_tuning_epochs": self.max_fine_tuning_epochs,
|
| 249 |
+
"optimizer_config": optimizer_dict,
|
| 250 |
+
"train_loop_fn": train_loop_dict,
|
| 251 |
+
}
|
| 252 |
+
|
| 253 |
+
@classmethod
|
| 254 |
+
def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> TrainingConfig:
|
| 255 |
+
"""
|
| 256 |
+
Creates a TrainingConfig instance from a dictionary.
|
| 257 |
+
|
| 258 |
+
This method assumes that the training loop class is included in the
|
| 259 |
+
``physicsnemo.active_learning.registry``, or a module path is specified
|
| 260 |
+
to import the class from. Note that datapools must be provided via
|
| 261 |
+
kwargs as they are not serialized.
|
| 262 |
+
|
| 263 |
+
Parameters
|
| 264 |
+
----------
|
| 265 |
+
data: dict[str, Any]
|
| 266 |
+
A dictionary that was previously serialized using the ``to_dict`` method.
|
| 267 |
+
**kwargs: Any
|
| 268 |
+
Additional keyword arguments to pass to the constructor. This is where
|
| 269 |
+
you must provide ``train_datapool`` and optionally ``val_datapool``.
|
| 270 |
+
|
| 271 |
+
Returns
|
| 272 |
+
-------
|
| 273 |
+
TrainingConfig
|
| 274 |
+
A new ``TrainingConfig`` instance.
|
| 275 |
+
|
| 276 |
+
Raises
|
| 277 |
+
------
|
| 278 |
+
ValueError
|
| 279 |
+
If required datapools are not provided in kwargs, if the data contains
|
| 280 |
+
unexpected keys, or if object construction fails.
|
| 281 |
+
"""
|
| 282 |
+
# Ensure required datapools are provided
|
| 283 |
+
if "train_datapool" not in kwargs:
|
| 284 |
+
raise ValueError(
|
| 285 |
+
"``train_datapool`` must be provided in kwargs when deserializing "
|
| 286 |
+
"TrainingConfig, as datapools are not serialized."
|
| 287 |
+
)
|
| 288 |
+
|
| 289 |
+
# Reconstruct optimizer config
|
| 290 |
+
optimizer_config = OptimizerConfig.from_dict(data["optimizer_config"])
|
| 291 |
+
|
| 292 |
+
# Reconstruct training loop function
|
| 293 |
+
train_loop_data = data["train_loop_fn"]
|
| 294 |
+
train_loop_fn = registry.construct(
|
| 295 |
+
train_loop_data["__name__"],
|
| 296 |
+
module_path=train_loop_data["__module__"],
|
| 297 |
+
**train_loop_data["__args__"],
|
| 298 |
+
)
|
| 299 |
+
|
| 300 |
+
# Build the config
|
| 301 |
+
try:
|
| 302 |
+
config = cls(
|
| 303 |
+
max_training_epochs=data["max_training_epochs"],
|
| 304 |
+
max_fine_tuning_epochs=data.get("max_fine_tuning_epochs"),
|
| 305 |
+
optimizer_config=optimizer_config,
|
| 306 |
+
train_loop_fn=train_loop_fn,
|
| 307 |
+
**kwargs,
|
| 308 |
+
)
|
| 309 |
+
except Exception as e:
|
| 310 |
+
raise ValueError(
|
| 311 |
+
"Failed to construct ``TrainingConfig`` from dictionary."
|
| 312 |
+
) from e
|
| 313 |
+
|
| 314 |
+
return config
|
| 315 |
+
|
| 316 |
+
|
| 317 |
+
@dataclass
|
| 318 |
+
class StrategiesConfig:
|
| 319 |
+
"""
|
| 320 |
+
Configuration for active learning strategies and data acquisition.
|
| 321 |
+
|
| 322 |
+
This encapsulates the query-label-metrology cycle that is at the
|
| 323 |
+
heart of active learning: strategies for selecting data, labeling it,
|
| 324 |
+
and measuring model uncertainty/performance.
|
| 325 |
+
|
| 326 |
+
Attributes
|
| 327 |
+
----------
|
| 328 |
+
query_strategies: list[p.QueryStrategy]
|
| 329 |
+
The query strategies to use for selecting data to label.
|
| 330 |
+
queue_cls: type[p.AbstractQueue]
|
| 331 |
+
The queue implementation to use for passing data between
|
| 332 |
+
query and labeling phases.
|
| 333 |
+
label_strategy: p.LabelStrategy | None
|
| 334 |
+
The strategy to use for labeling queried data. If None,
|
| 335 |
+
labeling will be skipped.
|
| 336 |
+
metrology_strategies: list[p.MetrologyStrategy] | None
|
| 337 |
+
Strategies for measuring model performance and uncertainty.
|
| 338 |
+
If None, metrology will be skipped.
|
| 339 |
+
unlabeled_datapool: p.DataPool | None
|
| 340 |
+
Pool of unlabeled data that query strategies can sample from.
|
| 341 |
+
Not all strategies require this (some may generate synthetic data).
|
| 342 |
+
"""
|
| 343 |
+
|
| 344 |
+
query_strategies: list[p.QueryStrategy]
|
| 345 |
+
queue_cls: type[p.AbstractQueue]
|
| 346 |
+
label_strategy: p.LabelStrategy | None = None
|
| 347 |
+
metrology_strategies: list[p.MetrologyStrategy] | None = None
|
| 348 |
+
unlabeled_datapool: p.DataPool | None = None
|
| 349 |
+
|
| 350 |
+
def __post_init__(self) -> None:
|
| 351 |
+
"""Validate strategies configuration."""
|
| 352 |
+
# Must have at least one query strategy
|
| 353 |
+
if not self.query_strategies:
|
| 354 |
+
raise ValueError(
|
| 355 |
+
"At least one query strategy must be provided. "
|
| 356 |
+
"Active learning requires a mechanism to select data."
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
# All query strategies must be callable
|
| 360 |
+
for strategy in self.query_strategies:
|
| 361 |
+
if not callable(strategy):
|
| 362 |
+
raise ValueError(f"Query strategy {strategy} must be callable")
|
| 363 |
+
|
| 364 |
+
# Label strategy must be callable if provided
|
| 365 |
+
if self.label_strategy is not None and not callable(self.label_strategy):
|
| 366 |
+
raise ValueError("label_strategy must be callable")
|
| 367 |
+
|
| 368 |
+
# Metrology strategies must be callable if provided
|
| 369 |
+
if self.metrology_strategies is not None:
|
| 370 |
+
if not self.metrology_strategies:
|
| 371 |
+
raise ValueError(
|
| 372 |
+
"metrology_strategies is an empty list. "
|
| 373 |
+
"Either provide strategies or set to None to skip metrology."
|
| 374 |
+
)
|
| 375 |
+
for strategy in self.metrology_strategies:
|
| 376 |
+
if not callable(strategy):
|
| 377 |
+
raise ValueError(f"Metrology strategy {strategy} must be callable")
|
| 378 |
+
|
| 379 |
+
# Validate queue class has basic queue interface
|
| 380 |
+
if not hasattr(self.queue_cls, "__call__"):
|
| 381 |
+
raise ValueError("queue_cls must be a callable class")
|
| 382 |
+
|
| 383 |
+
def to_dict(self) -> dict[str, Any]:
|
| 384 |
+
"""
|
| 385 |
+
Method that converts the present ``StrategiesConfig`` instance into a dictionary
|
| 386 |
+
that can be JSON serialized.
|
| 387 |
+
|
| 388 |
+
This method, for the most part, assumes that strategies are subclasses of
|
| 389 |
+
``ActiveLearningProtocol`` and/or they have an ``_args`` attribute that
|
| 390 |
+
captures the arguments to the constructor.
|
| 391 |
+
|
| 392 |
+
One issue is the inability to reliably serialize the ``unlabeled_datapool``,
|
| 393 |
+
which for the most part, likely does not need serialization as a dataset.
|
| 394 |
+
Regardless, this method will trigger a warning if ``unlabeled_datapool`` is
|
| 395 |
+
not None.
|
| 396 |
+
|
| 397 |
+
Returns
|
| 398 |
+
-------
|
| 399 |
+
dict[str, Any]
|
| 400 |
+
A dictionary that can be JSON serialized.
|
| 401 |
+
"""
|
| 402 |
+
output = defaultdict(list)
|
| 403 |
+
for strategy in self.query_strategies:
|
| 404 |
+
if not hasattr(strategy, "_args"):
|
| 405 |
+
raise ValueError(
|
| 406 |
+
f"Query strategy {strategy} does not have an `_args` attribute"
|
| 407 |
+
" which is required for serialization. Make sure your strategy"
|
| 408 |
+
" either subclasses `ActiveLearningProtocol` or implements"
|
| 409 |
+
" the `__new__` method to capture object arguments."
|
| 410 |
+
)
|
| 411 |
+
output["query_strategies"].append(strategy._args)
|
| 412 |
+
if self.label_strategy is not None:
|
| 413 |
+
if not hasattr(self.label_strategy, "_args"):
|
| 414 |
+
raise ValueError(
|
| 415 |
+
f"Label strategy {self.label_strategy} does not have an `_args` attribute"
|
| 416 |
+
" which is required for serialization. Make sure your strategy"
|
| 417 |
+
" either subclasses `ActiveLearningProtocol` or implements"
|
| 418 |
+
" the `__new__` method to capture object arguments."
|
| 419 |
+
)
|
| 420 |
+
output["label_strategy"] = self.label_strategy._args
|
| 421 |
+
output["queue_cls"] = {
|
| 422 |
+
"__name__": self.queue_cls.__name__,
|
| 423 |
+
"__module__": self.queue_cls.__module__,
|
| 424 |
+
}
|
| 425 |
+
if self.metrology_strategies is not None:
|
| 426 |
+
for strategy in self.metrology_strategies:
|
| 427 |
+
if not hasattr(strategy, "_args"):
|
| 428 |
+
raise ValueError(
|
| 429 |
+
f"Metrology strategy {strategy} does not have an `_args` attribute"
|
| 430 |
+
" which is required for serialization. Make sure your strategy"
|
| 431 |
+
" either subclasses `ActiveLearningProtocol` or implements"
|
| 432 |
+
" the `__new__` method to capture object arguments."
|
| 433 |
+
)
|
| 434 |
+
output["metrology_strategies"].append(strategy._args)
|
| 435 |
+
if self.unlabeled_datapool is not None:
|
| 436 |
+
warn(
|
| 437 |
+
"The `unlabeled_datapool` attribute is not supported for serialization"
|
| 438 |
+
" and will be excluded from the ``StrategiesConfig`` dictionary."
|
| 439 |
+
)
|
| 440 |
+
return output
|
| 441 |
+
|
| 442 |
+
@classmethod
|
| 443 |
+
def from_dict(cls, data: dict[str, Any], **kwargs: Any) -> StrategiesConfig:
|
| 444 |
+
"""
|
| 445 |
+
Create a ``StrategiesConfig`` instance from a dictionary.
|
| 446 |
+
|
| 447 |
+
This method heavily relies on classes being added to the
|
| 448 |
+
``physicsnemo.active_learning.registry``, which is used to instantiate
|
| 449 |
+
all strategies and custom types used in active learning. As a fall
|
| 450 |
+
back, the `registry.construct` method will try and import the class
|
| 451 |
+
from the module path if it is not found in the registry.
|
| 452 |
+
|
| 453 |
+
Parameters
|
| 454 |
+
----------
|
| 455 |
+
data: dict[str, Any]
|
| 456 |
+
A dictionary that was previously serialized using the ``to_dict`` method.
|
| 457 |
+
**kwargs: Any
|
| 458 |
+
Additional keyword arguments to pass to the constructor.
|
| 459 |
+
|
| 460 |
+
Returns
|
| 461 |
+
-------
|
| 462 |
+
StrategiesConfig
|
| 463 |
+
A new ``StrategiesConfig`` instance.
|
| 464 |
+
|
| 465 |
+
Raises
|
| 466 |
+
------
|
| 467 |
+
ValueError:
|
| 468 |
+
If the data contains unexpected keys or if the object construction fails.
|
| 469 |
+
NameError:
|
| 470 |
+
If a class is not found in the registry and no module path is provided.
|
| 471 |
+
ModuleNotFoundError:
|
| 472 |
+
If a module is not found with the specified module path.
|
| 473 |
+
"""
|
| 474 |
+
# ensure that the data contains no unexpected keys
|
| 475 |
+
data_keys = set(data.keys())
|
| 476 |
+
expected_keys = set(cls.__dataclass_fields__.keys())
|
| 477 |
+
extra_keys = data_keys - expected_keys
|
| 478 |
+
if extra_keys:
|
| 479 |
+
raise ValueError(
|
| 480 |
+
f"Unexpected keys in data: {extra_keys}. Expected keys are {expected_keys}."
|
| 481 |
+
)
|
| 482 |
+
# instantiate objects from the serialized data; general strategy is to
|
| 483 |
+
# use `registry.construct` that will try and resolve the class within
|
| 484 |
+
# the registry first, and if not found, then it will try and import the
|
| 485 |
+
# class from the module path.
|
| 486 |
+
output_dict = defaultdict(list)
|
| 487 |
+
for entry in data["query_strategies"]:
|
| 488 |
+
output_dict["query_strategies"].append(
|
| 489 |
+
registry.construct(
|
| 490 |
+
entry["__name__"],
|
| 491 |
+
module_path=entry["__module__"],
|
| 492 |
+
**entry["__args__"],
|
| 493 |
+
)
|
| 494 |
+
)
|
| 495 |
+
if "metrology_strategies" in data:
|
| 496 |
+
for entry in data["metrology_strategies"]:
|
| 497 |
+
output_dict["metrology_strategies"].append(
|
| 498 |
+
registry.construct(
|
| 499 |
+
entry["__name__"],
|
| 500 |
+
module_path=entry["__module__"],
|
| 501 |
+
**entry["__args__"],
|
| 502 |
+
)
|
| 503 |
+
)
|
| 504 |
+
if "label_strategy" in data:
|
| 505 |
+
output_dict["label_strategy"] = registry.construct(
|
| 506 |
+
data["label_strategy"]["__name__"],
|
| 507 |
+
module_path=data["label_strategy"]["__module__"],
|
| 508 |
+
**data["label_strategy"]["__args__"],
|
| 509 |
+
)
|
| 510 |
+
output_dict["queue_cls"] = registry.get_class(
|
| 511 |
+
data["queue_cls"]["__name__"], data["queue_cls"]["__module__"]
|
| 512 |
+
)
|
| 513 |
+
# potentially override with keyword arguments
|
| 514 |
+
output_dict.update(kwargs)
|
| 515 |
+
try:
|
| 516 |
+
config = cls(**output_dict)
|
| 517 |
+
except Exception as e:
|
| 518 |
+
raise ValueError(
|
| 519 |
+
"Failed to construct ``StrategiesConfig`` from dictionary."
|
| 520 |
+
) from e
|
| 521 |
+
return config
|
| 522 |
+
|
| 523 |
+
|
| 524 |
+
@dataclass
|
| 525 |
+
class DriverConfig:
|
| 526 |
+
"""
|
| 527 |
+
Configuration for driver orchestration and infrastructure.
|
| 528 |
+
|
| 529 |
+
This contains parameters that control the overall loop execution,
|
| 530 |
+
logging, checkpointing, and distributed training setup - orthogonal
|
| 531 |
+
to the specific AL or training logic.
|
| 532 |
+
|
| 533 |
+
Attributes
|
| 534 |
+
----------
|
| 535 |
+
batch_size: int
|
| 536 |
+
The batch size to use for data loaders.
|
| 537 |
+
max_active_learning_steps: int | None, default None
|
| 538 |
+
Maximum number of AL iterations to perform. None means infinite.
|
| 539 |
+
run_id: str, default auto-generated UUID
|
| 540 |
+
Unique identifier for this run. Auto-generated if not provided.
|
| 541 |
+
fine_tuning_lr: float | None, default None
|
| 542 |
+
Learning rate to switch to after the first AL step for fine-tuning.
|
| 543 |
+
reset_optim_states: bool, default True
|
| 544 |
+
Whether to reset optimizer states between AL steps.
|
| 545 |
+
skip_training: bool, default False
|
| 546 |
+
If True, skip the training phase entirely.
|
| 547 |
+
skip_metrology: bool, default False
|
| 548 |
+
If True, skip the metrology phase entirely.
|
| 549 |
+
skip_labeling: bool, default False
|
| 550 |
+
If True, skip the labeling phase entirely.
|
| 551 |
+
checkpoint_interval: int, default 1
|
| 552 |
+
Save model checkpoint every N AL steps. 0 disables checkpointing.
|
| 553 |
+
checkpoint_on_training: bool, default False
|
| 554 |
+
If True, save checkpoint at the start of the training phase.
|
| 555 |
+
checkpoint_on_metrology: bool, default False
|
| 556 |
+
If True, save checkpoint at the start of the metrology phase.
|
| 557 |
+
checkpoint_on_query: bool, default False
|
| 558 |
+
If True, save checkpoint at the start of the query phase.
|
| 559 |
+
checkpoint_on_labeling: bool, default True
|
| 560 |
+
If True, save checkpoint at the start of the labeling phase.
|
| 561 |
+
model_checkpoint_frequency: int, default 0
|
| 562 |
+
Save model weights every N epochs during training. 0 means only save
|
| 563 |
+
between active learning phases. Useful for mid-training restarts.
|
| 564 |
+
root_log_dir: str | Path, default Path.cwd() / "active_learning_logs"
|
| 565 |
+
Directory to save logs and checkpoints to. Defaults to
|
| 566 |
+
an 'active_learning_logs' directory in the current working directory.
|
| 567 |
+
dist_manager: DistributedManager | None, default None
|
| 568 |
+
Manager for distributed training configuration.
|
| 569 |
+
collate_fn: callable | None, default None
|
| 570 |
+
Custom collate function for batching data.
|
| 571 |
+
num_dataloader_workers: int, default 0
|
| 572 |
+
Number of worker processes for data loading.
|
| 573 |
+
device: str | torch.device | None, default None
|
| 574 |
+
Device to use for model and data. This is intended for single process
|
| 575 |
+
workflows; for distributed workflows, the device should be set in
|
| 576 |
+
``DistributedManager`` instead. If not specified, then the device
|
| 577 |
+
will default to ``torch.get_default_device()``.
|
| 578 |
+
dtype: torch.dtype | None, default None
|
| 579 |
+
The dtype to use for model and data, and AMP contexts. If not provided,
|
| 580 |
+
then the dtype will default to ``torch.get_default_dtype()``.
|
| 581 |
+
"""
|
| 582 |
+
|
| 583 |
+
batch_size: int
|
| 584 |
+
max_active_learning_steps: int | None = None
|
| 585 |
+
run_id: str = field(default_factory=lambda: str(uuid.uuid4()))
|
| 586 |
+
fine_tuning_lr: float | None = None # TODO: move to TrainingConfig
|
| 587 |
+
reset_optim_states: bool = True
|
| 588 |
+
skip_training: bool = False
|
| 589 |
+
skip_metrology: bool = False
|
| 590 |
+
skip_labeling: bool = False
|
| 591 |
+
checkpoint_interval: int = 1
|
| 592 |
+
checkpoint_on_training: bool = False
|
| 593 |
+
checkpoint_on_metrology: bool = False
|
| 594 |
+
checkpoint_on_query: bool = False
|
| 595 |
+
checkpoint_on_labeling: bool = True
|
| 596 |
+
model_checkpoint_frequency: int = 0
|
| 597 |
+
root_log_dir: str | Path = field(default=Path.cwd() / "active_learning_logs")
|
| 598 |
+
dist_manager: DistributedManager | None = None
|
| 599 |
+
collate_fn: callable | None = None
|
| 600 |
+
num_dataloader_workers: int = 0
|
| 601 |
+
device: str | torch.device | None = None
|
| 602 |
+
dtype: torch.dtype | None = None
|
| 603 |
+
|
| 604 |
+
def __post_init__(self) -> None:
|
| 605 |
+
"""Validate driver configuration."""
|
| 606 |
+
if self.max_active_learning_steps is None:
|
| 607 |
+
self.max_active_learning_steps = float("inf")
|
| 608 |
+
|
| 609 |
+
if (
|
| 610 |
+
self.max_active_learning_steps is not None
|
| 611 |
+
and self.max_active_learning_steps <= 0
|
| 612 |
+
):
|
| 613 |
+
raise ValueError(
|
| 614 |
+
"`max_active_learning_steps` must be a positive integer or None."
|
| 615 |
+
)
|
| 616 |
+
|
| 617 |
+
if not math.isfinite(self.batch_size) or self.batch_size <= 0:
|
| 618 |
+
raise ValueError("`batch_size` must be a positive integer.")
|
| 619 |
+
|
| 620 |
+
if not math.isfinite(self.checkpoint_interval) or self.checkpoint_interval < 0:
|
| 621 |
+
raise ValueError(
|
| 622 |
+
"`checkpoint_interval` must be a non-negative integer. "
|
| 623 |
+
"Use 0 to disable checkpointing."
|
| 624 |
+
)
|
| 625 |
+
|
| 626 |
+
if self.fine_tuning_lr is not None and self.fine_tuning_lr <= 0:
|
| 627 |
+
raise ValueError("`fine_tuning_lr` must be positive if provided.")
|
| 628 |
+
|
| 629 |
+
if self.num_dataloader_workers < 0:
|
| 630 |
+
raise ValueError("`num_dataloader_workers` must be non-negative.")
|
| 631 |
+
|
| 632 |
+
if self.model_checkpoint_frequency < 0:
|
| 633 |
+
raise ValueError("`model_checkpoint_frequency` must be non-negative.")
|
| 634 |
+
|
| 635 |
+
if isinstance(self.root_log_dir, str):
|
| 636 |
+
self.root_log_dir = Path(self.root_log_dir)
|
| 637 |
+
|
| 638 |
+
# Validate collate_fn if provided
|
| 639 |
+
if self.collate_fn is not None and not callable(self.collate_fn):
|
| 640 |
+
raise ValueError("`collate_fn` must be callable if provided.")
|
| 641 |
+
|
| 642 |
+
# device and dtype setup when not using DistributedManager
|
| 643 |
+
if self.device is None and not self.dist_manager:
|
| 644 |
+
self.device = torch.get_default_device()
|
| 645 |
+
if self.dtype is None:
|
| 646 |
+
self.dtype = torch.get_default_dtype()
|
| 647 |
+
|
| 648 |
+
def to_json(self) -> str:
|
| 649 |
+
"""
|
| 650 |
+
Returns a JSON string representation of the ``DriverConfig``.
|
| 651 |
+
|
| 652 |
+
Note that certain fields are not serialized and must be provided when
|
| 653 |
+
deserializing: ``dist_manager``, ``collate_fn``.
|
| 654 |
+
|
| 655 |
+
Returns
|
| 656 |
+
-------
|
| 657 |
+
str
|
| 658 |
+
A JSON string representation of the config.
|
| 659 |
+
"""
|
| 660 |
+
# base dict representation skips Python objects
|
| 661 |
+
dict_repr = {
|
| 662 |
+
key: self.__dict__[key]
|
| 663 |
+
for key in self.__dict__
|
| 664 |
+
if key
|
| 665 |
+
not in ["dist_manager", "collate_fn", "root_log_dir", "device", "dtype"]
|
| 666 |
+
}
|
| 667 |
+
# Note: checkpoint flags are included in dict_repr automatically
|
| 668 |
+
dict_repr["default_dtype"] = str(torch.get_default_dtype())
|
| 669 |
+
dict_repr["log_dir"] = str(self.root_log_dir)
|
| 670 |
+
# Convert dtype to string for JSON serialization
|
| 671 |
+
if self.dtype is not None:
|
| 672 |
+
dict_repr["dtype"] = str(self.dtype)
|
| 673 |
+
else:
|
| 674 |
+
dict_repr["dtype"] = None
|
| 675 |
+
if self.dist_manager is not None:
|
| 676 |
+
dict_repr["world_size"] = self.dist_manager.world_size
|
| 677 |
+
dict_repr["device"] = self.dist_manager.device.type
|
| 678 |
+
dict_repr["dist_manager_init_method"] = (
|
| 679 |
+
self.dist_manager._initialization_method
|
| 680 |
+
)
|
| 681 |
+
else:
|
| 682 |
+
if dist.is_initialized():
|
| 683 |
+
world_size = dist.get_world_size()
|
| 684 |
+
else:
|
| 685 |
+
world_size = 1
|
| 686 |
+
dict_repr["world_size"] = world_size
|
| 687 |
+
if self.device is not None:
|
| 688 |
+
dict_repr["device"] = (
|
| 689 |
+
str(self.device)
|
| 690 |
+
if hasattr(self.device, "type")
|
| 691 |
+
else str(self.device)
|
| 692 |
+
)
|
| 693 |
+
else:
|
| 694 |
+
dict_repr["device"] = torch.get_default_device().type
|
| 695 |
+
dict_repr["dist_manager_init_method"] = None
|
| 696 |
+
if self.collate_fn is not None:
|
| 697 |
+
dict_repr["collate_fn"] = self.collate_fn.__name__
|
| 698 |
+
else:
|
| 699 |
+
dict_repr["collate_fn"] = None
|
| 700 |
+
return dumps(dict_repr, indent=2)
|
| 701 |
+
|
| 702 |
+
@classmethod
|
| 703 |
+
def from_json(cls, json_str: str, **kwargs: Any) -> DriverConfig:
|
| 704 |
+
"""
|
| 705 |
+
Creates a DriverConfig instance from a JSON string.
|
| 706 |
+
|
| 707 |
+
This method reconstructs a DriverConfig from JSON. Note that certain
|
| 708 |
+
fields cannot be serialized and must be provided via kwargs:
|
| 709 |
+
- ``dist_manager``: DistributedManager instance (optional)
|
| 710 |
+
- ``collate_fn``: Custom collate function (optional)
|
| 711 |
+
|
| 712 |
+
Parameters
|
| 713 |
+
----------
|
| 714 |
+
json_str: str
|
| 715 |
+
A JSON string that was previously serialized using ``to_json()``.
|
| 716 |
+
**kwargs: Any
|
| 717 |
+
Additional keyword arguments to override or provide non-serializable
|
| 718 |
+
fields like ``dist_manager`` and ``collate_fn``.
|
| 719 |
+
|
| 720 |
+
Returns
|
| 721 |
+
-------
|
| 722 |
+
DriverConfig
|
| 723 |
+
A new ``DriverConfig`` instance.
|
| 724 |
+
|
| 725 |
+
Raises
|
| 726 |
+
------
|
| 727 |
+
ValueError
|
| 728 |
+
If the JSON cannot be parsed or required fields are missing.
|
| 729 |
+
|
| 730 |
+
Notes
|
| 731 |
+
-----
|
| 732 |
+
The device and dtype fields are reconstructed from their string
|
| 733 |
+
representations. The ``log_dir`` field in JSON is mapped to
|
| 734 |
+
``root_log_dir`` in the config.
|
| 735 |
+
"""
|
| 736 |
+
import json
|
| 737 |
+
|
| 738 |
+
try:
|
| 739 |
+
data = json.loads(json_str)
|
| 740 |
+
except json.JSONDecodeError as e:
|
| 741 |
+
raise ValueError(f"Invalid JSON string: {e}") from e
|
| 742 |
+
|
| 743 |
+
# Define fields that are not actual DriverConfig constructor parameters
|
| 744 |
+
metadata_fields = [
|
| 745 |
+
"default_dtype",
|
| 746 |
+
"world_size",
|
| 747 |
+
"dist_manager_init_method",
|
| 748 |
+
"log_dir", # handled separately as root_log_dir
|
| 749 |
+
]
|
| 750 |
+
non_serializable_fields = [
|
| 751 |
+
"dist_manager",
|
| 752 |
+
"collate_fn",
|
| 753 |
+
"root_log_dir",
|
| 754 |
+
"device",
|
| 755 |
+
"dtype",
|
| 756 |
+
]
|
| 757 |
+
|
| 758 |
+
# Extract serializable fields that map directly
|
| 759 |
+
config_fields = {
|
| 760 |
+
key: value
|
| 761 |
+
for key, value in data.items()
|
| 762 |
+
if key not in metadata_fields + non_serializable_fields
|
| 763 |
+
}
|
| 764 |
+
|
| 765 |
+
# Handle root_log_dir (stored as "log_dir" in JSON)
|
| 766 |
+
if "log_dir" in data:
|
| 767 |
+
config_fields["root_log_dir"] = Path(data["log_dir"])
|
| 768 |
+
|
| 769 |
+
# Handle device reconstruction from string
|
| 770 |
+
if "device" in data and data["device"] is not None:
|
| 771 |
+
device_str = data["device"]
|
| 772 |
+
# Parse device strings like "cuda:0", "cpu", "cuda", etc.
|
| 773 |
+
config_fields["device"] = torch.device(device_str)
|
| 774 |
+
|
| 775 |
+
# Handle dtype reconstruction from string
|
| 776 |
+
if "dtype" in data and data["dtype"] is not None:
|
| 777 |
+
dtype_str = data["dtype"]
|
| 778 |
+
# Map string representations to torch dtypes
|
| 779 |
+
dtype_map = {
|
| 780 |
+
"torch.float32": torch.float32,
|
| 781 |
+
"torch.float64": torch.float64,
|
| 782 |
+
"torch.float16": torch.float16,
|
| 783 |
+
"torch.bfloat16": torch.bfloat16,
|
| 784 |
+
"torch.int32": torch.int32,
|
| 785 |
+
"torch.int64": torch.int64,
|
| 786 |
+
"torch.int8": torch.int8,
|
| 787 |
+
"torch.uint8": torch.uint8,
|
| 788 |
+
}
|
| 789 |
+
if dtype_str in dtype_map:
|
| 790 |
+
config_fields["dtype"] = dtype_map[dtype_str]
|
| 791 |
+
else:
|
| 792 |
+
warn(
|
| 793 |
+
f"Unknown dtype string '{dtype_str}' in JSON. "
|
| 794 |
+
"Using default dtype instead."
|
| 795 |
+
)
|
| 796 |
+
|
| 797 |
+
# Merge with provided kwargs (allows overriding and adding non-serializable fields)
|
| 798 |
+
config_fields.update(kwargs)
|
| 799 |
+
|
| 800 |
+
# Create the config
|
| 801 |
+
try:
|
| 802 |
+
config = cls(**config_fields)
|
| 803 |
+
except Exception as e:
|
| 804 |
+
raise ValueError(
|
| 805 |
+
"Failed to construct ``DriverConfig`` from JSON string."
|
| 806 |
+
) from e
|
| 807 |
+
|
| 808 |
+
return config
|
physics_mcp/source/physicsnemo/active_learning/driver.py
ADDED
|
@@ -0,0 +1,1449 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
This module contains the definition for an active learning driver
|
| 19 |
+
class, which is responsible for orchestration and automation of
|
| 20 |
+
the end-to-end active learning process.
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
from __future__ import annotations
|
| 24 |
+
|
| 25 |
+
import inspect
|
| 26 |
+
import pickle
|
| 27 |
+
from contextlib import contextmanager
|
| 28 |
+
from copy import deepcopy
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from typing import Any, Generator
|
| 32 |
+
|
| 33 |
+
import torch
|
| 34 |
+
from torch import distributed as dist
|
| 35 |
+
from torch.nn.parallel import DistributedDataParallel
|
| 36 |
+
from torch.utils.data import DataLoader, DistributedSampler
|
| 37 |
+
|
| 38 |
+
from physicsnemo import Module
|
| 39 |
+
from physicsnemo import __version__ as physicsnemo_version
|
| 40 |
+
from physicsnemo.active_learning import protocols as p
|
| 41 |
+
from physicsnemo.active_learning.config import (
|
| 42 |
+
DriverConfig,
|
| 43 |
+
StrategiesConfig,
|
| 44 |
+
TrainingConfig,
|
| 45 |
+
)
|
| 46 |
+
from physicsnemo.active_learning.logger import (
|
| 47 |
+
ActiveLearningLoggerAdapter,
|
| 48 |
+
setup_active_learning_logger,
|
| 49 |
+
)
|
| 50 |
+
from physicsnemo.distributed import DistributedManager
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
@dataclass
|
| 54 |
+
class ActiveLearningCheckpoint:
|
| 55 |
+
"""
|
| 56 |
+
Metadata associated with an ongoing (or completed) active
|
| 57 |
+
learning experiment.
|
| 58 |
+
|
| 59 |
+
The information contained in this metadata should be sufficient
|
| 60 |
+
to restart the active learning experiment at the nearest point:
|
| 61 |
+
for example, training should be able to continue from an epoch,
|
| 62 |
+
while for querying/sampling, etc. we continue from a pre-existing
|
| 63 |
+
queue.
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
driver_config: DriverConfig
|
| 67 |
+
strategies_config: StrategiesConfig
|
| 68 |
+
active_learning_step_idx: int
|
| 69 |
+
active_learning_phase: p.ActiveLearningPhase
|
| 70 |
+
physicsnemo_version: str = physicsnemo_version
|
| 71 |
+
training_config: TrainingConfig | None = None
|
| 72 |
+
optimizer_state: dict[str, Any] | None = None
|
| 73 |
+
lr_scheduler_state: dict[str, Any] | None = None
|
| 74 |
+
has_query_queue: bool = False
|
| 75 |
+
has_label_queue: bool = False
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
class Driver(p.DriverProtocol):
|
| 79 |
+
"""
|
| 80 |
+
Provides a simple implementation of the ``DriverProtocol`` used to
|
| 81 |
+
orchestrate an active learning process within PhysicsNeMo.
|
| 82 |
+
|
| 83 |
+
At a high level, the active learning process is broken down into four
|
| 84 |
+
phases: training, metrology, query, and labeling.
|
| 85 |
+
|
| 86 |
+
To understand the orchestration, start by inspecting the
|
| 87 |
+
``active_learning_step`` method, which defines a single iteration of
|
| 88 |
+
the active learning loop, which is dispatched by the ``run`` method.
|
| 89 |
+
From there, it should be relatively straightforward to trace the
|
| 90 |
+
remaining components.
|
| 91 |
+
|
| 92 |
+
Attributes
|
| 93 |
+
----------
|
| 94 |
+
config: DriverConfig
|
| 95 |
+
Infrastructure and orchestration configuration.
|
| 96 |
+
learner: Module | p.LearnerProtocol
|
| 97 |
+
The learner module for the active learning process.
|
| 98 |
+
strategies_config: StrategiesConfig
|
| 99 |
+
Active learning strategies (query, label, metrology).
|
| 100 |
+
training_config: TrainingConfig | None
|
| 101 |
+
Training components. None if training is skipped.
|
| 102 |
+
inference_fn: p.InferenceProtocol | None
|
| 103 |
+
Custom inference function.
|
| 104 |
+
active_learning_step_idx: int
|
| 105 |
+
Current iteration index of the active learning loop.
|
| 106 |
+
query_queue: p.AbstractQueue
|
| 107 |
+
Queue populated with data by query strategies.
|
| 108 |
+
label_queue: p.AbstractQueue
|
| 109 |
+
Queue populated with labeled data by the label strategy.
|
| 110 |
+
optimizer: torch.optim.Optimizer | None
|
| 111 |
+
Configured optimizer (set after configure_optimizer is called).
|
| 112 |
+
lr_scheduler: torch.optim.lr_scheduler._LRScheduler | None
|
| 113 |
+
Configured learning rate scheduler.
|
| 114 |
+
logger: logging.Logger
|
| 115 |
+
Persistent logger for the active learning process.
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
# Phase execution order for active learning step (immutable)
|
| 119 |
+
_PHASE_ORDER = [
|
| 120 |
+
p.ActiveLearningPhase.TRAINING,
|
| 121 |
+
p.ActiveLearningPhase.METROLOGY,
|
| 122 |
+
p.ActiveLearningPhase.QUERY,
|
| 123 |
+
p.ActiveLearningPhase.LABELING,
|
| 124 |
+
]
|
| 125 |
+
|
| 126 |
+
def __init__(
|
| 127 |
+
self,
|
| 128 |
+
config: DriverConfig,
|
| 129 |
+
learner: Module | p.LearnerProtocol,
|
| 130 |
+
strategies_config: StrategiesConfig,
|
| 131 |
+
training_config: TrainingConfig | None = None,
|
| 132 |
+
inference_fn: p.InferenceProtocol | None = None,
|
| 133 |
+
) -> None:
|
| 134 |
+
"""
|
| 135 |
+
Initializes the active learning driver.
|
| 136 |
+
|
| 137 |
+
At the bare minimum, the driver requires a config, learner, and
|
| 138 |
+
strategies config to be used in a purely querying loop. Additional
|
| 139 |
+
arguments can be provided to enable training and other workflows.
|
| 140 |
+
|
| 141 |
+
Parameters
|
| 142 |
+
----------
|
| 143 |
+
config: DriverConfig
|
| 144 |
+
Orchestration and infrastructure configuration, for example
|
| 145 |
+
the batch size, the log directory, the distributed manager, etc.
|
| 146 |
+
learner: Module | p.LearnerProtocol
|
| 147 |
+
The model to use for active learning.
|
| 148 |
+
strategies_config: StrategiesConfig
|
| 149 |
+
Container for active learning strategies (query, label, metrology).
|
| 150 |
+
training_config: TrainingConfig | None
|
| 151 |
+
Training components. Required if ``skip_training`` is False in
|
| 152 |
+
the ``DriverConfig``.
|
| 153 |
+
inference_fn: p.InferenceProtocol | None
|
| 154 |
+
Custom inference function. If None, uses ``learner.__call__``.
|
| 155 |
+
This is not actually called by the driver, but is stored as an
|
| 156 |
+
attribute for attached strategies to use as needed.
|
| 157 |
+
"""
|
| 158 |
+
# Configs have already validated themselves in __post_init__
|
| 159 |
+
self.config = config
|
| 160 |
+
self.learner = learner
|
| 161 |
+
self.strategies_config = strategies_config
|
| 162 |
+
self.training_config = training_config
|
| 163 |
+
self.inference_fn = inference_fn
|
| 164 |
+
self.active_learning_step_idx = 0
|
| 165 |
+
self.current_phase: p.ActiveLearningPhase | None = (
|
| 166 |
+
None # Track current phase for logging context
|
| 167 |
+
)
|
| 168 |
+
self._last_checkpoint_path: Path | None = None
|
| 169 |
+
|
| 170 |
+
# Validate cross-config constraints
|
| 171 |
+
self._validate_config_consistency()
|
| 172 |
+
|
| 173 |
+
self._setup_logger()
|
| 174 |
+
self.attach_strategies()
|
| 175 |
+
|
| 176 |
+
# Initialize queues from strategies_config
|
| 177 |
+
self.query_queue = strategies_config.queue_cls()
|
| 178 |
+
self.label_queue = strategies_config.queue_cls()
|
| 179 |
+
|
| 180 |
+
def _validate_config_consistency(self) -> None:
|
| 181 |
+
"""
|
| 182 |
+
Validate consistency across configs.
|
| 183 |
+
|
| 184 |
+
Each config validates itself, but this method checks relationships
|
| 185 |
+
between configs that can only be validated when composed together.
|
| 186 |
+
"""
|
| 187 |
+
# If training is not skipped, training_config must be provided
|
| 188 |
+
if not self.config.skip_training and self.training_config is None:
|
| 189 |
+
raise ValueError(
|
| 190 |
+
"`training_config` must be provided when `skip_training` is False."
|
| 191 |
+
)
|
| 192 |
+
|
| 193 |
+
# If labeling is not skipped, must have label strategy and train datapool
|
| 194 |
+
if not self.config.skip_labeling:
|
| 195 |
+
if self.strategies_config.label_strategy is None:
|
| 196 |
+
raise ValueError(
|
| 197 |
+
"`label_strategy` must be provided in strategies_config "
|
| 198 |
+
"when `skip_labeling` is False."
|
| 199 |
+
)
|
| 200 |
+
if (
|
| 201 |
+
self.training_config is None
|
| 202 |
+
or self.training_config.train_datapool is None
|
| 203 |
+
):
|
| 204 |
+
raise ValueError(
|
| 205 |
+
"`train_datapool` must be provided in training_config "
|
| 206 |
+
"when `skip_labeling` is False (labeled data is appended to it)."
|
| 207 |
+
)
|
| 208 |
+
|
| 209 |
+
# If fine-tuning lr is set, must have training enabled
|
| 210 |
+
if self.config.fine_tuning_lr is not None and self.config.skip_training:
|
| 211 |
+
raise ValueError(
|
| 212 |
+
"`fine_tuning_lr` has no effect when `skip_training` is True."
|
| 213 |
+
)
|
| 214 |
+
|
| 215 |
+
@property
|
| 216 |
+
def query_strategies(self) -> list[p.QueryStrategy]:
|
| 217 |
+
"""Returns the query strategies from strategies_config."""
|
| 218 |
+
return self.strategies_config.query_strategies
|
| 219 |
+
|
| 220 |
+
@property
|
| 221 |
+
def label_strategy(self) -> p.LabelStrategy | None:
|
| 222 |
+
"""Returns the label strategy from strategies_config."""
|
| 223 |
+
return self.strategies_config.label_strategy
|
| 224 |
+
|
| 225 |
+
@property
|
| 226 |
+
def metrology_strategies(self) -> list[p.MetrologyStrategy] | None:
|
| 227 |
+
"""Returns the metrology strategies from strategies_config."""
|
| 228 |
+
return self.strategies_config.metrology_strategies
|
| 229 |
+
|
| 230 |
+
@property
|
| 231 |
+
def unlabeled_datapool(self) -> p.DataPool | None:
|
| 232 |
+
"""Returns the unlabeled datapool from strategies_config."""
|
| 233 |
+
return self.strategies_config.unlabeled_datapool
|
| 234 |
+
|
| 235 |
+
@property
|
| 236 |
+
def train_datapool(self) -> p.DataPool | None:
|
| 237 |
+
"""Returns the training datapool from training_config."""
|
| 238 |
+
return self.training_config.train_datapool if self.training_config else None
|
| 239 |
+
|
| 240 |
+
@property
|
| 241 |
+
def val_datapool(self) -> p.DataPool | None:
|
| 242 |
+
"""Returns the validation datapool from training_config."""
|
| 243 |
+
return self.training_config.val_datapool if self.training_config else None
|
| 244 |
+
|
| 245 |
+
@property
|
| 246 |
+
def train_loop_fn(self) -> p.TrainingLoop | None:
|
| 247 |
+
"""Returns the training loop function from training_config."""
|
| 248 |
+
return self.training_config.train_loop_fn if self.training_config else None
|
| 249 |
+
|
| 250 |
+
@property
|
| 251 |
+
def device(self) -> torch.device:
|
| 252 |
+
"""Return a consistent device interface to use across the driver."""
|
| 253 |
+
if self.dist_manager is not None and self.dist_manager.is_initialized():
|
| 254 |
+
return self.dist_manager.device
|
| 255 |
+
else:
|
| 256 |
+
return torch.get_default_device()
|
| 257 |
+
|
| 258 |
+
@property
|
| 259 |
+
def run_id(self) -> str:
|
| 260 |
+
"""Returns the run id from the ``DriverConfig``.
|
| 261 |
+
|
| 262 |
+
Returns
|
| 263 |
+
-------
|
| 264 |
+
str
|
| 265 |
+
The run id.
|
| 266 |
+
"""
|
| 267 |
+
return self.config.run_id
|
| 268 |
+
|
| 269 |
+
@property
|
| 270 |
+
def log_dir(self) -> Path:
|
| 271 |
+
"""Returns the log directory.
|
| 272 |
+
|
| 273 |
+
Note that this is the ``DriverConfig.root_log_dir`` combined
|
| 274 |
+
with the shortened run ID for the current run.
|
| 275 |
+
|
| 276 |
+
Effectively, this means that each run will have its own
|
| 277 |
+
directory for logs, checkpoints, etc.
|
| 278 |
+
|
| 279 |
+
Returns
|
| 280 |
+
-------
|
| 281 |
+
Path
|
| 282 |
+
The log directory.
|
| 283 |
+
"""
|
| 284 |
+
return self.config.root_log_dir / self.short_run_id
|
| 285 |
+
|
| 286 |
+
@property
|
| 287 |
+
def short_run_id(self) -> str:
|
| 288 |
+
"""Returns the first 8 characters of the run id.
|
| 289 |
+
|
| 290 |
+
The 8 character limit assumes that the run ID is a UUID4.
|
| 291 |
+
This is particularly useful for user-facing interfaces,
|
| 292 |
+
where you do not necessarily want to reference the full UUID.
|
| 293 |
+
|
| 294 |
+
Returns
|
| 295 |
+
-------
|
| 296 |
+
str
|
| 297 |
+
The first 8 characters of the run id.
|
| 298 |
+
"""
|
| 299 |
+
return self.run_id[:8]
|
| 300 |
+
|
| 301 |
+
@property
|
| 302 |
+
def last_checkpoint(self) -> Path | None:
|
| 303 |
+
"""
|
| 304 |
+
Returns path to the most recently saved checkpoint.
|
| 305 |
+
|
| 306 |
+
Returns
|
| 307 |
+
-------
|
| 308 |
+
Path | None
|
| 309 |
+
Path to the last checkpoint directory, or None if no checkpoint
|
| 310 |
+
has been saved yet.
|
| 311 |
+
"""
|
| 312 |
+
return self._last_checkpoint_path
|
| 313 |
+
|
| 314 |
+
@property
|
| 315 |
+
def active_learning_step_idx(self) -> int:
|
| 316 |
+
"""
|
| 317 |
+
Returns the current active learning step index.
|
| 318 |
+
|
| 319 |
+
This represents the number of times the active learning step
|
| 320 |
+
has been called, i.e. the number of iterations of the loop.
|
| 321 |
+
|
| 322 |
+
Returns
|
| 323 |
+
-------
|
| 324 |
+
int
|
| 325 |
+
The current active learning step index.
|
| 326 |
+
"""
|
| 327 |
+
return self._active_learning_step_idx
|
| 328 |
+
|
| 329 |
+
@active_learning_step_idx.setter
|
| 330 |
+
def active_learning_step_idx(self, value: int) -> None:
|
| 331 |
+
"""
|
| 332 |
+
Sets the current active learning step index.
|
| 333 |
+
|
| 334 |
+
Parameters
|
| 335 |
+
----------
|
| 336 |
+
value: int
|
| 337 |
+
The new active learning step index.
|
| 338 |
+
|
| 339 |
+
Raises
|
| 340 |
+
------
|
| 341 |
+
ValueError
|
| 342 |
+
If the new active learning step index is negative.
|
| 343 |
+
"""
|
| 344 |
+
if value < 0:
|
| 345 |
+
raise ValueError("Active learning step index must be non-negative.")
|
| 346 |
+
self._active_learning_step_idx = value
|
| 347 |
+
|
| 348 |
+
@property
|
| 349 |
+
def dist_manager(self) -> DistributedManager | None:
|
| 350 |
+
"""Returns the distributed manager, if it was specified as part
|
| 351 |
+
of the `DriverConfig` configuration.
|
| 352 |
+
|
| 353 |
+
Returns
|
| 354 |
+
-------
|
| 355 |
+
DistributedManager | None
|
| 356 |
+
The distributed manager.
|
| 357 |
+
"""
|
| 358 |
+
return self.config.dist_manager
|
| 359 |
+
|
| 360 |
+
def configure_optimizer(self) -> None:
|
| 361 |
+
"""Setup optimizer and LR schedulers from training_config."""
|
| 362 |
+
if self.training_config is None:
|
| 363 |
+
self.optimizer = None
|
| 364 |
+
self.lr_scheduler = None
|
| 365 |
+
return
|
| 366 |
+
|
| 367 |
+
opt_cfg = self.training_config.optimizer_config
|
| 368 |
+
|
| 369 |
+
if opt_cfg.optimizer_cls is not None:
|
| 370 |
+
try:
|
| 371 |
+
_ = inspect.signature(opt_cfg.optimizer_cls).bind(
|
| 372 |
+
self.learner.parameters(), **opt_cfg.optimizer_kwargs
|
| 373 |
+
)
|
| 374 |
+
except TypeError as e:
|
| 375 |
+
raise ValueError(
|
| 376 |
+
f"Invalid optimizer kwargs for {opt_cfg.optimizer_cls}; {e}"
|
| 377 |
+
)
|
| 378 |
+
self.optimizer = opt_cfg.optimizer_cls(
|
| 379 |
+
self.learner.parameters(), **opt_cfg.optimizer_kwargs
|
| 380 |
+
)
|
| 381 |
+
else:
|
| 382 |
+
self.optimizer = None
|
| 383 |
+
return
|
| 384 |
+
|
| 385 |
+
if opt_cfg.scheduler_cls is not None and self.optimizer is not None:
|
| 386 |
+
try:
|
| 387 |
+
_ = inspect.signature(opt_cfg.scheduler_cls).bind(
|
| 388 |
+
self.optimizer, **opt_cfg.scheduler_kwargs
|
| 389 |
+
)
|
| 390 |
+
except TypeError as e:
|
| 391 |
+
raise ValueError(
|
| 392 |
+
f"Invalid LR scheduler kwargs for {opt_cfg.scheduler_cls}; {e}"
|
| 393 |
+
)
|
| 394 |
+
self.lr_scheduler = opt_cfg.scheduler_cls(
|
| 395 |
+
self.optimizer, **opt_cfg.scheduler_kwargs
|
| 396 |
+
)
|
| 397 |
+
else:
|
| 398 |
+
self.lr_scheduler = None
|
| 399 |
+
# in the case where we want to reset optimizer states between active learning steps
|
| 400 |
+
if self.config.reset_optim_states and self.is_optimizer_configured:
|
| 401 |
+
self._original_optim_state = deepcopy(self.optimizer.state_dict())
|
| 402 |
+
|
| 403 |
+
@property
|
| 404 |
+
def is_optimizer_configured(self) -> bool:
|
| 405 |
+
"""Returns whether the optimizer is configured."""
|
| 406 |
+
return getattr(self, "optimizer", None) is not None
|
| 407 |
+
|
| 408 |
+
@property
|
| 409 |
+
def is_lr_scheduler_configured(self) -> bool:
|
| 410 |
+
"""Returns whether the LR scheduler is configured."""
|
| 411 |
+
return getattr(self, "lr_scheduler", None) is not None
|
| 412 |
+
|
| 413 |
+
def attach_strategies(self) -> None:
|
| 414 |
+
"""Calls ``strategy.attach`` for all available strategies."""
|
| 415 |
+
super().attach_strategies()
|
| 416 |
+
|
| 417 |
+
def _setup_logger(self) -> None:
|
| 418 |
+
"""
|
| 419 |
+
Sets up a persistent logger for the driver.
|
| 420 |
+
|
| 421 |
+
This logger is specialized in that it provides additional context
|
| 422 |
+
information depending on the part of the active learning cycle.
|
| 423 |
+
"""
|
| 424 |
+
base_logger = setup_active_learning_logger(
|
| 425 |
+
"core.active_learning",
|
| 426 |
+
run_id=self.run_id,
|
| 427 |
+
log_dir=self.log_dir,
|
| 428 |
+
)
|
| 429 |
+
# Wrap with adapter to automatically include iteration context
|
| 430 |
+
self.logger = ActiveLearningLoggerAdapter(base_logger, driver_ref=self)
|
| 431 |
+
|
| 432 |
+
def _should_checkpoint_at_step(self) -> bool:
|
| 433 |
+
"""
|
| 434 |
+
Determine if a checkpoint should be saved at the current AL step.
|
| 435 |
+
|
| 436 |
+
Uses the `checkpoint_interval` from config to decide. If interval is 0,
|
| 437 |
+
checkpointing is disabled. Otherwise, checkpoint at step 0 and every
|
| 438 |
+
N steps thereafter.
|
| 439 |
+
|
| 440 |
+
Returns
|
| 441 |
+
-------
|
| 442 |
+
bool
|
| 443 |
+
True if checkpoint should be saved, False otherwise.
|
| 444 |
+
"""
|
| 445 |
+
if self.config.checkpoint_interval == 0:
|
| 446 |
+
return False
|
| 447 |
+
# Always checkpoint at step 0, then every checkpoint_interval steps
|
| 448 |
+
return self.active_learning_step_idx % self.config.checkpoint_interval == 0
|
| 449 |
+
|
| 450 |
+
def _serialize_queue(self, queue: p.AbstractQueue, file_path: Path) -> bool:
|
| 451 |
+
"""
|
| 452 |
+
Serialize queue to a file.
|
| 453 |
+
|
| 454 |
+
If queue implements `to_list()`, serialize the list. Otherwise, use
|
| 455 |
+
torch.save to serialize the entire queue object.
|
| 456 |
+
|
| 457 |
+
Parameters
|
| 458 |
+
----------
|
| 459 |
+
queue: p.AbstractQueue
|
| 460 |
+
The queue to serialize.
|
| 461 |
+
file_path: Path
|
| 462 |
+
Path where the queue should be saved.
|
| 463 |
+
|
| 464 |
+
Returns
|
| 465 |
+
-------
|
| 466 |
+
bool
|
| 467 |
+
True if serialization succeeded, False otherwise.
|
| 468 |
+
"""
|
| 469 |
+
try:
|
| 470 |
+
if hasattr(queue, "to_list") and callable(getattr(queue, "to_list")):
|
| 471 |
+
# Use custom serialization method
|
| 472 |
+
queue_data = {"type": "list", "data": queue.to_list()}
|
| 473 |
+
else:
|
| 474 |
+
# Fallback to torch.save for the entire queue
|
| 475 |
+
queue_data = {"type": "torch", "data": queue}
|
| 476 |
+
|
| 477 |
+
torch.save(queue_data, file_path)
|
| 478 |
+
return True
|
| 479 |
+
except (TypeError, AttributeError, pickle.PicklingError, RuntimeError) as e:
|
| 480 |
+
# Some queues cannot be pickled, e.g. stdlib queue.Queue with thread locks
|
| 481 |
+
# Clean up any partially written file
|
| 482 |
+
if file_path.exists():
|
| 483 |
+
file_path.unlink()
|
| 484 |
+
|
| 485 |
+
self.logger.warning(
|
| 486 |
+
f"Failed to serialize queue to {file_path}: {e}. Queue state will not be saved. "
|
| 487 |
+
f"Consider implementing to_list()/from_list() methods for custom serialization."
|
| 488 |
+
)
|
| 489 |
+
return False
|
| 490 |
+
|
| 491 |
+
def _deserialize_queue(self, queue: p.AbstractQueue, file_path: Path) -> None:
|
| 492 |
+
"""
|
| 493 |
+
Restore queue from a file.
|
| 494 |
+
|
| 495 |
+
Parameters
|
| 496 |
+
----------
|
| 497 |
+
queue: p.AbstractQueue
|
| 498 |
+
The queue to restore data into.
|
| 499 |
+
file_path: Path
|
| 500 |
+
Path to the saved queue file.
|
| 501 |
+
"""
|
| 502 |
+
if not file_path.exists():
|
| 503 |
+
return
|
| 504 |
+
|
| 505 |
+
try:
|
| 506 |
+
queue_data = torch.load(file_path, map_location="cpu", weights_only=False)
|
| 507 |
+
|
| 508 |
+
if queue_data["type"] == "list":
|
| 509 |
+
if hasattr(queue, "from_list") and callable(
|
| 510 |
+
getattr(queue, "from_list")
|
| 511 |
+
):
|
| 512 |
+
queue.from_list(queue_data["data"])
|
| 513 |
+
else:
|
| 514 |
+
# Manually populate queue from list
|
| 515 |
+
for item in queue_data["data"]:
|
| 516 |
+
queue.put(item)
|
| 517 |
+
elif queue_data["type"] == "torch":
|
| 518 |
+
# Restore from torch-saved queue - copy items to current queue
|
| 519 |
+
restored_queue = queue_data["data"]
|
| 520 |
+
# Copy items from restored queue to current queue
|
| 521 |
+
while not restored_queue.empty():
|
| 522 |
+
queue.put(restored_queue.get())
|
| 523 |
+
except Exception as e:
|
| 524 |
+
self.logger.warning(
|
| 525 |
+
f"Failed to deserialize queue from {file_path}: {e}. "
|
| 526 |
+
f"Queue will be empty."
|
| 527 |
+
)
|
| 528 |
+
|
| 529 |
+
def save_checkpoint(
|
| 530 |
+
self, path: str | Path | None = None, training_epoch: int | None = None
|
| 531 |
+
) -> Path | None:
|
| 532 |
+
"""
|
| 533 |
+
Save a checkpoint of the active learning experiment.
|
| 534 |
+
|
| 535 |
+
Saves AL orchestration state (configs, queues, step index, phase) and model weights.
|
| 536 |
+
Training-specific state (optimizer, scheduler) is handled by DefaultTrainingLoop
|
| 537 |
+
and saved to training_state.pt during training.
|
| 538 |
+
|
| 539 |
+
Parameters
|
| 540 |
+
----------
|
| 541 |
+
path: str | Path | None
|
| 542 |
+
Path to save checkpoint. If None, creates path based on current
|
| 543 |
+
AL step index and phase: log_dir/checkpoints/step_{idx}/{phase}/
|
| 544 |
+
training_epoch: int | None
|
| 545 |
+
Optional epoch number for mid-training checkpoints.
|
| 546 |
+
|
| 547 |
+
Returns
|
| 548 |
+
-------
|
| 549 |
+
Path | None
|
| 550 |
+
Checkpoint directory path, or None if checkpoint not saved (non-rank-0 in distributed).
|
| 551 |
+
"""
|
| 552 |
+
# Determine checkpoint directory
|
| 553 |
+
if path is None:
|
| 554 |
+
phase_name = self.current_phase if self.current_phase else "init"
|
| 555 |
+
checkpoint_dir = (
|
| 556 |
+
self.log_dir
|
| 557 |
+
/ "checkpoints"
|
| 558 |
+
/ f"step_{self.active_learning_step_idx}"
|
| 559 |
+
/ phase_name
|
| 560 |
+
)
|
| 561 |
+
if training_epoch is not None:
|
| 562 |
+
checkpoint_dir = checkpoint_dir / f"epoch_{training_epoch}"
|
| 563 |
+
else:
|
| 564 |
+
checkpoint_dir = Path(path)
|
| 565 |
+
|
| 566 |
+
# Create checkpoint directory
|
| 567 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 568 |
+
|
| 569 |
+
# Only rank 0 saves checkpoint in distributed setting
|
| 570 |
+
if self.dist_manager is not None and self.dist_manager.is_initialized():
|
| 571 |
+
if self.dist_manager.rank != 0:
|
| 572 |
+
return None
|
| 573 |
+
|
| 574 |
+
# Serialize configurations
|
| 575 |
+
driver_config_json = self.config.to_json()
|
| 576 |
+
strategies_config_dict = self.strategies_config.to_dict()
|
| 577 |
+
training_config_dict = (
|
| 578 |
+
self.training_config.to_dict() if self.training_config else None
|
| 579 |
+
)
|
| 580 |
+
|
| 581 |
+
# Serialize queue states to separate files
|
| 582 |
+
query_queue_file = checkpoint_dir / "query_queue.pt"
|
| 583 |
+
label_queue_file = checkpoint_dir / "label_queue.pt"
|
| 584 |
+
has_query_queue = self._serialize_queue(self.query_queue, query_queue_file)
|
| 585 |
+
has_label_queue = self._serialize_queue(self.label_queue, label_queue_file)
|
| 586 |
+
|
| 587 |
+
# Create checkpoint dataclass (only AL orchestration state)
|
| 588 |
+
checkpoint = ActiveLearningCheckpoint(
|
| 589 |
+
driver_config=driver_config_json,
|
| 590 |
+
strategies_config=strategies_config_dict,
|
| 591 |
+
active_learning_step_idx=self.active_learning_step_idx,
|
| 592 |
+
active_learning_phase=self.current_phase or p.ActiveLearningPhase.TRAINING,
|
| 593 |
+
physicsnemo_version=physicsnemo_version,
|
| 594 |
+
training_config=training_config_dict,
|
| 595 |
+
optimizer_state=None, # Training loop handles this
|
| 596 |
+
lr_scheduler_state=None, # Training loop handles this
|
| 597 |
+
has_query_queue=has_query_queue,
|
| 598 |
+
has_label_queue=has_label_queue,
|
| 599 |
+
)
|
| 600 |
+
|
| 601 |
+
# Add training epoch if in mid-training checkpoint
|
| 602 |
+
checkpoint_dict = {
|
| 603 |
+
"checkpoint": checkpoint,
|
| 604 |
+
}
|
| 605 |
+
if training_epoch is not None:
|
| 606 |
+
checkpoint_dict["training_epoch"] = training_epoch
|
| 607 |
+
|
| 608 |
+
# Save checkpoint metadata
|
| 609 |
+
checkpoint_path = checkpoint_dir / "checkpoint.pt"
|
| 610 |
+
torch.save(checkpoint_dict, checkpoint_path)
|
| 611 |
+
|
| 612 |
+
# Save model weights (separate from training state)
|
| 613 |
+
if isinstance(self.learner, Module):
|
| 614 |
+
model_name = (
|
| 615 |
+
self.learner.meta.name
|
| 616 |
+
if self.learner.meta
|
| 617 |
+
else self.learner.__class__.__name__
|
| 618 |
+
)
|
| 619 |
+
model_path = checkpoint_dir / f"{model_name}.mdlus"
|
| 620 |
+
self.learner.save(str(model_path))
|
| 621 |
+
elif hasattr(self.learner, "module") and isinstance(
|
| 622 |
+
self.learner.module, Module
|
| 623 |
+
):
|
| 624 |
+
# Unwrap DDP
|
| 625 |
+
model_name = (
|
| 626 |
+
self.learner.module.meta.name
|
| 627 |
+
if self.learner.module.meta
|
| 628 |
+
else self.learner.module.__class__.__name__
|
| 629 |
+
)
|
| 630 |
+
model_path = checkpoint_dir / f"{model_name}.mdlus"
|
| 631 |
+
self.learner.module.save(str(model_path))
|
| 632 |
+
else:
|
| 633 |
+
model_name = self.learner.__class__.__name__
|
| 634 |
+
model_path = checkpoint_dir / f"{model_name}.pt"
|
| 635 |
+
torch.save(self.learner.state_dict(), model_path)
|
| 636 |
+
|
| 637 |
+
# Update last checkpoint path
|
| 638 |
+
self._last_checkpoint_path = checkpoint_dir
|
| 639 |
+
|
| 640 |
+
# Log successful checkpoint save
|
| 641 |
+
self.logger.info(
|
| 642 |
+
f"Saved checkpoint at step {self.active_learning_step_idx}, "
|
| 643 |
+
f"phase {self.current_phase}: {checkpoint_dir}"
|
| 644 |
+
)
|
| 645 |
+
|
| 646 |
+
return checkpoint_dir
|
| 647 |
+
|
| 648 |
+
@classmethod
|
| 649 |
+
def load_checkpoint(
|
| 650 |
+
cls,
|
| 651 |
+
checkpoint_path: str | Path,
|
| 652 |
+
learner: Module | p.LearnerProtocol | None = None,
|
| 653 |
+
train_datapool: p.DataPool | None = None,
|
| 654 |
+
val_datapool: p.DataPool | None = None,
|
| 655 |
+
unlabeled_datapool: p.DataPool | None = None,
|
| 656 |
+
**kwargs: Any,
|
| 657 |
+
) -> Driver:
|
| 658 |
+
"""
|
| 659 |
+
Load a Driver instance from a checkpoint.
|
| 660 |
+
|
| 661 |
+
Given a checkpoint directory, this method will attempt to reconstruct
|
| 662 |
+
the driver and its associated components from the checkpoint. The
|
| 663 |
+
checkpoint path must contain a ``checkpoint.pt`` file, which contains
|
| 664 |
+
the metadata associated with the experiment.
|
| 665 |
+
|
| 666 |
+
Additional parameters that might not be serialized with the checkpointing
|
| 667 |
+
mechanism can/need to be provided to this method; for example when
|
| 668 |
+
using non-`physicsnemo.Module` learners, and any data pools associated
|
| 669 |
+
with the workflow.
|
| 670 |
+
|
| 671 |
+
.. important::
|
| 672 |
+
|
| 673 |
+
Currently, the strategy states are not reloaded from the checkpoint.
|
| 674 |
+
This will be addressed in a future patch, but for now it is recommended
|
| 675 |
+
to back up your strategy states (e.g. metrology records) manually
|
| 676 |
+
before restarting experiments.
|
| 677 |
+
|
| 678 |
+
Parameters
|
| 679 |
+
----------
|
| 680 |
+
checkpoint_path: str | Path
|
| 681 |
+
Path to checkpoint directory containing checkpoint.pt and model weights.
|
| 682 |
+
learner: Module | p.LearnerProtocol | None
|
| 683 |
+
Learner model to load weights into. If None, will attempt to
|
| 684 |
+
reconstruct from checkpoint (only works for physicsnemo.Module).
|
| 685 |
+
train_datapool: p.DataPool | None
|
| 686 |
+
Training datapool. Required if training_config exists in checkpoint.
|
| 687 |
+
val_datapool: p.DataPool | None
|
| 688 |
+
Validation datapool. Optional.
|
| 689 |
+
unlabeled_datapool: p.DataPool | None
|
| 690 |
+
Unlabeled datapool for query strategies. Optional.
|
| 691 |
+
**kwargs: Any
|
| 692 |
+
Additional keyword arguments to override config values.
|
| 693 |
+
|
| 694 |
+
Returns
|
| 695 |
+
-------
|
| 696 |
+
Driver
|
| 697 |
+
Reconstructed Driver instance ready to resume execution.
|
| 698 |
+
"""
|
| 699 |
+
checkpoint_path = Path(checkpoint_path)
|
| 700 |
+
|
| 701 |
+
# Load checkpoint file
|
| 702 |
+
checkpoint_file = checkpoint_path / "checkpoint.pt"
|
| 703 |
+
if not checkpoint_file.exists():
|
| 704 |
+
raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_file}")
|
| 705 |
+
|
| 706 |
+
checkpoint_dict = torch.load(
|
| 707 |
+
checkpoint_file, map_location="cpu", weights_only=False
|
| 708 |
+
)
|
| 709 |
+
checkpoint: ActiveLearningCheckpoint = checkpoint_dict["checkpoint"]
|
| 710 |
+
training_epoch = checkpoint_dict.get("training_epoch", None)
|
| 711 |
+
|
| 712 |
+
# Reconstruct configs
|
| 713 |
+
driver_config = DriverConfig.from_json(
|
| 714 |
+
checkpoint.driver_config, **kwargs.get("driver_config_overrides", {})
|
| 715 |
+
)
|
| 716 |
+
|
| 717 |
+
# TODO add strategy state loading from checkpoint
|
| 718 |
+
strategies_config = StrategiesConfig.from_dict(
|
| 719 |
+
checkpoint.strategies_config,
|
| 720 |
+
unlabeled_datapool=unlabeled_datapool,
|
| 721 |
+
**kwargs.get("strategies_config_overrides", {}),
|
| 722 |
+
)
|
| 723 |
+
|
| 724 |
+
training_config = None
|
| 725 |
+
if checkpoint.training_config is not None:
|
| 726 |
+
training_config = TrainingConfig.from_dict(
|
| 727 |
+
checkpoint.training_config,
|
| 728 |
+
train_datapool=train_datapool,
|
| 729 |
+
val_datapool=val_datapool,
|
| 730 |
+
**kwargs.get("training_config_overrides", {}),
|
| 731 |
+
)
|
| 732 |
+
|
| 733 |
+
# Load or reconstruct learner
|
| 734 |
+
if learner is None:
|
| 735 |
+
# Attempt to reconstruct from checkpoint (only for Module)
|
| 736 |
+
# Try to find any .mdlus file in the checkpoint directory
|
| 737 |
+
mdlus_files = list(checkpoint_path.glob("*.mdlus"))
|
| 738 |
+
if mdlus_files:
|
| 739 |
+
# Use the first .mdlus file found
|
| 740 |
+
model_path = mdlus_files[0]
|
| 741 |
+
learner = Module.from_checkpoint(str(model_path))
|
| 742 |
+
else:
|
| 743 |
+
raise ValueError(
|
| 744 |
+
"No learner provided and unable to reconstruct from checkpoint. "
|
| 745 |
+
"Please provide a learner instance."
|
| 746 |
+
)
|
| 747 |
+
else:
|
| 748 |
+
# Load model weights into provided learner
|
| 749 |
+
# Determine expected model filename based on learner type
|
| 750 |
+
if isinstance(learner, Module):
|
| 751 |
+
model_name = (
|
| 752 |
+
learner.meta.name if learner.meta else learner.__class__.__name__
|
| 753 |
+
)
|
| 754 |
+
model_path = checkpoint_path / f"{model_name}.mdlus"
|
| 755 |
+
if model_path.exists():
|
| 756 |
+
learner.load(str(model_path))
|
| 757 |
+
else:
|
| 758 |
+
# Fallback: try to find any .mdlus file
|
| 759 |
+
mdlus_files = list(checkpoint_path.glob("*.mdlus"))
|
| 760 |
+
if mdlus_files:
|
| 761 |
+
learner.load(str(mdlus_files[0]))
|
| 762 |
+
elif hasattr(learner, "module") and isinstance(learner.module, Module):
|
| 763 |
+
# Unwrap DDP
|
| 764 |
+
model_name = (
|
| 765 |
+
learner.module.meta.name
|
| 766 |
+
if learner.module.meta
|
| 767 |
+
else learner.module.__class__.__name__
|
| 768 |
+
)
|
| 769 |
+
model_path = checkpoint_path / f"{model_name}.mdlus"
|
| 770 |
+
if model_path.exists():
|
| 771 |
+
learner.module.load(str(model_path))
|
| 772 |
+
else:
|
| 773 |
+
# Fallback: try to find any .mdlus file
|
| 774 |
+
mdlus_files = list(checkpoint_path.glob("*.mdlus"))
|
| 775 |
+
if mdlus_files:
|
| 776 |
+
learner.module.load(str(mdlus_files[0]))
|
| 777 |
+
else:
|
| 778 |
+
# Non-Module learner: look for .pt file with class name
|
| 779 |
+
model_name = learner.__class__.__name__
|
| 780 |
+
model_path = checkpoint_path / f"{model_name}.pt"
|
| 781 |
+
if model_path.exists():
|
| 782 |
+
state_dict = torch.load(model_path, map_location="cpu")
|
| 783 |
+
learner.load_state_dict(state_dict)
|
| 784 |
+
else:
|
| 785 |
+
# Fallback: try to find any .pt file
|
| 786 |
+
pt_files = list(checkpoint_path.glob("*.pt"))
|
| 787 |
+
# Filter out checkpoint.pt and queue files
|
| 788 |
+
model_pt_files = [
|
| 789 |
+
f
|
| 790 |
+
for f in pt_files
|
| 791 |
+
if f.name
|
| 792 |
+
not in [
|
| 793 |
+
"checkpoint.pt",
|
| 794 |
+
"query_queue.pt",
|
| 795 |
+
"label_queue.pt",
|
| 796 |
+
"training_state.pt",
|
| 797 |
+
]
|
| 798 |
+
]
|
| 799 |
+
if model_pt_files:
|
| 800 |
+
state_dict = torch.load(model_pt_files[0], map_location="cpu")
|
| 801 |
+
learner.load_state_dict(state_dict)
|
| 802 |
+
|
| 803 |
+
# Instantiate Driver
|
| 804 |
+
driver = cls(
|
| 805 |
+
config=driver_config,
|
| 806 |
+
learner=learner,
|
| 807 |
+
strategies_config=strategies_config,
|
| 808 |
+
training_config=training_config,
|
| 809 |
+
inference_fn=kwargs.get("inference_fn", None),
|
| 810 |
+
)
|
| 811 |
+
|
| 812 |
+
# Restore active learning state
|
| 813 |
+
driver.active_learning_step_idx = checkpoint.active_learning_step_idx
|
| 814 |
+
driver.current_phase = checkpoint.active_learning_phase
|
| 815 |
+
driver._last_checkpoint_path = checkpoint_path
|
| 816 |
+
|
| 817 |
+
# Load training state (optimizer, scheduler) if training_config exists
|
| 818 |
+
# This delegates to the training loop's checkpoint loading logic
|
| 819 |
+
if driver.training_config is not None:
|
| 820 |
+
driver.configure_optimizer()
|
| 821 |
+
|
| 822 |
+
# Use training loop to load training state (including model weights again if needed)
|
| 823 |
+
from physicsnemo.active_learning.loop import DefaultTrainingLoop
|
| 824 |
+
|
| 825 |
+
DefaultTrainingLoop.load_training_checkpoint(
|
| 826 |
+
checkpoint_dir=checkpoint_path,
|
| 827 |
+
model=driver.learner,
|
| 828 |
+
optimizer=driver.optimizer,
|
| 829 |
+
lr_scheduler=driver.lr_scheduler
|
| 830 |
+
if hasattr(driver, "lr_scheduler")
|
| 831 |
+
else None,
|
| 832 |
+
)
|
| 833 |
+
|
| 834 |
+
# Restore queue states from separate files
|
| 835 |
+
if checkpoint.has_query_queue:
|
| 836 |
+
query_queue_file = checkpoint_path / "query_queue.pt"
|
| 837 |
+
driver._deserialize_queue(driver.query_queue, query_queue_file)
|
| 838 |
+
|
| 839 |
+
if checkpoint.has_label_queue:
|
| 840 |
+
label_queue_file = checkpoint_path / "label_queue.pt"
|
| 841 |
+
driver._deserialize_queue(driver.label_queue, label_queue_file)
|
| 842 |
+
|
| 843 |
+
driver.logger.info(
|
| 844 |
+
f"Loaded checkpoint from {checkpoint_path} at step "
|
| 845 |
+
f"{checkpoint.active_learning_step_idx}, phase {checkpoint.active_learning_phase}"
|
| 846 |
+
)
|
| 847 |
+
if training_epoch is not None:
|
| 848 |
+
driver.logger.info(f"Resuming from training epoch {training_epoch}")
|
| 849 |
+
|
| 850 |
+
return driver
|
| 851 |
+
|
| 852 |
+
def barrier(self) -> None:
|
| 853 |
+
"""
|
| 854 |
+
Wrapper to call barrier on the correct device.
|
| 855 |
+
|
| 856 |
+
Becomes a no-op if distributed is not initialized, otherwise
|
| 857 |
+
will attempt to read the local device ID from either the distributed manager
|
| 858 |
+
or the default device.
|
| 859 |
+
"""
|
| 860 |
+
if dist.is_initialized():
|
| 861 |
+
if (
|
| 862 |
+
self.dist_manager is not None
|
| 863 |
+
and self.dist_manager.device.type == "cuda"
|
| 864 |
+
):
|
| 865 |
+
dist.barrier(device_ids=[self.dist_manager.local_rank])
|
| 866 |
+
elif torch.get_default_device().type == "cuda":
|
| 867 |
+
# this might occur if distributed manager is not used
|
| 868 |
+
dist.barrier(device_ids=[torch.cuda.current_device()])
|
| 869 |
+
else:
|
| 870 |
+
dist.barrier()
|
| 871 |
+
|
| 872 |
+
def _configure_model(self) -> None:
|
| 873 |
+
"""
|
| 874 |
+
Method that encapsulates all the logic for preparing the model
|
| 875 |
+
ahead of time.
|
| 876 |
+
|
| 877 |
+
If the distributed manager has been configured and initialized
|
| 878 |
+
with a world size greater than 1, then we wrap the model in DDP.
|
| 879 |
+
Otherwise, we simply move the model to the correct device.
|
| 880 |
+
|
| 881 |
+
After the model has been moved to device, we configure the optimizer
|
| 882 |
+
and learning rate scheduler if training is enabled.
|
| 883 |
+
"""
|
| 884 |
+
if self.dist_manager is not None and self.dist_manager.is_initialized():
|
| 885 |
+
if self.dist_manager.world_size > 1 and not isinstance(
|
| 886 |
+
self.learner, DistributedDataParallel
|
| 887 |
+
):
|
| 888 |
+
# wrap the model in DDP
|
| 889 |
+
self.learner = torch.nn.parallel.DistributedDataParallel(
|
| 890 |
+
self.learner,
|
| 891 |
+
device_ids=[self.dist_manager.local_rank],
|
| 892 |
+
output_device=self.dist_manager.device,
|
| 893 |
+
broadcast_buffers=self.dist_manager.broadcast_buffers,
|
| 894 |
+
find_unused_parameters=self.dist_manager.find_unused_parameters,
|
| 895 |
+
)
|
| 896 |
+
else:
|
| 897 |
+
if self.config.device is not None:
|
| 898 |
+
self.learner = self.learner.to(self.config.device, self.config.dtype)
|
| 899 |
+
# assume all device management is done via the dist_manager, so at this
|
| 900 |
+
# point the model is on the correct device and we can set up the optimizer
|
| 901 |
+
# if we intend to train
|
| 902 |
+
if not self.config.skip_training and not self.is_optimizer_configured:
|
| 903 |
+
self.configure_optimizer()
|
| 904 |
+
if self.is_optimizer_configured and self.config.reset_optim_states:
|
| 905 |
+
self.optimizer.load_state_dict(self._original_optim_state)
|
| 906 |
+
|
| 907 |
+
def _get_phase_index(self, phase: p.ActiveLearningPhase | None) -> int:
|
| 908 |
+
"""
|
| 909 |
+
Get index of phase in execution order.
|
| 910 |
+
|
| 911 |
+
Parameters
|
| 912 |
+
----------
|
| 913 |
+
phase: p.ActiveLearningPhase | None
|
| 914 |
+
Phase to find index for. If None, returns 0 (start from beginning).
|
| 915 |
+
|
| 916 |
+
Returns
|
| 917 |
+
-------
|
| 918 |
+
int
|
| 919 |
+
Index in _PHASE_ORDER (0-3).
|
| 920 |
+
"""
|
| 921 |
+
if phase is None:
|
| 922 |
+
return 0
|
| 923 |
+
try:
|
| 924 |
+
return self._PHASE_ORDER.index(phase)
|
| 925 |
+
except ValueError:
|
| 926 |
+
self.logger.warning(
|
| 927 |
+
f"Unknown phase {phase}, defaulting to start from beginning"
|
| 928 |
+
)
|
| 929 |
+
return 0
|
| 930 |
+
|
| 931 |
+
def _build_phase_queue(
|
| 932 |
+
self,
|
| 933 |
+
train_step_fn: p.TrainingProtocol | None,
|
| 934 |
+
validate_step_fn: p.ValidationProtocol | None,
|
| 935 |
+
args: tuple,
|
| 936 |
+
kwargs: dict,
|
| 937 |
+
) -> list[Any]:
|
| 938 |
+
"""
|
| 939 |
+
Build list of phase functions to execute for this AL step.
|
| 940 |
+
|
| 941 |
+
If current_phase is set (e.g., from checkpoint), only phases at or after
|
| 942 |
+
current_phase are included. Otherwise, all non-skipped phases are included.
|
| 943 |
+
|
| 944 |
+
Parameters
|
| 945 |
+
----------
|
| 946 |
+
train_step_fn: p.TrainingProtocol | None
|
| 947 |
+
Training function to pass to training phase.
|
| 948 |
+
validate_step_fn: p.ValidationProtocol | None
|
| 949 |
+
Validation function to pass to training phase.
|
| 950 |
+
args: tuple
|
| 951 |
+
Additional arguments to pass to phase methods.
|
| 952 |
+
kwargs: dict
|
| 953 |
+
Additional keyword arguments to pass to phase methods.
|
| 954 |
+
|
| 955 |
+
Returns
|
| 956 |
+
-------
|
| 957 |
+
list[Callable]
|
| 958 |
+
Queue of phase functions to execute in order.
|
| 959 |
+
"""
|
| 960 |
+
# Define all possible phases with their execution conditions
|
| 961 |
+
all_phases = [
|
| 962 |
+
(
|
| 963 |
+
p.ActiveLearningPhase.TRAINING,
|
| 964 |
+
lambda: self._training_phase(
|
| 965 |
+
train_step_fn, validate_step_fn, *args, **kwargs
|
| 966 |
+
),
|
| 967 |
+
not self.config.skip_training,
|
| 968 |
+
),
|
| 969 |
+
(
|
| 970 |
+
p.ActiveLearningPhase.METROLOGY,
|
| 971 |
+
lambda: self._metrology_phase(*args, **kwargs),
|
| 972 |
+
not self.config.skip_metrology,
|
| 973 |
+
),
|
| 974 |
+
(
|
| 975 |
+
p.ActiveLearningPhase.QUERY,
|
| 976 |
+
lambda: self._query_phase(*args, **kwargs),
|
| 977 |
+
True, # Query phase always runs
|
| 978 |
+
),
|
| 979 |
+
(
|
| 980 |
+
p.ActiveLearningPhase.LABELING,
|
| 981 |
+
lambda: self._labeling_phase(*args, **kwargs),
|
| 982 |
+
not self.config.skip_labeling,
|
| 983 |
+
),
|
| 984 |
+
]
|
| 985 |
+
|
| 986 |
+
# Find starting index based on current_phase (resume point)
|
| 987 |
+
start_idx = self._get_phase_index(self.current_phase)
|
| 988 |
+
|
| 989 |
+
if start_idx > 0:
|
| 990 |
+
self.logger.info(
|
| 991 |
+
f"Resuming AL step {self.active_learning_step_idx} from "
|
| 992 |
+
f"{self.current_phase}"
|
| 993 |
+
)
|
| 994 |
+
|
| 995 |
+
# Build queue: only phases from start_idx onwards that should run
|
| 996 |
+
phase_queue = []
|
| 997 |
+
for idx, (phase, phase_fn, should_run) in enumerate(all_phases):
|
| 998 |
+
# Skip phases before current_phase
|
| 999 |
+
if idx < start_idx:
|
| 1000 |
+
self.logger.debug(
|
| 1001 |
+
f"Skipping {phase} (already completed in this AL step)"
|
| 1002 |
+
)
|
| 1003 |
+
continue
|
| 1004 |
+
|
| 1005 |
+
# Add phase to queue if not skipped by config
|
| 1006 |
+
if should_run:
|
| 1007 |
+
phase_queue.append(phase_fn)
|
| 1008 |
+
else:
|
| 1009 |
+
self.logger.debug(f"Skipping {phase} (disabled in config)")
|
| 1010 |
+
|
| 1011 |
+
return phase_queue
|
| 1012 |
+
|
| 1013 |
+
def _construct_dataloader(
|
| 1014 |
+
self, pool: p.DataPool, shuffle: bool = False, drop_last: bool = False
|
| 1015 |
+
) -> DataLoader:
|
| 1016 |
+
"""
|
| 1017 |
+
Helper method to construct a data loader for a given data pool.
|
| 1018 |
+
|
| 1019 |
+
In the case that a distributed manager was provided, then a distributed
|
| 1020 |
+
sampler will be used, which will be bound to the current rank.
|
| 1021 |
+
Otherwise, a regular sampler will be used. Similarly, if your data
|
| 1022 |
+
structure requires a specialized function to construct batches,
|
| 1023 |
+
then this function can be provided via the `collate_fn` argument.
|
| 1024 |
+
|
| 1025 |
+
Parameters
|
| 1026 |
+
----------
|
| 1027 |
+
pool: p.DataPool
|
| 1028 |
+
The data pool to construct a data loader for.
|
| 1029 |
+
shuffle: bool = False
|
| 1030 |
+
Whether to shuffle the data.
|
| 1031 |
+
drop_last: bool = False
|
| 1032 |
+
Whether to drop the last batch if it is not complete.
|
| 1033 |
+
|
| 1034 |
+
Returns
|
| 1035 |
+
-------
|
| 1036 |
+
DataLoader
|
| 1037 |
+
The constructed data loader.
|
| 1038 |
+
"""
|
| 1039 |
+
# if a distributed manager was omitted, then we assume single process
|
| 1040 |
+
if self.dist_manager is not None and self.dist_manager.is_initialized():
|
| 1041 |
+
sampler = DistributedSampler(
|
| 1042 |
+
pool,
|
| 1043 |
+
num_replicas=self.dist_manager.world_size,
|
| 1044 |
+
rank=self.dist_manager.rank,
|
| 1045 |
+
shuffle=shuffle,
|
| 1046 |
+
drop_last=drop_last,
|
| 1047 |
+
)
|
| 1048 |
+
# set to None, because sampler will handle instead
|
| 1049 |
+
shuffle = None
|
| 1050 |
+
else:
|
| 1051 |
+
sampler = None
|
| 1052 |
+
# fully spec out the data loader
|
| 1053 |
+
pin_memory = False
|
| 1054 |
+
if self.dist_manager is not None and self.dist_manager.is_initialized():
|
| 1055 |
+
if self.dist_manager.device.type == "cuda":
|
| 1056 |
+
pin_memory = True
|
| 1057 |
+
loader = DataLoader(
|
| 1058 |
+
pool,
|
| 1059 |
+
shuffle=shuffle,
|
| 1060 |
+
sampler=sampler,
|
| 1061 |
+
collate_fn=self.config.collate_fn,
|
| 1062 |
+
batch_size=self.config.batch_size,
|
| 1063 |
+
num_workers=self.config.num_dataloader_workers,
|
| 1064 |
+
persistent_workers=self.config.num_dataloader_workers > 0,
|
| 1065 |
+
pin_memory=pin_memory,
|
| 1066 |
+
)
|
| 1067 |
+
return loader
|
| 1068 |
+
|
| 1069 |
+
def active_learning_step(
|
| 1070 |
+
self,
|
| 1071 |
+
train_step_fn: p.TrainingProtocol | None = None,
|
| 1072 |
+
validate_step_fn: p.ValidationProtocol | None = None,
|
| 1073 |
+
*args: Any,
|
| 1074 |
+
**kwargs: Any,
|
| 1075 |
+
) -> None:
|
| 1076 |
+
"""
|
| 1077 |
+
Performs a single active learning iteration.
|
| 1078 |
+
|
| 1079 |
+
This method will perform the following sequence of steps:
|
| 1080 |
+
1. Train the model stored in ``Driver.learner`` by creating data loaders
|
| 1081 |
+
with ``Driver.train_datapool`` and ``Driver.val_datapool``.
|
| 1082 |
+
2. Run the metrology strategies stored in ``Driver.metrology_strategies``.
|
| 1083 |
+
3. Run the query strategies stored in ``Driver.query_strategies``, if available.
|
| 1084 |
+
4. Run the labeling strategy stored in ``Driver.label_strategy``, if available.
|
| 1085 |
+
|
| 1086 |
+
When entering each stage, we check to ensure all components necessary for the
|
| 1087 |
+
minimum function for that stage are available before proceeding.
|
| 1088 |
+
|
| 1089 |
+
If current_phase is set (e.g., from checkpoint resumption), only phases at
|
| 1090 |
+
or after current_phase will be executed. After completing all phases,
|
| 1091 |
+
current_phase is reset to None for the next AL step.
|
| 1092 |
+
|
| 1093 |
+
Parameters
|
| 1094 |
+
----------
|
| 1095 |
+
train_step_fn: p.TrainingProtocol | None = None
|
| 1096 |
+
The training function to use for training. If not provided, then the
|
| 1097 |
+
``Driver.train_loop_fn`` will be used.
|
| 1098 |
+
validate_step_fn: p.ValidationProtocol | None = None
|
| 1099 |
+
The validation function to use for validation. If not provided, then
|
| 1100 |
+
validation will not be performed.
|
| 1101 |
+
args: Any
|
| 1102 |
+
Additional arguments to pass to the method. These will be passed to the
|
| 1103 |
+
training loop, metrology strategies, query strategies, and labeling strategies.
|
| 1104 |
+
kwargs: Any
|
| 1105 |
+
Additional keyword arguments to pass to the method. These will be passed to the
|
| 1106 |
+
training loop, metrology strategies, query strategies, and labeling strategies.
|
| 1107 |
+
|
| 1108 |
+
Raises
|
| 1109 |
+
------
|
| 1110 |
+
ValueError
|
| 1111 |
+
If any of the required components for a stage are not available.
|
| 1112 |
+
"""
|
| 1113 |
+
self._setup_active_learning_step()
|
| 1114 |
+
|
| 1115 |
+
# Build queue of phase functions based on current_phase
|
| 1116 |
+
phase_queue = self._build_phase_queue(
|
| 1117 |
+
train_step_fn, validate_step_fn, args, kwargs
|
| 1118 |
+
)
|
| 1119 |
+
|
| 1120 |
+
# Execute each phase in order (de-populate queue)
|
| 1121 |
+
for phase_fn in phase_queue:
|
| 1122 |
+
phase_fn()
|
| 1123 |
+
|
| 1124 |
+
# Reset current_phase after completing all phases in this AL step
|
| 1125 |
+
self.current_phase = None
|
| 1126 |
+
|
| 1127 |
+
self.logger.debug("Entering barrier for synchronization.")
|
| 1128 |
+
self.barrier()
|
| 1129 |
+
self.active_learning_step_idx += 1
|
| 1130 |
+
self.logger.info(
|
| 1131 |
+
f"Completed active learning step {self.active_learning_step_idx}"
|
| 1132 |
+
)
|
| 1133 |
+
|
| 1134 |
+
def _setup_active_learning_step(self) -> None:
|
| 1135 |
+
"""Initialize distributed manager and configure model for the active learning step."""
|
| 1136 |
+
if self.dist_manager is not None and not self.dist_manager.is_initialized():
|
| 1137 |
+
self.logger.info(
|
| 1138 |
+
"Distributed manager configured but not initialized; initializing."
|
| 1139 |
+
)
|
| 1140 |
+
self.dist_manager.initialize()
|
| 1141 |
+
self._configure_model()
|
| 1142 |
+
self.logger.info(
|
| 1143 |
+
f"Starting active learning step {self.active_learning_step_idx}"
|
| 1144 |
+
)
|
| 1145 |
+
|
| 1146 |
+
def _training_phase(
|
| 1147 |
+
self,
|
| 1148 |
+
train_step_fn: p.TrainingProtocol | None,
|
| 1149 |
+
validate_step_fn: p.ValidationProtocol | None,
|
| 1150 |
+
*args: Any,
|
| 1151 |
+
**kwargs: Any,
|
| 1152 |
+
) -> None:
|
| 1153 |
+
"""Execute the training phase of the active learning step."""
|
| 1154 |
+
self._validate_training_requirements(train_step_fn, validate_step_fn)
|
| 1155 |
+
|
| 1156 |
+
# don't need to barrier because it'll be done at the end of training anyway
|
| 1157 |
+
with self._phase_context("training", call_barrier=False):
|
| 1158 |
+
# Note: Training phase checkpointing is handled by the training loop itself
|
| 1159 |
+
# during epoch execution based on model_checkpoint_frequency
|
| 1160 |
+
|
| 1161 |
+
train_loader = self._construct_dataloader(self.train_datapool, shuffle=True)
|
| 1162 |
+
self.logger.info(
|
| 1163 |
+
f"There are {len(train_loader)} batches in the training loader."
|
| 1164 |
+
)
|
| 1165 |
+
val_loader = None
|
| 1166 |
+
if self.val_datapool is not None:
|
| 1167 |
+
if validate_step_fn or hasattr(self.learner, "validation_step"):
|
| 1168 |
+
val_loader = self._construct_dataloader(
|
| 1169 |
+
self.val_datapool, shuffle=False
|
| 1170 |
+
)
|
| 1171 |
+
else:
|
| 1172 |
+
self.logger.warning(
|
| 1173 |
+
"Validation data is available, but no `validate_step_fn` "
|
| 1174 |
+
"or `validation_step` method in Learner is provided."
|
| 1175 |
+
)
|
| 1176 |
+
# if a fine-tuning lr is provided, adjust it after the first iteration
|
| 1177 |
+
if (
|
| 1178 |
+
self.config.fine_tuning_lr is not None
|
| 1179 |
+
and self.active_learning_step_idx > 0
|
| 1180 |
+
):
|
| 1181 |
+
self.optimizer.param_groups[0]["lr"] = self.config.fine_tuning_lr
|
| 1182 |
+
|
| 1183 |
+
# Determine max epochs to train for this AL step
|
| 1184 |
+
if self.active_learning_step_idx > 0:
|
| 1185 |
+
target_max_epochs = self.training_config.max_fine_tuning_epochs
|
| 1186 |
+
else:
|
| 1187 |
+
target_max_epochs = self.training_config.max_training_epochs
|
| 1188 |
+
|
| 1189 |
+
# Check if resuming from mid-training checkpoint
|
| 1190 |
+
start_epoch = 1
|
| 1191 |
+
epochs_to_train = target_max_epochs
|
| 1192 |
+
|
| 1193 |
+
if self._last_checkpoint_path and self._last_checkpoint_path.exists():
|
| 1194 |
+
training_state_path = self._last_checkpoint_path / "training_state.pt"
|
| 1195 |
+
if training_state_path.exists():
|
| 1196 |
+
training_state = torch.load(
|
| 1197 |
+
training_state_path, map_location="cpu", weights_only=False
|
| 1198 |
+
)
|
| 1199 |
+
last_completed_epoch = training_state.get("training_epoch", 0)
|
| 1200 |
+
if last_completed_epoch > 0:
|
| 1201 |
+
start_epoch = last_completed_epoch + 1
|
| 1202 |
+
epochs_to_train = target_max_epochs - last_completed_epoch
|
| 1203 |
+
self.logger.info(
|
| 1204 |
+
f"Resuming training from epoch {start_epoch} "
|
| 1205 |
+
f"({epochs_to_train} epochs remaining)"
|
| 1206 |
+
)
|
| 1207 |
+
|
| 1208 |
+
# Skip training if all epochs already completed
|
| 1209 |
+
if epochs_to_train <= 0:
|
| 1210 |
+
self.logger.info(
|
| 1211 |
+
f"Training already complete ({target_max_epochs} epochs), "
|
| 1212 |
+
f"skipping training phase"
|
| 1213 |
+
)
|
| 1214 |
+
return
|
| 1215 |
+
|
| 1216 |
+
device = (
|
| 1217 |
+
self.dist_manager.device
|
| 1218 |
+
if self.dist_manager is not None
|
| 1219 |
+
else self.config.device
|
| 1220 |
+
)
|
| 1221 |
+
dtype = self.config.dtype
|
| 1222 |
+
|
| 1223 |
+
# Set checkpoint directory and frequency on training loop
|
| 1224 |
+
# This allows the training loop to handle training state checkpointing internally
|
| 1225 |
+
if hasattr(self.train_loop_fn, "checkpoint_base_dir") and hasattr(
|
| 1226 |
+
self.train_loop_fn, "checkpoint_frequency"
|
| 1227 |
+
):
|
| 1228 |
+
# Checkpoint base is the current AL step's training directory
|
| 1229 |
+
checkpoint_base = (
|
| 1230 |
+
self.log_dir
|
| 1231 |
+
/ "checkpoints"
|
| 1232 |
+
/ f"step_{self.active_learning_step_idx}"
|
| 1233 |
+
/ "training"
|
| 1234 |
+
)
|
| 1235 |
+
self.train_loop_fn.checkpoint_base_dir = checkpoint_base
|
| 1236 |
+
self.train_loop_fn.checkpoint_frequency = (
|
| 1237 |
+
self.config.model_checkpoint_frequency
|
| 1238 |
+
)
|
| 1239 |
+
|
| 1240 |
+
self.train_loop_fn(
|
| 1241 |
+
self.learner,
|
| 1242 |
+
self.optimizer,
|
| 1243 |
+
train_step_fn=train_step_fn,
|
| 1244 |
+
validate_step_fn=validate_step_fn,
|
| 1245 |
+
train_dataloader=train_loader,
|
| 1246 |
+
validation_dataloader=val_loader,
|
| 1247 |
+
lr_scheduler=self.lr_scheduler,
|
| 1248 |
+
max_epochs=epochs_to_train, # Only remaining epochs
|
| 1249 |
+
device=device,
|
| 1250 |
+
dtype=dtype,
|
| 1251 |
+
**kwargs,
|
| 1252 |
+
)
|
| 1253 |
+
|
| 1254 |
+
def _metrology_phase(self, *args: Any, **kwargs: Any) -> None:
|
| 1255 |
+
"""Execute the metrology phase of the active learning step."""
|
| 1256 |
+
|
| 1257 |
+
with self._phase_context("metrology"):
|
| 1258 |
+
for strategy in self.metrology_strategies:
|
| 1259 |
+
self.logger.info(
|
| 1260 |
+
f"Running metrology strategy: {strategy.__class__.__name__}"
|
| 1261 |
+
)
|
| 1262 |
+
strategy(*args, **kwargs)
|
| 1263 |
+
self.logger.info(
|
| 1264 |
+
f"Completed metrics for strategy: {strategy.__class__.__name__}"
|
| 1265 |
+
)
|
| 1266 |
+
strategy.serialize_records(*args, **kwargs)
|
| 1267 |
+
|
| 1268 |
+
def _query_phase(self, *args: Any, **kwargs: Any) -> None:
|
| 1269 |
+
"""Execute the query phase of the active learning step."""
|
| 1270 |
+
with self._phase_context("query"):
|
| 1271 |
+
for strategy in self.query_strategies:
|
| 1272 |
+
self.logger.info(
|
| 1273 |
+
f"Running query strategy: {strategy.__class__.__name__}"
|
| 1274 |
+
)
|
| 1275 |
+
strategy(self.query_queue, *args, **kwargs)
|
| 1276 |
+
|
| 1277 |
+
if self.query_queue.empty():
|
| 1278 |
+
self.logger.warning(
|
| 1279 |
+
"Querying strategies produced no samples this iteration."
|
| 1280 |
+
)
|
| 1281 |
+
|
| 1282 |
+
def _labeling_phase(self, *args: Any, **kwargs: Any) -> None:
|
| 1283 |
+
"""Execute the labeling phase of the active learning step."""
|
| 1284 |
+
self._validate_labeling_requirements()
|
| 1285 |
+
|
| 1286 |
+
if self.query_queue.empty():
|
| 1287 |
+
self.logger.warning("No samples to label. Skipping labeling phase.")
|
| 1288 |
+
return
|
| 1289 |
+
|
| 1290 |
+
with self._phase_context("labeling"):
|
| 1291 |
+
try:
|
| 1292 |
+
self.label_strategy(self.query_queue, self.label_queue, *args, **kwargs)
|
| 1293 |
+
except Exception as e:
|
| 1294 |
+
self.logger.error(f"Exception encountered during labeling: {e}")
|
| 1295 |
+
self.logger.info("Labeling completed. Now appending to training pool.")
|
| 1296 |
+
|
| 1297 |
+
# TODO this is done serially, could be improved with batched writes
|
| 1298 |
+
sample_counter = 0
|
| 1299 |
+
while not self.label_queue.empty():
|
| 1300 |
+
self.train_datapool.append(self.label_queue.get())
|
| 1301 |
+
sample_counter += 1
|
| 1302 |
+
self.logger.info(f"Appended {sample_counter} samples to training pool.")
|
| 1303 |
+
|
| 1304 |
+
def _validate_training_requirements(
|
| 1305 |
+
self,
|
| 1306 |
+
train_step_fn: p.TrainingProtocol | None,
|
| 1307 |
+
validate_step_fn: p.ValidationProtocol | None,
|
| 1308 |
+
) -> None:
|
| 1309 |
+
"""Validate that all required components for training are available."""
|
| 1310 |
+
if self.training_config is None:
|
| 1311 |
+
raise ValueError(
|
| 1312 |
+
"`training_config` must be provided if `skip_training` is False."
|
| 1313 |
+
)
|
| 1314 |
+
if self.train_loop_fn is None:
|
| 1315 |
+
raise ValueError("`train_loop_fn` must be provided in training_config.")
|
| 1316 |
+
if self.train_datapool is None:
|
| 1317 |
+
raise ValueError("`train_datapool` must be provided in training_config.")
|
| 1318 |
+
if not train_step_fn and not hasattr(self.learner, "training_step"):
|
| 1319 |
+
raise ValueError(
|
| 1320 |
+
"`train_step_fn` must be provided if the model does not implement "
|
| 1321 |
+
"the `training_step` method."
|
| 1322 |
+
)
|
| 1323 |
+
if validate_step_fn and self.val_datapool is None:
|
| 1324 |
+
raise ValueError(
|
| 1325 |
+
"`val_datapool` must be provided in training_config if "
|
| 1326 |
+
"`validate_step_fn` is provided."
|
| 1327 |
+
)
|
| 1328 |
+
|
| 1329 |
+
def _validate_labeling_requirements(self) -> None:
|
| 1330 |
+
"""Validate that all required components for labeling are available."""
|
| 1331 |
+
if self.label_strategy is None:
|
| 1332 |
+
raise ValueError(
|
| 1333 |
+
"`label_strategy` must be provided in strategies_config if "
|
| 1334 |
+
"`skip_labeling` is False."
|
| 1335 |
+
)
|
| 1336 |
+
if self.training_config is None or self.train_datapool is None:
|
| 1337 |
+
raise ValueError(
|
| 1338 |
+
"`train_datapool` must be provided in training_config for "
|
| 1339 |
+
"labeling, as data will be appended to it."
|
| 1340 |
+
)
|
| 1341 |
+
|
| 1342 |
+
@contextmanager
|
| 1343 |
+
def _phase_context(
|
| 1344 |
+
self, phase_name: p.ActiveLearningPhase, call_barrier: bool = True
|
| 1345 |
+
) -> Generator[None, Any, None]:
|
| 1346 |
+
"""
|
| 1347 |
+
Context manager for consistent phase tracking, error handling, and synchronization.
|
| 1348 |
+
|
| 1349 |
+
Sets the current phase for logging context, handles exceptions,
|
| 1350 |
+
and synchronizes distributed workers with a barrier. Also triggers
|
| 1351 |
+
checkpoint saves at the start of each phase if configured.
|
| 1352 |
+
|
| 1353 |
+
Parameters
|
| 1354 |
+
----------
|
| 1355 |
+
phase_name: p.ActiveLearningPhase
|
| 1356 |
+
A discrete phase of the active learning workflow.
|
| 1357 |
+
call_barrier: bool
|
| 1358 |
+
Whether to call barrier for synchronization at the end.
|
| 1359 |
+
"""
|
| 1360 |
+
self.current_phase = phase_name
|
| 1361 |
+
|
| 1362 |
+
# Save checkpoint at START of phase if configured
|
| 1363 |
+
# Exception: training phase handles checkpointing internally
|
| 1364 |
+
if phase_name != p.ActiveLearningPhase.TRAINING:
|
| 1365 |
+
should_checkpoint = getattr(
|
| 1366 |
+
self.config, f"checkpoint_on_{phase_name}", False
|
| 1367 |
+
)
|
| 1368 |
+
# Check if we should checkpoint based on interval
|
| 1369 |
+
if should_checkpoint and self._should_checkpoint_at_step():
|
| 1370 |
+
self.save_checkpoint()
|
| 1371 |
+
|
| 1372 |
+
try:
|
| 1373 |
+
yield
|
| 1374 |
+
except Exception as e:
|
| 1375 |
+
self.logger.error(f"Exception encountered during {phase_name}: {e}")
|
| 1376 |
+
raise
|
| 1377 |
+
finally:
|
| 1378 |
+
if call_barrier:
|
| 1379 |
+
self.logger.debug("Entering barrier for synchronization.")
|
| 1380 |
+
self.barrier()
|
| 1381 |
+
|
| 1382 |
+
def run(
|
| 1383 |
+
self,
|
| 1384 |
+
train_step_fn: p.TrainingProtocol | None = None,
|
| 1385 |
+
validate_step_fn: p.ValidationProtocol | None = None,
|
| 1386 |
+
*args: Any,
|
| 1387 |
+
**kwargs: Any,
|
| 1388 |
+
) -> None:
|
| 1389 |
+
"""
|
| 1390 |
+
Runs the active learning loop until the maximum number of
|
| 1391 |
+
active learning steps is reached.
|
| 1392 |
+
|
| 1393 |
+
Parameters
|
| 1394 |
+
----------
|
| 1395 |
+
train_step_fn: p.TrainingProtocol | None = None
|
| 1396 |
+
The training function to use for training. If not provided, then the
|
| 1397 |
+
``Driver.train_loop_fn`` will be used.
|
| 1398 |
+
validate_step_fn: p.ValidationProtocol | None = None
|
| 1399 |
+
The validation function to use for validation. If not provided, then
|
| 1400 |
+
validation will not be performed.
|
| 1401 |
+
args: Any
|
| 1402 |
+
Additional arguments to pass to the method. These will be passed to the
|
| 1403 |
+
training loop, metrology strategies, query strategies, and labeling strategies.
|
| 1404 |
+
kwargs: Any
|
| 1405 |
+
Additional keyword arguments to pass to the method. These will be passed to the
|
| 1406 |
+
training loop, metrology strategies, query strategies, and labeling strategies.
|
| 1407 |
+
"""
|
| 1408 |
+
# TODO: refactor initialization logic here instead of inside the step
|
| 1409 |
+
while self.active_learning_step_idx < self.config.max_active_learning_steps:
|
| 1410 |
+
self.active_learning_step(
|
| 1411 |
+
train_step_fn=train_step_fn,
|
| 1412 |
+
validate_step_fn=validate_step_fn,
|
| 1413 |
+
*args,
|
| 1414 |
+
**kwargs,
|
| 1415 |
+
)
|
| 1416 |
+
|
| 1417 |
+
def __call__(
|
| 1418 |
+
self,
|
| 1419 |
+
train_step_fn: p.TrainingProtocol | None = None,
|
| 1420 |
+
validate_step_fn: p.ValidationProtocol | None = None,
|
| 1421 |
+
*args: Any,
|
| 1422 |
+
**kwargs: Any,
|
| 1423 |
+
) -> None:
|
| 1424 |
+
"""
|
| 1425 |
+
Provides syntactic sugar for running the active learning loop.
|
| 1426 |
+
|
| 1427 |
+
Calls ``Driver.run`` internally.
|
| 1428 |
+
|
| 1429 |
+
Parameters
|
| 1430 |
+
----------
|
| 1431 |
+
train_step_fn: p.TrainingProtocol | None = None
|
| 1432 |
+
The training function to use for training. If not provided, then the
|
| 1433 |
+
``Driver.train_loop_fn`` will be used.
|
| 1434 |
+
validate_step_fn: p.ValidationProtocol | None = None
|
| 1435 |
+
The validation function to use for validation. If not provided, then
|
| 1436 |
+
validation will not be performed.
|
| 1437 |
+
args: Any
|
| 1438 |
+
Additional arguments to pass to the method. These will be passed to the
|
| 1439 |
+
training loop, metrology strategies, query strategies, and labeling strategies.
|
| 1440 |
+
kwargs: Any
|
| 1441 |
+
Additional keyword arguments to pass to the method. These will be passed to the
|
| 1442 |
+
training loop, metrology strategies, query strategies, and labeling strategies.
|
| 1443 |
+
"""
|
| 1444 |
+
self.run(
|
| 1445 |
+
train_step_fn=train_step_fn,
|
| 1446 |
+
validate_step_fn=validate_step_fn,
|
| 1447 |
+
*args,
|
| 1448 |
+
**kwargs,
|
| 1449 |
+
)
|
physics_mcp/source/physicsnemo/active_learning/logger.py
ADDED
|
@@ -0,0 +1,330 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import json
|
| 20 |
+
import logging
|
| 21 |
+
from contextlib import contextmanager
|
| 22 |
+
from datetime import datetime
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
from threading import local
|
| 25 |
+
from typing import Any
|
| 26 |
+
|
| 27 |
+
try:
|
| 28 |
+
from termcolor import colored
|
| 29 |
+
except ImportError:
|
| 30 |
+
colored = None
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# Thread-local storage for context information
|
| 34 |
+
_context_storage = local()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
class ActiveLearningLoggerAdapter(logging.LoggerAdapter):
|
| 38 |
+
"""Logger adapter that automatically includes active learning iteration context.
|
| 39 |
+
|
| 40 |
+
This adapter automatically adds iteration information to log messages
|
| 41 |
+
by accessing the driver's current iteration state.
|
| 42 |
+
"""
|
| 43 |
+
|
| 44 |
+
def __init__(self, logger: logging.Logger, driver_ref: Any = None):
|
| 45 |
+
"""Initialize the adapter with a logger and optional driver reference.
|
| 46 |
+
|
| 47 |
+
Parameters
|
| 48 |
+
----------
|
| 49 |
+
logger : logging.Logger
|
| 50 |
+
The underlying logger to adapt
|
| 51 |
+
driver_ref : Any, optional
|
| 52 |
+
Reference to the driver object to get iteration context from
|
| 53 |
+
"""
|
| 54 |
+
super().__init__(logger, {})
|
| 55 |
+
self.driver_ref = driver_ref
|
| 56 |
+
|
| 57 |
+
def process(self, msg: str, kwargs: dict[str, Any]) -> tuple[str, dict[str, Any]]:
|
| 58 |
+
"""Process the log message to add iteration, run ID, and phase context.
|
| 59 |
+
|
| 60 |
+
Parameters
|
| 61 |
+
----------
|
| 62 |
+
msg : str
|
| 63 |
+
The log message
|
| 64 |
+
kwargs : dict[str, Any]
|
| 65 |
+
Additional keyword arguments
|
| 66 |
+
|
| 67 |
+
Returns
|
| 68 |
+
-------
|
| 69 |
+
tuple[str, dict[str, Any]]
|
| 70 |
+
Processed message and kwargs
|
| 71 |
+
"""
|
| 72 |
+
# Add iteration, run ID, and phase context if driver reference is available
|
| 73 |
+
if self.driver_ref is not None:
|
| 74 |
+
extra = kwargs.get("extra", {})
|
| 75 |
+
|
| 76 |
+
# Add iteration context
|
| 77 |
+
if hasattr(self.driver_ref, "active_learning_step_idx"):
|
| 78 |
+
iteration = getattr(self.driver_ref, "active_learning_step_idx", None)
|
| 79 |
+
if iteration is not None:
|
| 80 |
+
extra["iteration"] = iteration
|
| 81 |
+
|
| 82 |
+
# Add run ID context
|
| 83 |
+
if hasattr(self.driver_ref, "run_id"):
|
| 84 |
+
run_id = getattr(self.driver_ref, "run_id", None)
|
| 85 |
+
if run_id is not None:
|
| 86 |
+
extra["run_id"] = run_id
|
| 87 |
+
|
| 88 |
+
# Add current phase context
|
| 89 |
+
if hasattr(self.driver_ref, "current_phase"):
|
| 90 |
+
phase = getattr(self.driver_ref, "current_phase", None)
|
| 91 |
+
if phase is not None:
|
| 92 |
+
extra["phase"] = phase
|
| 93 |
+
|
| 94 |
+
if extra:
|
| 95 |
+
kwargs["extra"] = extra
|
| 96 |
+
|
| 97 |
+
return msg, kwargs
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
class JSONFormatter(logging.Formatter):
|
| 101 |
+
"""JSON formatter for structured logging to files.
|
| 102 |
+
|
| 103 |
+
This formatter converts log records to JSON format, including all
|
| 104 |
+
contextual information and metadata for structured analysis.
|
| 105 |
+
"""
|
| 106 |
+
|
| 107 |
+
def format(self, record: logging.LogRecord) -> str:
|
| 108 |
+
"""Format the log record as JSON.
|
| 109 |
+
|
| 110 |
+
Parameters
|
| 111 |
+
----------
|
| 112 |
+
record : logging.LogRecord
|
| 113 |
+
The log record to format
|
| 114 |
+
|
| 115 |
+
Returns
|
| 116 |
+
-------
|
| 117 |
+
str
|
| 118 |
+
JSON-formatted log message
|
| 119 |
+
"""
|
| 120 |
+
log_entry = {
|
| 121 |
+
"timestamp": datetime.fromtimestamp(record.created).isoformat(),
|
| 122 |
+
"level": record.levelname,
|
| 123 |
+
"logger": record.name,
|
| 124 |
+
"message": record.getMessage(),
|
| 125 |
+
"module": record.module,
|
| 126 |
+
"function": record.funcName,
|
| 127 |
+
"line": record.lineno,
|
| 128 |
+
}
|
| 129 |
+
|
| 130 |
+
# Add contextual information if available
|
| 131 |
+
if hasattr(record, "context"):
|
| 132 |
+
log_entry["context"] = record.context
|
| 133 |
+
|
| 134 |
+
if hasattr(record, "caller_object"):
|
| 135 |
+
log_entry["caller_object"] = record.caller_object
|
| 136 |
+
|
| 137 |
+
if hasattr(record, "iteration"):
|
| 138 |
+
log_entry["iteration"] = record.iteration
|
| 139 |
+
|
| 140 |
+
if hasattr(record, "phase"):
|
| 141 |
+
log_entry["phase"] = record.phase
|
| 142 |
+
|
| 143 |
+
extra_keys = list(filter(lambda x: x not in log_entry, record.__dict__.keys()))
|
| 144 |
+
# Add any extra fields
|
| 145 |
+
for key in extra_keys:
|
| 146 |
+
log_entry[key] = record.__dict__[key]
|
| 147 |
+
|
| 148 |
+
return json.dumps(log_entry)
|
| 149 |
+
|
| 150 |
+
|
| 151 |
+
def _get_context_stack():
|
| 152 |
+
"""Get the context stack for the current thread."""
|
| 153 |
+
if not hasattr(_context_storage, "context_stack"):
|
| 154 |
+
_context_storage.context_stack = []
|
| 155 |
+
return _context_storage.context_stack
|
| 156 |
+
|
| 157 |
+
|
| 158 |
+
class ContextFormatter(logging.Formatter):
|
| 159 |
+
"""Standard formatter that includes active learning context information with colors."""
|
| 160 |
+
|
| 161 |
+
def format(self, record):
|
| 162 |
+
# Build context string
|
| 163 |
+
context_parts = []
|
| 164 |
+
if hasattr(record, "caller_object") and record.caller_object:
|
| 165 |
+
context_parts.append(f"obj:{record.caller_object}")
|
| 166 |
+
if hasattr(record, "run_id") and record.run_id:
|
| 167 |
+
context_parts.append(f"run:{record.run_id}")
|
| 168 |
+
if hasattr(record, "iteration") and record.iteration is not None:
|
| 169 |
+
context_parts.append(f"iter:{record.iteration}")
|
| 170 |
+
if hasattr(record, "phase") and record.phase:
|
| 171 |
+
context_parts.append(f"phase:{record.phase}")
|
| 172 |
+
if hasattr(record, "context") and record.context:
|
| 173 |
+
for key, value in record.context.items():
|
| 174 |
+
context_parts.append(f"{key}:{value}")
|
| 175 |
+
|
| 176 |
+
context_str = f"[{', '.join(context_parts)}]" if context_parts else ""
|
| 177 |
+
|
| 178 |
+
# Use standard formatting
|
| 179 |
+
base_msg = super().format(record)
|
| 180 |
+
|
| 181 |
+
# Add color to the message based on level if termcolor is available
|
| 182 |
+
if colored is not None:
|
| 183 |
+
match record.levelno:
|
| 184 |
+
case level if level >= logging.ERROR:
|
| 185 |
+
base_msg = colored(base_msg, "red")
|
| 186 |
+
case level if level >= logging.WARNING:
|
| 187 |
+
base_msg = colored(base_msg, "yellow")
|
| 188 |
+
case level if level >= logging.INFO:
|
| 189 |
+
base_msg = colored(base_msg, "white")
|
| 190 |
+
case _: # DEBUG
|
| 191 |
+
base_msg = colored(base_msg, "cyan")
|
| 192 |
+
|
| 193 |
+
# Add colored context string
|
| 194 |
+
if context_str:
|
| 195 |
+
if colored is not None:
|
| 196 |
+
context_str = colored(context_str, "blue")
|
| 197 |
+
base_msg += f" {context_str}"
|
| 198 |
+
|
| 199 |
+
return base_msg
|
| 200 |
+
|
| 201 |
+
|
| 202 |
+
class ContextInjectingFilter(logging.Filter):
|
| 203 |
+
"""Filter that injects contextual information into log records."""
|
| 204 |
+
|
| 205 |
+
def filter(self, record):
|
| 206 |
+
# Add context information from thread-local storage
|
| 207 |
+
context_stack = _get_context_stack()
|
| 208 |
+
if context_stack:
|
| 209 |
+
current_context = context_stack[-1]
|
| 210 |
+
if current_context["caller_object"]:
|
| 211 |
+
record.caller_object = current_context["caller_object"]
|
| 212 |
+
if current_context["iteration"] is not None:
|
| 213 |
+
record.iteration = current_context["iteration"]
|
| 214 |
+
if current_context.get("phase"):
|
| 215 |
+
record.phase = current_context["phase"]
|
| 216 |
+
if current_context["context"]:
|
| 217 |
+
record.context = current_context["context"]
|
| 218 |
+
return True
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def setup_active_learning_logger(
|
| 222 |
+
name: str,
|
| 223 |
+
run_id: str,
|
| 224 |
+
log_dir: str | Path = Path("active_learning_logs"),
|
| 225 |
+
level: int = logging.INFO,
|
| 226 |
+
) -> logging.Logger:
|
| 227 |
+
"""Set up a logger with active learning-specific formatting and handlers.
|
| 228 |
+
|
| 229 |
+
Parameters
|
| 230 |
+
----------
|
| 231 |
+
name : str
|
| 232 |
+
Logger name
|
| 233 |
+
run_id : str
|
| 234 |
+
Unique identifier for this run, used in log filename
|
| 235 |
+
log_dir : str | Path, optional
|
| 236 |
+
Directory to store log files, by default "./logs"
|
| 237 |
+
level : int, optional
|
| 238 |
+
Logging level, by default logging.INFO
|
| 239 |
+
|
| 240 |
+
Returns
|
| 241 |
+
-------
|
| 242 |
+
logging.Logger
|
| 243 |
+
Configured standard Python logger
|
| 244 |
+
|
| 245 |
+
Example
|
| 246 |
+
-------
|
| 247 |
+
>>> logger = setup_active_learning_logger("experiment", "run_001")
|
| 248 |
+
>>> logger.info("Starting experiment")
|
| 249 |
+
>>> with log_context(caller_object="Trainer", iteration=5):
|
| 250 |
+
... logger.info("Training step")
|
| 251 |
+
"""
|
| 252 |
+
# Get standard logger
|
| 253 |
+
logger = logging.getLogger(name)
|
| 254 |
+
logger.setLevel(level)
|
| 255 |
+
|
| 256 |
+
# Clear any existing handlers to avoid duplicates
|
| 257 |
+
logger.handlers.clear()
|
| 258 |
+
|
| 259 |
+
# Disable propagation to prevent duplicate messages from parent loggers
|
| 260 |
+
logger.propagate = False
|
| 261 |
+
|
| 262 |
+
# Create log directory if it doesn't exist
|
| 263 |
+
if isinstance(log_dir, str):
|
| 264 |
+
log_dir_path = Path(log_dir)
|
| 265 |
+
else:
|
| 266 |
+
log_dir_path = log_dir
|
| 267 |
+
log_dir_path.mkdir(parents=True, exist_ok=True)
|
| 268 |
+
|
| 269 |
+
# Set up console handler with standard formatting
|
| 270 |
+
console_handler = logging.StreamHandler()
|
| 271 |
+
console_handler.setLevel(logging.DEBUG)
|
| 272 |
+
console_handler.setFormatter(
|
| 273 |
+
ContextFormatter("%(asctime)s - %(name)s - %(levelname)s - %(message)s")
|
| 274 |
+
)
|
| 275 |
+
console_handler.addFilter(ContextInjectingFilter())
|
| 276 |
+
logger.addHandler(console_handler)
|
| 277 |
+
|
| 278 |
+
# Set up file handler with JSON formatting
|
| 279 |
+
log_file = log_dir_path / f"{run_id}.log"
|
| 280 |
+
file_handler = logging.FileHandler(log_file, mode="w")
|
| 281 |
+
file_handler.setLevel(logging.DEBUG)
|
| 282 |
+
file_handler.setFormatter(JSONFormatter())
|
| 283 |
+
file_handler.addFilter(ContextInjectingFilter())
|
| 284 |
+
logger.addHandler(file_handler)
|
| 285 |
+
|
| 286 |
+
return logger
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
@contextmanager
|
| 290 |
+
def log_context(
|
| 291 |
+
caller_object: str | None = None,
|
| 292 |
+
iteration: int | None = None,
|
| 293 |
+
phase: str | None = None,
|
| 294 |
+
**kwargs: Any,
|
| 295 |
+
):
|
| 296 |
+
"""Context manager for adding contextual information to log messages.
|
| 297 |
+
|
| 298 |
+
Parameters
|
| 299 |
+
----------
|
| 300 |
+
caller_object : str, optional
|
| 301 |
+
Name or identifier of the object making the log call
|
| 302 |
+
iteration : int, optional
|
| 303 |
+
Current iteration counter
|
| 304 |
+
phase : str, optional
|
| 305 |
+
Current phase of the active learning process
|
| 306 |
+
**kwargs : Any
|
| 307 |
+
Additional contextual key-value pairs
|
| 308 |
+
|
| 309 |
+
Example
|
| 310 |
+
-------
|
| 311 |
+
>>> from logging import getLogger
|
| 312 |
+
>>> from physicsnemo.active_learning.logger import log_context
|
| 313 |
+
>>> logger = getLogger("my_logger")
|
| 314 |
+
>>> with log_context(caller_object="Trainer", iteration=5, phase="training", epoch=2):
|
| 315 |
+
... logger.info("Processing batch")
|
| 316 |
+
"""
|
| 317 |
+
context_info = {
|
| 318 |
+
"caller_object": caller_object,
|
| 319 |
+
"iteration": iteration,
|
| 320 |
+
"phase": phase,
|
| 321 |
+
"context": kwargs,
|
| 322 |
+
}
|
| 323 |
+
|
| 324 |
+
context_stack = _get_context_stack()
|
| 325 |
+
context_stack.append(context_info)
|
| 326 |
+
|
| 327 |
+
try:
|
| 328 |
+
yield
|
| 329 |
+
finally:
|
| 330 |
+
context_stack.pop()
|
physics_mcp/source/physicsnemo/active_learning/loop.py
ADDED
|
@@ -0,0 +1,534 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from __future__ import annotations
|
| 18 |
+
|
| 19 |
+
import inspect
|
| 20 |
+
from pathlib import Path
|
| 21 |
+
from typing import Any
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
from torch.optim import Optimizer
|
| 25 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 26 |
+
from torch.utils.data import DataLoader
|
| 27 |
+
from tqdm import tqdm
|
| 28 |
+
|
| 29 |
+
from physicsnemo import Module
|
| 30 |
+
from physicsnemo.active_learning import protocols as p
|
| 31 |
+
from physicsnemo.distributed import DistributedManager
|
| 32 |
+
from physicsnemo.launch.logging import LaunchLogger
|
| 33 |
+
from physicsnemo.utils.capture import StaticCaptureEvaluateNoGrad, StaticCaptureTraining
|
| 34 |
+
|
| 35 |
+
__all__ = ["DefaultTrainingLoop"]
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
def _recursive_data_device_cast(
|
| 39 |
+
data: Any,
|
| 40 |
+
device: torch.device | str | None = None,
|
| 41 |
+
dtype: torch.dtype | None = None,
|
| 42 |
+
**kwargs: Any,
|
| 43 |
+
) -> Any:
|
| 44 |
+
"""
|
| 45 |
+
Recursively moves/cast input data to a specified device and dtype.
|
| 46 |
+
|
| 47 |
+
For iterable objects, we recurse through the elements depending on
|
| 48 |
+
the type of iterable until we reach an object that either has a ``to``
|
| 49 |
+
method that can be called, or just returns the data unchanged.
|
| 50 |
+
|
| 51 |
+
Parameters
|
| 52 |
+
----------
|
| 53 |
+
data: Any
|
| 54 |
+
The data to move to the device.
|
| 55 |
+
device: torch.device | str | None = None
|
| 56 |
+
The device to move the data to.
|
| 57 |
+
dtype: torch.dtype | None = None
|
| 58 |
+
The dtype to move the data to.
|
| 59 |
+
kwargs: Any
|
| 60 |
+
Additional keyword arguments to pass to the `to` method.
|
| 61 |
+
By default, `non_blocking` is set to `True` to allow
|
| 62 |
+
asynchronous data transfers.
|
| 63 |
+
|
| 64 |
+
Returns
|
| 65 |
+
-------
|
| 66 |
+
Any
|
| 67 |
+
The data moved to the device.
|
| 68 |
+
"""
|
| 69 |
+
kwargs.setdefault("non_blocking", True)
|
| 70 |
+
if hasattr(data, "to"):
|
| 71 |
+
# if there is a `to` method, then we can just call it
|
| 72 |
+
return data.to(device=device, dtype=dtype, **kwargs)
|
| 73 |
+
elif isinstance(data, dict):
|
| 74 |
+
return {
|
| 75 |
+
k: _recursive_data_device_cast(v, device, dtype) for k, v in data.items()
|
| 76 |
+
}
|
| 77 |
+
elif isinstance(data, list):
|
| 78 |
+
return [_recursive_data_device_cast(v, device, dtype) for v in data]
|
| 79 |
+
elif isinstance(data, tuple):
|
| 80 |
+
return tuple(_recursive_data_device_cast(v, device, dtype) for v in data)
|
| 81 |
+
else:
|
| 82 |
+
return data
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
class DefaultTrainingLoop(p.TrainingLoop):
|
| 86 |
+
def __new__(cls, *args: Any, **kwargs: Any) -> DefaultTrainingLoop:
|
| 87 |
+
"""
|
| 88 |
+
Wrapper for instantiating DefaultTrainingLoop.
|
| 89 |
+
|
| 90 |
+
This method captures arguments used to instantiate the loop
|
| 91 |
+
and stores them in the `_args` attribute for serialization.
|
| 92 |
+
This follows the same pattern as `ActiveLearningProtocol.__new__`.
|
| 93 |
+
|
| 94 |
+
Parameters
|
| 95 |
+
----------
|
| 96 |
+
args: Any
|
| 97 |
+
Arguments to pass to the loop's constructor.
|
| 98 |
+
kwargs: Any
|
| 99 |
+
Keyword arguments to pass to the loop's constructor.
|
| 100 |
+
|
| 101 |
+
Returns
|
| 102 |
+
-------
|
| 103 |
+
DefaultTrainingLoop
|
| 104 |
+
A new instance with an `_args` attribute for serialization.
|
| 105 |
+
"""
|
| 106 |
+
out = super().__new__(cls)
|
| 107 |
+
|
| 108 |
+
# Get signature of __init__ function
|
| 109 |
+
sig = inspect.signature(cls.__init__)
|
| 110 |
+
|
| 111 |
+
# Bind args and kwargs to signature
|
| 112 |
+
bound_args = sig.bind_partial(
|
| 113 |
+
*([None] + list(args)), **kwargs
|
| 114 |
+
) # Add None to account for self
|
| 115 |
+
bound_args.apply_defaults()
|
| 116 |
+
|
| 117 |
+
# Get args and kwargs (excluding self and unroll kwargs)
|
| 118 |
+
instantiate_args = {}
|
| 119 |
+
for param, (k, v) in zip(sig.parameters.values(), bound_args.arguments.items()):
|
| 120 |
+
# Skip self
|
| 121 |
+
if k == "self":
|
| 122 |
+
continue
|
| 123 |
+
|
| 124 |
+
# Add args and kwargs to instantiate_args
|
| 125 |
+
if param.kind == param.VAR_KEYWORD:
|
| 126 |
+
instantiate_args.update(v)
|
| 127 |
+
else:
|
| 128 |
+
# Special handling for device: convert torch.device to string
|
| 129 |
+
if k == "device" and isinstance(v, torch.device):
|
| 130 |
+
instantiate_args[k] = str(v)
|
| 131 |
+
# Special handling for dtype: convert to string representation
|
| 132 |
+
elif k == "dtype" and isinstance(v, torch.dtype):
|
| 133 |
+
instantiate_args[k] = str(v)
|
| 134 |
+
else:
|
| 135 |
+
instantiate_args[k] = v
|
| 136 |
+
|
| 137 |
+
# Store args needed for instantiation
|
| 138 |
+
out._args = {
|
| 139 |
+
"__name__": cls.__name__,
|
| 140 |
+
"__module__": cls.__module__,
|
| 141 |
+
"__args__": instantiate_args,
|
| 142 |
+
}
|
| 143 |
+
return out
|
| 144 |
+
|
| 145 |
+
def __init__(
|
| 146 |
+
self,
|
| 147 |
+
train_step_fn: p.TrainingProtocol | None = None,
|
| 148 |
+
validate_step_fn: p.ValidationProtocol | None = None,
|
| 149 |
+
enable_static_capture: bool = True,
|
| 150 |
+
use_progress_bars: bool = True,
|
| 151 |
+
device: str | torch.device | None = None,
|
| 152 |
+
dtype: torch.dtype | None = None,
|
| 153 |
+
checkpoint_frequency: int = 0,
|
| 154 |
+
**capture_kwargs: Any,
|
| 155 |
+
) -> None:
|
| 156 |
+
"""
|
| 157 |
+
Initializes the default training loop.
|
| 158 |
+
|
| 159 |
+
The general usage of this loop is to
|
| 160 |
+
|
| 161 |
+
TODO: add support for early stopping
|
| 162 |
+
|
| 163 |
+
Parameters
|
| 164 |
+
----------
|
| 165 |
+
train_step_fn: TrainingProtocol | None = None
|
| 166 |
+
A callable that implements the logic for performing a single
|
| 167 |
+
training step. See ``protocols.TrainingProtocol`` for the expected
|
| 168 |
+
interface, but ultimately the function should return a scalar loss
|
| 169 |
+
value that has a ``backward`` method.
|
| 170 |
+
validate_step_fn: ValidationProtocol | None = None
|
| 171 |
+
A callable that implements the logic for performing a single
|
| 172 |
+
validation step. See ``protocols.ValidationProtocol`` for the expected
|
| 173 |
+
interface, but in contrast to ``train_step_fn`` this function should
|
| 174 |
+
not return anything.
|
| 175 |
+
enable_static_capture: bool = True
|
| 176 |
+
Whether to enable static capture for the training and validation steps.
|
| 177 |
+
use_progress_bars: bool = True
|
| 178 |
+
Whether to show ``tqdm`` progress bars to display epoch and step progress.
|
| 179 |
+
device: str | torch.device | None = None
|
| 180 |
+
The device used for performing the loop. If not provided, then the device
|
| 181 |
+
will default to the model's device at runtime.
|
| 182 |
+
dtype: torch.dtype | None = None
|
| 183 |
+
The dtype used for performing the loop. If not provided, then the dtype
|
| 184 |
+
will default to ``torch.get_default_dtype()``.
|
| 185 |
+
checkpoint_frequency: int = 0
|
| 186 |
+
How often to save checkpoints during training (every N epochs).
|
| 187 |
+
If 0, no checkpoints are saved during training. Set via Driver before
|
| 188 |
+
training execution.
|
| 189 |
+
capture_kwargs: Any
|
| 190 |
+
Additional keyword arguments to pass to the static capture decorators.
|
| 191 |
+
"""
|
| 192 |
+
self.train_step_fn = train_step_fn
|
| 193 |
+
self.validate_step_fn = validate_step_fn
|
| 194 |
+
self.enable_static_capture = enable_static_capture
|
| 195 |
+
if isinstance(device, str):
|
| 196 |
+
device = torch.device(device)
|
| 197 |
+
# check to see if we can rely on DistributedManager
|
| 198 |
+
if device is None and DistributedManager.is_initialized():
|
| 199 |
+
device = DistributedManager.device
|
| 200 |
+
self.device = device
|
| 201 |
+
if dtype is None:
|
| 202 |
+
dtype = torch.get_default_dtype()
|
| 203 |
+
self.dtype = dtype
|
| 204 |
+
self.capture_kwargs = capture_kwargs
|
| 205 |
+
self.use_progress_bars = use_progress_bars
|
| 206 |
+
self.capture_functions = {}
|
| 207 |
+
self.checkpoint_frequency = checkpoint_frequency
|
| 208 |
+
self.checkpoint_base_dir: Path | None = None
|
| 209 |
+
|
| 210 |
+
def save_training_checkpoint(
|
| 211 |
+
self,
|
| 212 |
+
checkpoint_dir: Path,
|
| 213 |
+
model: Module | p.LearnerProtocol,
|
| 214 |
+
optimizer: Optimizer,
|
| 215 |
+
lr_scheduler: _LRScheduler | None = None,
|
| 216 |
+
training_epoch: int | None = None,
|
| 217 |
+
) -> None:
|
| 218 |
+
"""
|
| 219 |
+
Save training state to checkpoint directory.
|
| 220 |
+
|
| 221 |
+
Model weights are saved separately. Optimizer, scheduler, and epoch
|
| 222 |
+
metadata are combined into a single training_state.pt file.
|
| 223 |
+
|
| 224 |
+
Parameters
|
| 225 |
+
----------
|
| 226 |
+
checkpoint_dir: Path
|
| 227 |
+
Directory to save checkpoint files.
|
| 228 |
+
model: Module | p.LearnerProtocol
|
| 229 |
+
Model to save weights for.
|
| 230 |
+
optimizer: Optimizer
|
| 231 |
+
Optimizer to save state from.
|
| 232 |
+
lr_scheduler: _LRScheduler | None
|
| 233 |
+
Optional LR scheduler to save state from.
|
| 234 |
+
training_epoch: int | None
|
| 235 |
+
Current training epoch for metadata.
|
| 236 |
+
"""
|
| 237 |
+
checkpoint_dir.mkdir(parents=True, exist_ok=True)
|
| 238 |
+
|
| 239 |
+
# Save model weights separately
|
| 240 |
+
if isinstance(model, Module):
|
| 241 |
+
model_path = checkpoint_dir / "model.mdlus"
|
| 242 |
+
model.save(str(model_path))
|
| 243 |
+
else:
|
| 244 |
+
model_path = checkpoint_dir / "model_state.pt"
|
| 245 |
+
torch.save(model.state_dict(), model_path)
|
| 246 |
+
|
| 247 |
+
# Combine optimizer, scheduler, and epoch metadata into single file
|
| 248 |
+
training_state = {
|
| 249 |
+
"optimizer_state": optimizer.state_dict(),
|
| 250 |
+
"lr_scheduler_state": lr_scheduler.state_dict() if lr_scheduler else None,
|
| 251 |
+
"training_epoch": training_epoch,
|
| 252 |
+
}
|
| 253 |
+
training_state_path = checkpoint_dir / "training_state.pt"
|
| 254 |
+
torch.save(training_state, training_state_path)
|
| 255 |
+
|
| 256 |
+
@staticmethod
|
| 257 |
+
def load_training_checkpoint(
|
| 258 |
+
checkpoint_dir: Path,
|
| 259 |
+
model: Module | p.LearnerProtocol,
|
| 260 |
+
optimizer: Optimizer,
|
| 261 |
+
lr_scheduler: _LRScheduler | None = None,
|
| 262 |
+
) -> int | None:
|
| 263 |
+
"""
|
| 264 |
+
Load training state from checkpoint directory.
|
| 265 |
+
|
| 266 |
+
Model weights are loaded separately. Optimizer, scheduler, and epoch
|
| 267 |
+
metadata are loaded from the combined training_state.pt file.
|
| 268 |
+
|
| 269 |
+
Parameters
|
| 270 |
+
----------
|
| 271 |
+
checkpoint_dir: Path
|
| 272 |
+
Directory containing checkpoint files.
|
| 273 |
+
model: Module | p.LearnerProtocol
|
| 274 |
+
Model to load weights into.
|
| 275 |
+
optimizer: Optimizer
|
| 276 |
+
Optimizer to load state into.
|
| 277 |
+
lr_scheduler: _LRScheduler | None
|
| 278 |
+
Optional LR scheduler to load state into.
|
| 279 |
+
|
| 280 |
+
Returns
|
| 281 |
+
-------
|
| 282 |
+
int | None
|
| 283 |
+
Training epoch from metadata if available, else None.
|
| 284 |
+
"""
|
| 285 |
+
# Load model weights separately
|
| 286 |
+
if isinstance(model, Module):
|
| 287 |
+
model_path = checkpoint_dir / "model.mdlus"
|
| 288 |
+
if model_path.exists():
|
| 289 |
+
model.load(str(model_path))
|
| 290 |
+
else:
|
| 291 |
+
model_state_path = checkpoint_dir / "model_state.pt"
|
| 292 |
+
if model_state_path.exists():
|
| 293 |
+
state_dict = torch.load(model_state_path, map_location="cpu")
|
| 294 |
+
model.load_state_dict(state_dict)
|
| 295 |
+
|
| 296 |
+
# Load combined training state (optimizer, scheduler, epoch)
|
| 297 |
+
training_state_path = checkpoint_dir / "training_state.pt"
|
| 298 |
+
if training_state_path.exists():
|
| 299 |
+
training_state = torch.load(training_state_path, map_location="cpu")
|
| 300 |
+
|
| 301 |
+
# Restore optimizer state
|
| 302 |
+
if "optimizer_state" in training_state:
|
| 303 |
+
optimizer.load_state_dict(training_state["optimizer_state"])
|
| 304 |
+
|
| 305 |
+
# Restore scheduler state if present
|
| 306 |
+
if lr_scheduler and training_state.get("lr_scheduler_state"):
|
| 307 |
+
lr_scheduler.load_state_dict(training_state["lr_scheduler_state"])
|
| 308 |
+
|
| 309 |
+
# Return epoch metadata
|
| 310 |
+
return training_state.get("training_epoch", None)
|
| 311 |
+
|
| 312 |
+
return None
|
| 313 |
+
|
| 314 |
+
@property
|
| 315 |
+
def amp_type(self) -> torch.dtype:
|
| 316 |
+
if self.dtype in [torch.float16, torch.bfloat16]:
|
| 317 |
+
return self.dtype
|
| 318 |
+
else:
|
| 319 |
+
return torch.float16
|
| 320 |
+
|
| 321 |
+
def _create_capture_functions(
|
| 322 |
+
self,
|
| 323 |
+
model: Module | p.LearnerProtocol,
|
| 324 |
+
optimizer: Optimizer,
|
| 325 |
+
train_step_fn: p.TrainingProtocol | None = None,
|
| 326 |
+
validate_step_fn: p.ValidationProtocol | None = None,
|
| 327 |
+
) -> tuple[p.TrainingProtocol | None, p.ValidationProtocol | None]:
|
| 328 |
+
"""
|
| 329 |
+
Attempt to create static capture functions based off training and validation
|
| 330 |
+
functions.
|
| 331 |
+
|
| 332 |
+
This uses the Python object IDs to unique identify functions, and adds the
|
| 333 |
+
decorated functions to an internal `capture_functions` dictionary. If the
|
| 334 |
+
decorated functions already exist, then this function will be no-op.
|
| 335 |
+
|
| 336 |
+
Parameters
|
| 337 |
+
----------
|
| 338 |
+
model: Module | p.LearnerProtocol
|
| 339 |
+
The model to train.
|
| 340 |
+
optimizer: Optimizer
|
| 341 |
+
The optimizer to use for training.
|
| 342 |
+
train_step_fn: p.TrainingProtocol | None = None
|
| 343 |
+
The training function to use for training.
|
| 344 |
+
validate_step_fn: p.ValidationProtocol | None = None
|
| 345 |
+
The validation function to use for validation.
|
| 346 |
+
|
| 347 |
+
Returns
|
| 348 |
+
-------
|
| 349 |
+
tuple[p.TrainingProtocol | None, p.ValidationProtocol | None]
|
| 350 |
+
The training and validation functions with static capture applied.
|
| 351 |
+
"""
|
| 352 |
+
if not train_step_fn:
|
| 353 |
+
train_step_fn = self.train_step_fn
|
| 354 |
+
train_func_id = id(train_step_fn)
|
| 355 |
+
if train_func_id not in self.capture_functions:
|
| 356 |
+
try:
|
| 357 |
+
train_step_fn = StaticCaptureTraining(
|
| 358 |
+
model=model,
|
| 359 |
+
optim=optimizer,
|
| 360 |
+
amp_type=self.amp_type,
|
| 361 |
+
**self.capture_kwargs,
|
| 362 |
+
)(train_step_fn)
|
| 363 |
+
self.capture_functions[train_func_id] = train_step_fn
|
| 364 |
+
except Exception as e:
|
| 365 |
+
raise RuntimeError(
|
| 366 |
+
"Failed to create static capture for `train_step_fn`. "
|
| 367 |
+
) from e
|
| 368 |
+
else:
|
| 369 |
+
train_step_fn = self.capture_functions[train_func_id]
|
| 370 |
+
if not validate_step_fn:
|
| 371 |
+
validate_step_fn = self.validate_step_fn
|
| 372 |
+
if validate_step_fn:
|
| 373 |
+
val_func_id = id(validate_step_fn)
|
| 374 |
+
if val_func_id not in self.capture_functions:
|
| 375 |
+
try:
|
| 376 |
+
validate_step_fn = StaticCaptureEvaluateNoGrad(
|
| 377 |
+
model=model, amp_type=self.amp_type, **self.capture_kwargs
|
| 378 |
+
)(validate_step_fn)
|
| 379 |
+
self.capture_functions[val_func_id] = validate_step_fn
|
| 380 |
+
except Exception as e:
|
| 381 |
+
raise RuntimeError(
|
| 382 |
+
"Failed to create static capture for `validate_step_fn`. "
|
| 383 |
+
) from e
|
| 384 |
+
else:
|
| 385 |
+
validate_step_fn = self.capture_functions[val_func_id]
|
| 386 |
+
return train_step_fn, validate_step_fn
|
| 387 |
+
|
| 388 |
+
def __call__(
|
| 389 |
+
self,
|
| 390 |
+
model: Module | p.LearnerProtocol,
|
| 391 |
+
optimizer: Optimizer,
|
| 392 |
+
train_dataloader: DataLoader,
|
| 393 |
+
max_epochs: int,
|
| 394 |
+
validation_dataloader: DataLoader | None = None,
|
| 395 |
+
train_step_fn: p.TrainingProtocol | None = None,
|
| 396 |
+
validate_step_fn: p.ValidationProtocol | None = None,
|
| 397 |
+
lr_scheduler: _LRScheduler | None = None,
|
| 398 |
+
device: str | torch.device | None = None,
|
| 399 |
+
dtype: torch.dtype | None = None,
|
| 400 |
+
*args: Any,
|
| 401 |
+
**kwargs: Any,
|
| 402 |
+
) -> None:
|
| 403 |
+
"""
|
| 404 |
+
Performs ``max_epochs`` epochs of training and optionally validation.
|
| 405 |
+
|
| 406 |
+
Some of the arguments, such as ``train_step_fn`` and ``validate_step_fn``,
|
| 407 |
+
are optional only if the ``model`` implements the ``p.LearnerProtocol``.
|
| 408 |
+
If they are passed, however, they will take precedence over the methods
|
| 409 |
+
originally provided to the constructor method.
|
| 410 |
+
|
| 411 |
+
The bare minimum required arguments for this loop to work are:
|
| 412 |
+
1. A model to train
|
| 413 |
+
2. An optimizer to step
|
| 414 |
+
3. A training dataloader to iterate over
|
| 415 |
+
4. The maximum number of epochs to train for
|
| 416 |
+
|
| 417 |
+
If validation is required, then both ``validation_dataloader`` and
|
| 418 |
+
``validate_step_fn`` must be specified.
|
| 419 |
+
|
| 420 |
+
Parameters
|
| 421 |
+
----------
|
| 422 |
+
model: Module | p.LearnerProtocol
|
| 423 |
+
The model to train.
|
| 424 |
+
optimizer: torch.optim.Optimizer
|
| 425 |
+
The optimizer to use for training.
|
| 426 |
+
train_dataloader: DataLoader
|
| 427 |
+
The dataloader to use for training.
|
| 428 |
+
max_epochs: int
|
| 429 |
+
The number of epochs to train for.
|
| 430 |
+
validation_dataloader: DataLoader | None
|
| 431 |
+
The dataloader to use for validation. If not provided, then validation
|
| 432 |
+
will not be performed.
|
| 433 |
+
train_step_fn: p.TrainingProtocol | None = None
|
| 434 |
+
The training function to use for training. If passed, it will take
|
| 435 |
+
precedence over the method provided to the constructor method.
|
| 436 |
+
validate_step_fn: p.ValidationProtocol | None = None
|
| 437 |
+
The validation function to use for validation.
|
| 438 |
+
lr_scheduler: torch.optim.lr_scheduler._LRScheduler | None = None
|
| 439 |
+
The learning rate scheduler to use for training.
|
| 440 |
+
device: str | torch.device | None = None
|
| 441 |
+
The device used for performing the loop. If provided, it will
|
| 442 |
+
override the device specified in the constructor. If both values
|
| 443 |
+
are not provided, then we default to PyTorch's default device.
|
| 444 |
+
dtype: torch.dtype | None = None
|
| 445 |
+
The dtype used for performing the loop. If provided, it will
|
| 446 |
+
override the dtype specified in the constructor. If both values
|
| 447 |
+
are not provided, then we default to PyTorch's default dtype.
|
| 448 |
+
args: Any
|
| 449 |
+
Additional arguments to pass the training and validation
|
| 450 |
+
step functions.
|
| 451 |
+
kwargs: Any
|
| 452 |
+
Additional keyword arguments to pass the training and validation
|
| 453 |
+
step functions.
|
| 454 |
+
"""
|
| 455 |
+
if not train_step_fn and not self.train_step_fn:
|
| 456 |
+
raise RuntimeError(
|
| 457 |
+
"""
|
| 458 |
+
No training step function provided.
|
| 459 |
+
Either provide a `train_step_fn` to this constructor, or
|
| 460 |
+
provide a `train_step_fn` to the `__call__` method.
|
| 461 |
+
"""
|
| 462 |
+
)
|
| 463 |
+
if not device and not self.device:
|
| 464 |
+
device = torch.get_default_device()
|
| 465 |
+
if not dtype and not self.dtype:
|
| 466 |
+
dtype = torch.get_default_dtype()
|
| 467 |
+
# if a device is specified, move the model
|
| 468 |
+
if device and device != model.device:
|
| 469 |
+
# not 100% sure this will trigger issues with the optimizer
|
| 470 |
+
# but allows a potentially different device to be used
|
| 471 |
+
model = model.to(device)
|
| 472 |
+
if self.enable_static_capture:
|
| 473 |
+
# if static capture is enabled, we check for a cache hit based on
|
| 474 |
+
# the incoming function IDs. If we miss, we then create new wrappers.
|
| 475 |
+
train_step_fn, validate_step_fn = self._create_capture_functions(
|
| 476 |
+
model, optimizer, train_step_fn, validate_step_fn
|
| 477 |
+
)
|
| 478 |
+
epoch_iter = range(1, max_epochs + 1)
|
| 479 |
+
if self.use_progress_bars:
|
| 480 |
+
epoch_iter = tqdm(epoch_iter, desc="Epoch", leave=False, position=0)
|
| 481 |
+
########### EPOCH LOOP ###########
|
| 482 |
+
for epoch in epoch_iter:
|
| 483 |
+
model.train()
|
| 484 |
+
train_iter = iter(train_dataloader)
|
| 485 |
+
if self.use_progress_bars:
|
| 486 |
+
train_iter = tqdm(
|
| 487 |
+
train_iter, desc="Training step", leave=False, unit="batch"
|
| 488 |
+
)
|
| 489 |
+
########### TRAINING STEP LOOP ###########
|
| 490 |
+
with LaunchLogger(
|
| 491 |
+
"train", epoch=epoch, num_mini_batch=len(train_dataloader)
|
| 492 |
+
) as log:
|
| 493 |
+
for batch in train_iter:
|
| 494 |
+
batch = _recursive_data_device_cast(
|
| 495 |
+
batch, device=device, dtype=dtype
|
| 496 |
+
)
|
| 497 |
+
model.zero_grad(set_to_none=True)
|
| 498 |
+
loss = train_step_fn(model, batch, *args, **kwargs)
|
| 499 |
+
log.log_minibatch({"train_loss": loss.detach().item()})
|
| 500 |
+
# normally, static capture will call backward because of AMP
|
| 501 |
+
if not self.enable_static_capture:
|
| 502 |
+
loss.backward()
|
| 503 |
+
optimizer.step()
|
| 504 |
+
if lr_scheduler:
|
| 505 |
+
lr_scheduler.step()
|
| 506 |
+
########### VALIDATION STEP LOOP ###########
|
| 507 |
+
if validate_step_fn and validation_dataloader:
|
| 508 |
+
model.eval()
|
| 509 |
+
val_iter = iter(validation_dataloader)
|
| 510 |
+
if self.use_progress_bars:
|
| 511 |
+
val_iter = tqdm(
|
| 512 |
+
val_iter, desc="Validation step", leave=False, unit="batch"
|
| 513 |
+
)
|
| 514 |
+
with LaunchLogger(
|
| 515 |
+
"validation", epoch=epoch, num_mini_batch=len(validation_dataloader)
|
| 516 |
+
) as log:
|
| 517 |
+
for batch in val_iter:
|
| 518 |
+
batch = _recursive_data_device_cast(
|
| 519 |
+
batch, device=device, dtype=dtype
|
| 520 |
+
)
|
| 521 |
+
validate_step_fn(model, batch, *args, **kwargs)
|
| 522 |
+
|
| 523 |
+
########### CHECKPOINT SAVE ###########
|
| 524 |
+
# Save training state at specified frequency
|
| 525 |
+
if self.checkpoint_base_dir and self.checkpoint_frequency > 0:
|
| 526 |
+
if epoch % self.checkpoint_frequency == 0:
|
| 527 |
+
epoch_checkpoint_dir = self.checkpoint_base_dir / f"epoch_{epoch}"
|
| 528 |
+
self.save_training_checkpoint(
|
| 529 |
+
checkpoint_dir=epoch_checkpoint_dir,
|
| 530 |
+
model=model,
|
| 531 |
+
optimizer=optimizer,
|
| 532 |
+
lr_scheduler=lr_scheduler,
|
| 533 |
+
training_epoch=epoch,
|
| 534 |
+
)
|
physics_mcp/source/physicsnemo/active_learning/protocols.py
ADDED
|
@@ -0,0 +1,1394 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
This module contains base classes for active learning protocols.
|
| 19 |
+
|
| 20 |
+
These are protocols intended to be abstract, and importing these
|
| 21 |
+
classes specifically is intended to either be subclassed, or for
|
| 22 |
+
type annotations.
|
| 23 |
+
|
| 24 |
+
Protocol Architecture
|
| 25 |
+
---------------------
|
| 26 |
+
Python ``Protocol``s are used for structural typing: essentially, they are used to
|
| 27 |
+
describe an expected interface in a way that is helpful for static type checkers
|
| 28 |
+
to make sure concrete implementations provide everything that is needed for a workflow
|
| 29 |
+
to function. ``Protocol``s are not actually enforced at runtime, and inheritance is not
|
| 30 |
+
required for them to function: as long as the implementation provides the expected
|
| 31 |
+
attributes and methods, they will be compatible with the protocol.
|
| 32 |
+
|
| 33 |
+
The active learning framework is built around several key protocol abstractions
|
| 34 |
+
that work together to orchestrate the active learning workflow:
|
| 35 |
+
|
| 36 |
+
**Core Infrastructure Protocols:**
|
| 37 |
+
- `AbstractQueue[T]` - Generic queue protocol for passing data between components
|
| 38 |
+
- `DataPool[T]` - Protocol for data reservoirs that support appending and sampling
|
| 39 |
+
- `ActiveLearningProtocol` - Base protocol providing common interface for all AL strategies
|
| 40 |
+
|
| 41 |
+
**Strategy Protocols (inherit from ActiveLearningProtocol):**
|
| 42 |
+
- `QueryStrategy` - Defines how to select data points for labeling
|
| 43 |
+
- `LabelStrategy` - Defines processes for adding ground truth labels to unlabeled data
|
| 44 |
+
- `MetrologyStrategy` - Defines procedures that assess model improvements beyond validation metrics
|
| 45 |
+
|
| 46 |
+
**Model Interface Protocols:**
|
| 47 |
+
- `TrainingProtocol` - Interface for training step functions
|
| 48 |
+
- `ValidationProtocol` - Interface for validation step functions
|
| 49 |
+
- `InferenceProtocol` - Interface for inference step functions
|
| 50 |
+
- `TrainingLoop` - Interface for complete training loop implementations
|
| 51 |
+
- `LearnerProtocol` - Comprehensive interface for learner modules (combines training/validation/inference)
|
| 52 |
+
|
| 53 |
+
**Orchestration Protocol:**
|
| 54 |
+
- `DriverProtocol` - Main orchestrator that coordinates all components in the active learning loop
|
| 55 |
+
|
| 56 |
+
Protocol Relationships
|
| 57 |
+
----------------------
|
| 58 |
+
|
| 59 |
+
```mermaid
|
| 60 |
+
graph TB
|
| 61 |
+
subgraph "Core Infrastructure"
|
| 62 |
+
AQ[AbstractQueue<T>]
|
| 63 |
+
DP[DataPool<T>]
|
| 64 |
+
ALP[ActiveLearningProtocol]
|
| 65 |
+
end
|
| 66 |
+
|
| 67 |
+
subgraph "Strategy Layer"
|
| 68 |
+
QS[QueryStrategy]
|
| 69 |
+
LS[LabelStrategy]
|
| 70 |
+
MS[MetrologyStrategy]
|
| 71 |
+
end
|
| 72 |
+
|
| 73 |
+
subgraph "Model Interface Layer"
|
| 74 |
+
TP[TrainingProtocol]
|
| 75 |
+
VP[ValidationProtocol]
|
| 76 |
+
IP[InferenceProtocol]
|
| 77 |
+
TL[TrainingLoop]
|
| 78 |
+
LP[LearnerProtocol]
|
| 79 |
+
end
|
| 80 |
+
|
| 81 |
+
subgraph "Orchestration Layer"
|
| 82 |
+
Driver[DriverProtocol]
|
| 83 |
+
end
|
| 84 |
+
|
| 85 |
+
%% Inheritance relationships (thick blue arrows)
|
| 86 |
+
ALP ==>|inherits| QS
|
| 87 |
+
ALP ==>|inherits| LS
|
| 88 |
+
ALP ==>|inherits| MS
|
| 89 |
+
|
| 90 |
+
%% Composition relationships (dashed green arrows)
|
| 91 |
+
Driver -.->|uses| LP
|
| 92 |
+
Driver -.->|manages| QS
|
| 93 |
+
Driver -.->|manages| LS
|
| 94 |
+
Driver -.->|manages| MS
|
| 95 |
+
Driver -.->|contains| DP
|
| 96 |
+
Driver -.->|contains| AQ
|
| 97 |
+
|
| 98 |
+
%% Protocol usage relationships (dotted purple arrows)
|
| 99 |
+
TL -.->|can use| TP
|
| 100 |
+
TL -.->|can use| VP
|
| 101 |
+
TL -.->|can use| LP
|
| 102 |
+
LP -.->|implements| TP
|
| 103 |
+
LP -.->|implements| VP
|
| 104 |
+
LP -.->|implements| IP
|
| 105 |
+
|
| 106 |
+
%% Data flow relationships (solid red arrows)
|
| 107 |
+
QS -->|enqueues to| AQ
|
| 108 |
+
AQ -->|consumed by| LS
|
| 109 |
+
LS -->|enqueues to| AQ
|
| 110 |
+
|
| 111 |
+
%% Styling for different relationship types
|
| 112 |
+
linkStyle 0 stroke:#1976d2,stroke-width:4px
|
| 113 |
+
linkStyle 1 stroke:#1976d2,stroke-width:4px
|
| 114 |
+
linkStyle 2 stroke:#1976d2,stroke-width:4px
|
| 115 |
+
linkStyle 3 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5
|
| 116 |
+
linkStyle 4 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5
|
| 117 |
+
linkStyle 5 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5
|
| 118 |
+
linkStyle 6 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5
|
| 119 |
+
linkStyle 7 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5
|
| 120 |
+
linkStyle 8 stroke:#388e3c,stroke-width:2px,stroke-dasharray: 5 5
|
| 121 |
+
linkStyle 9 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2
|
| 122 |
+
linkStyle 10 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2
|
| 123 |
+
linkStyle 11 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2
|
| 124 |
+
linkStyle 12 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2
|
| 125 |
+
linkStyle 13 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2
|
| 126 |
+
linkStyle 14 stroke:#7b1fa2,stroke-width:2px,stroke-dasharray: 2 2
|
| 127 |
+
linkStyle 15 stroke:#d32f2f,stroke-width:3px
|
| 128 |
+
linkStyle 16 stroke:#d32f2f,stroke-width:3px
|
| 129 |
+
linkStyle 17 stroke:#d32f2f,stroke-width:3px
|
| 130 |
+
|
| 131 |
+
%% Node styling
|
| 132 |
+
classDef coreInfra fill:#e3f2fd,stroke:#1976d2,stroke-width:2px
|
| 133 |
+
classDef strategy fill:#f3e5f5,stroke:#7b1fa2,stroke-width:2px
|
| 134 |
+
classDef modelInterface fill:#e8f5e8,stroke:#388e3c,stroke-width:2px
|
| 135 |
+
classDef orchestration fill:#fff3e0,stroke:#f57c00,stroke-width:3px
|
| 136 |
+
|
| 137 |
+
class AQ,DP,ALP coreInfra
|
| 138 |
+
class QS,LS,MS strategy
|
| 139 |
+
class TP,VP,IP,TL,LP modelInterface
|
| 140 |
+
class Driver orchestration
|
| 141 |
+
```
|
| 142 |
+
|
| 143 |
+
**Relationship Legend:**
|
| 144 |
+
- **Blue thick arrows (==>)**: Inheritance relationships (subclass extends parent)
|
| 145 |
+
- **Green dashed arrows (-.->)**: Composition relationships (object contains/manages other objects)
|
| 146 |
+
- **Purple dotted arrows (-.->)**: Protocol usage relationships (can use or implements interface)
|
| 147 |
+
- **Red solid arrows (-->)**: Data flow relationships (data moves between components)
|
| 148 |
+
|
| 149 |
+
Active Learning Workflow
|
| 150 |
+
------------------------
|
| 151 |
+
|
| 152 |
+
The typical active learning workflow orchestrated by `DriverProtocol` follows this sequence:
|
| 153 |
+
|
| 154 |
+
1. **Training Phase**: Use `LearnerProtocol` or `TrainingLoop` to train the model on `training_pool`
|
| 155 |
+
2. **Metrology Phase** (optional): Apply `MetrologyStrategy` instances to assess model performance
|
| 156 |
+
3. **Query Phase**: Apply `QueryStrategy` instances to select samples from `unlabeled_pool` → `query_queue`
|
| 157 |
+
4. **Labeling Phase** (optional): Apply `LabelStrategy` instances to label queued samples → `label_queue`
|
| 158 |
+
5. **Data Integration**: Move labeled data from `label_queue` to `training_pool`
|
| 159 |
+
|
| 160 |
+
Type Parameters
|
| 161 |
+
---------------
|
| 162 |
+
- `T`: Data structure containing both inputs and ground truth labels
|
| 163 |
+
- `S`: Data structure containing only inputs (no ground truth labels)
|
| 164 |
+
"""
|
| 165 |
+
|
| 166 |
+
from __future__ import annotations
|
| 167 |
+
|
| 168 |
+
import inspect
|
| 169 |
+
import logging
|
| 170 |
+
from enum import StrEnum
|
| 171 |
+
from logging import Logger
|
| 172 |
+
from pathlib import Path
|
| 173 |
+
from typing import Any, Iterator, Protocol, TypeVar
|
| 174 |
+
|
| 175 |
+
import torch
|
| 176 |
+
from torch.optim import Optimizer
|
| 177 |
+
from torch.optim.lr_scheduler import _LRScheduler
|
| 178 |
+
from torch.utils.data import DataLoader
|
| 179 |
+
|
| 180 |
+
from physicsnemo import Module
|
| 181 |
+
|
| 182 |
+
# T is used to denote a data structure that contains inputs for a model and ground truths
|
| 183 |
+
T = TypeVar("T")
|
| 184 |
+
# S is used to denote a data structure that has inputs for a model, but no ground truth labels
|
| 185 |
+
S = TypeVar("S")
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class ActiveLearningPhase(StrEnum):
|
| 189 |
+
"""
|
| 190 |
+
An enumeration of the different phases of the active learning workflow.
|
| 191 |
+
|
| 192 |
+
This is primarily used in the metadata for restarting an ongoing active
|
| 193 |
+
learning experiment.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
TRAINING = "training"
|
| 197 |
+
METROLOGY = "metrology"
|
| 198 |
+
QUERY = "query"
|
| 199 |
+
LABELING = "labeling"
|
| 200 |
+
DATA_INTEGRATION = "data_integration"
|
| 201 |
+
|
| 202 |
+
|
| 203 |
+
class AbstractQueue(Protocol[T]):
|
| 204 |
+
"""
|
| 205 |
+
Defines a generic queue protocol for data that is passed between active
|
| 206 |
+
learning components.
|
| 207 |
+
|
| 208 |
+
This can be a simple local `queue.Queue`, or a more sophisticated
|
| 209 |
+
distributed queue system.
|
| 210 |
+
|
| 211 |
+
The primary use case for this is to allow a query strategy to
|
| 212 |
+
enqueue some data structure for the labeling strategy to consume,
|
| 213 |
+
and once the labeling is done, enqueue to a data serialization
|
| 214 |
+
workflow. While there is no explcit restriction on the **type**
|
| 215 |
+
of queue that is implemented, a reasonable assumption to make
|
| 216 |
+
would be a FIFO queue, unless otherwise specified by the concrete
|
| 217 |
+
implementation.
|
| 218 |
+
|
| 219 |
+
Optional Serialization Methods
|
| 220 |
+
-------------------------------
|
| 221 |
+
Implementations may optionally provide `to_list()` and `from_list()`
|
| 222 |
+
methods for checkpoint serialization. If not provided, the queue
|
| 223 |
+
will be serialized using `torch.save()` as a fallback.
|
| 224 |
+
|
| 225 |
+
Type Parameters
|
| 226 |
+
---------------
|
| 227 |
+
T
|
| 228 |
+
The type of items that will be stored in the queue.
|
| 229 |
+
"""
|
| 230 |
+
|
| 231 |
+
def put(self, item: T) -> None:
|
| 232 |
+
"""
|
| 233 |
+
Method to put a data structure into the queue.
|
| 234 |
+
|
| 235 |
+
Parameters
|
| 236 |
+
----------
|
| 237 |
+
item: T
|
| 238 |
+
The data structure to put into the queue.
|
| 239 |
+
"""
|
| 240 |
+
...
|
| 241 |
+
|
| 242 |
+
def get(self) -> T:
|
| 243 |
+
"""
|
| 244 |
+
Method to get a data structure from the queue.
|
| 245 |
+
|
| 246 |
+
This method should remove the data structure from the queue,
|
| 247 |
+
and return it to a consumer.
|
| 248 |
+
|
| 249 |
+
Returns
|
| 250 |
+
-------
|
| 251 |
+
T
|
| 252 |
+
The data structure that was removed from the queue.
|
| 253 |
+
"""
|
| 254 |
+
...
|
| 255 |
+
|
| 256 |
+
def empty(self) -> bool:
|
| 257 |
+
"""
|
| 258 |
+
Method to check if the queue is empty/has been depleted.
|
| 259 |
+
|
| 260 |
+
Returns
|
| 261 |
+
-------
|
| 262 |
+
bool
|
| 263 |
+
True if the queue is empty, False otherwise.
|
| 264 |
+
"""
|
| 265 |
+
...
|
| 266 |
+
|
| 267 |
+
|
| 268 |
+
class DataPool(Protocol[T]):
|
| 269 |
+
"""
|
| 270 |
+
An abstract protocol for some reservoir of data that is
|
| 271 |
+
used for some part of active learning, parametrized such
|
| 272 |
+
that it will return data structures of an arbitrary type ``T``.
|
| 273 |
+
|
| 274 |
+
**All** methods are left abstract, and need to be defined
|
| 275 |
+
by concrete implementations. For the most part, a `torch.utils.data.Dataset`
|
| 276 |
+
would match this protocol, provided that it implements the ``append`` method
|
| 277 |
+
which will allow data to be persisted to a filesystem.
|
| 278 |
+
|
| 279 |
+
Methods
|
| 280 |
+
-------
|
| 281 |
+
__getitem__(self, index: int) -> T:
|
| 282 |
+
Method to get a single data structure from the data pool.
|
| 283 |
+
__len__(self) -> int:
|
| 284 |
+
Method to get the length of the data pool.
|
| 285 |
+
__iter__(self) -> Iterator[T]:
|
| 286 |
+
Method to iterate over the data pool.
|
| 287 |
+
append(self, item: T) -> None:
|
| 288 |
+
Method to append a data structure to the data pool.
|
| 289 |
+
"""
|
| 290 |
+
|
| 291 |
+
def __getitem__(self, index: int) -> T:
|
| 292 |
+
"""
|
| 293 |
+
Method to get a data structure from the data pool.
|
| 294 |
+
|
| 295 |
+
This method should retrieve an item from the pool by a
|
| 296 |
+
flat index.
|
| 297 |
+
|
| 298 |
+
Parameters
|
| 299 |
+
----------
|
| 300 |
+
index: int
|
| 301 |
+
The index of the data structure to get.
|
| 302 |
+
|
| 303 |
+
Returns
|
| 304 |
+
-------
|
| 305 |
+
T
|
| 306 |
+
The data structure at the given index.
|
| 307 |
+
"""
|
| 308 |
+
...
|
| 309 |
+
|
| 310 |
+
def __len__(self) -> int:
|
| 311 |
+
"""
|
| 312 |
+
Method to get the length of the data pool.
|
| 313 |
+
|
| 314 |
+
Returns
|
| 315 |
+
-------
|
| 316 |
+
int
|
| 317 |
+
The length of the data pool.
|
| 318 |
+
"""
|
| 319 |
+
...
|
| 320 |
+
|
| 321 |
+
def __iter__(self) -> Iterator[T]:
|
| 322 |
+
"""
|
| 323 |
+
Method to iterate over the data pool.
|
| 324 |
+
|
| 325 |
+
This method should return an iterator over the data pool.
|
| 326 |
+
|
| 327 |
+
Returns
|
| 328 |
+
-------
|
| 329 |
+
Iterator[T]
|
| 330 |
+
An iterator over the data pool.
|
| 331 |
+
"""
|
| 332 |
+
...
|
| 333 |
+
|
| 334 |
+
def append(self, item: T) -> None:
|
| 335 |
+
"""
|
| 336 |
+
Method to append a data structure to the data pool.
|
| 337 |
+
|
| 338 |
+
For persistent storage pools, this will actually mean that the
|
| 339 |
+
``item`` is serialized to a filesystem.
|
| 340 |
+
|
| 341 |
+
Parameters
|
| 342 |
+
----------
|
| 343 |
+
item: T
|
| 344 |
+
The data structure to append to the data pool.
|
| 345 |
+
"""
|
| 346 |
+
...
|
| 347 |
+
|
| 348 |
+
|
| 349 |
+
class ActiveLearningProtocol(Protocol):
|
| 350 |
+
"""
|
| 351 |
+
This protocol acts as a basis for all active learning protocols.
|
| 352 |
+
|
| 353 |
+
This ensures that all protocols have some common interface, for
|
| 354 |
+
example the ability to `attach` to another object for scope
|
| 355 |
+
management.
|
| 356 |
+
|
| 357 |
+
Attributes
|
| 358 |
+
----------
|
| 359 |
+
__protocol_name__: str
|
| 360 |
+
The name of the protocol. This is primarily used for `repr`
|
| 361 |
+
and `str` f-strings. This should be defined by concrete
|
| 362 |
+
implementations.
|
| 363 |
+
_args: dict[str, Any]
|
| 364 |
+
A dictionary of arguments that were used to instantiate the protocol.
|
| 365 |
+
This is used for serialization and deserialization of the protocol,
|
| 366 |
+
and follows the same pattern as the ``_args`` attribute of
|
| 367 |
+
``physicsnemo.Module``.
|
| 368 |
+
|
| 369 |
+
Methods
|
| 370 |
+
-------
|
| 371 |
+
attach(self, other: object) -> None:
|
| 372 |
+
This method is used to attach the current object to another,
|
| 373 |
+
allowing the protocol to access the attached object's scope.
|
| 374 |
+
The use case for this is to allow a protocol access to the
|
| 375 |
+
driver's scope to access dataset, model, etc. as needed.
|
| 376 |
+
This needs to be implemented by concrete implementations.
|
| 377 |
+
is_attached: bool
|
| 378 |
+
Whether the current object is attached to another object.
|
| 379 |
+
This is left abstract, as it depends on how ``attach`` is implemented.
|
| 380 |
+
logger: Logger
|
| 381 |
+
The logger for this protocol. This is used to log information
|
| 382 |
+
about the protocol's progress.
|
| 383 |
+
_setup_logger(self) -> None:
|
| 384 |
+
This method is used to setup the logger for the protocol.
|
| 385 |
+
The default implementation is to configure the logger similarly
|
| 386 |
+
to how ``physicsnemo`` loggers are configured.
|
| 387 |
+
"""
|
| 388 |
+
|
| 389 |
+
__protocol_name__: str
|
| 390 |
+
__protocol_type__: ActiveLearningPhase
|
| 391 |
+
_args: dict[str, Any]
|
| 392 |
+
|
| 393 |
+
def __new__(cls, *args: Any, **kwargs: Any) -> ActiveLearningProtocol:
|
| 394 |
+
"""
|
| 395 |
+
Wrapper for instantiating any subclass of `ActiveLearningProtocol`.
|
| 396 |
+
|
| 397 |
+
This method will use `inspect` to capture arguments and keyword
|
| 398 |
+
arguments that were used to instantiate the protocol, and stash
|
| 399 |
+
them into the `_args` attribute of the instance, following
|
| 400 |
+
what is done with `physicsnemo.Module`.
|
| 401 |
+
|
| 402 |
+
This approach is useful for reconstructing strategies from checkpoints.
|
| 403 |
+
|
| 404 |
+
Parameters
|
| 405 |
+
----------
|
| 406 |
+
args: Any
|
| 407 |
+
Arguments to pass to the protocol's constructor.
|
| 408 |
+
kwargs: Any
|
| 409 |
+
Keyword arguments to pass to the protocol's constructor.
|
| 410 |
+
|
| 411 |
+
Returns
|
| 412 |
+
-------
|
| 413 |
+
ActiveLearningProtocol
|
| 414 |
+
A new instance of the protocol class. The instance will have an
|
| 415 |
+
`_args` attribute that contains the keys `__name__`, `__module__`,
|
| 416 |
+
and `__args__` as metadata for the protocol.
|
| 417 |
+
"""
|
| 418 |
+
out = super().__new__(cls)
|
| 419 |
+
|
| 420 |
+
# Get signature of __init__ function
|
| 421 |
+
sig = inspect.signature(cls.__init__)
|
| 422 |
+
|
| 423 |
+
# Bind args and kwargs to signature
|
| 424 |
+
bound_args = sig.bind_partial(
|
| 425 |
+
*([None] + list(args)), **kwargs
|
| 426 |
+
) # Add None to account for self
|
| 427 |
+
bound_args.apply_defaults()
|
| 428 |
+
|
| 429 |
+
# Get args and kwargs (excluding self and unroll kwargs)
|
| 430 |
+
instantiate_args = {}
|
| 431 |
+
for param, (k, v) in zip(sig.parameters.values(), bound_args.arguments.items()):
|
| 432 |
+
# Skip self
|
| 433 |
+
if k == "self":
|
| 434 |
+
continue
|
| 435 |
+
|
| 436 |
+
# Add args and kwargs to instantiate_args
|
| 437 |
+
if param.kind == param.VAR_KEYWORD:
|
| 438 |
+
instantiate_args.update(v)
|
| 439 |
+
else:
|
| 440 |
+
instantiate_args[k] = v
|
| 441 |
+
|
| 442 |
+
# Store args needed for instantiation
|
| 443 |
+
out._args = {
|
| 444 |
+
"__name__": cls.__name__,
|
| 445 |
+
"__module__": cls.__module__,
|
| 446 |
+
"__args__": instantiate_args,
|
| 447 |
+
}
|
| 448 |
+
return out
|
| 449 |
+
|
| 450 |
+
def attach(self, other: object) -> None:
|
| 451 |
+
"""
|
| 452 |
+
This method is used to attach another object to the current protocol,
|
| 453 |
+
allowing the attached object to access the scope of this protocol.
|
| 454 |
+
The primary reason for this is to allow the protocol to access
|
| 455 |
+
things like the dataset, the learner model, etc. as needed.
|
| 456 |
+
|
| 457 |
+
Example use cases would be for a query strategy to access the ``unlabeled_pool``;
|
| 458 |
+
for a metrology strategy to access the ``validation_pool``, and for any
|
| 459 |
+
strategy to be able to access the surrogate/learner model.
|
| 460 |
+
|
| 461 |
+
This method can be as simple as setting ``self.driver = other``, but
|
| 462 |
+
is left abstract in case there are other potential use cases
|
| 463 |
+
where multiple protocols could share information.
|
| 464 |
+
|
| 465 |
+
Parameters
|
| 466 |
+
----------
|
| 467 |
+
other: object
|
| 468 |
+
The object to attach to.
|
| 469 |
+
"""
|
| 470 |
+
...
|
| 471 |
+
|
| 472 |
+
@property
|
| 473 |
+
def is_attached(self) -> bool:
|
| 474 |
+
"""
|
| 475 |
+
Property to check if the current object is already attached.
|
| 476 |
+
|
| 477 |
+
This is left abstract, as it depends on how ``attach`` is implemented.
|
| 478 |
+
|
| 479 |
+
Returns
|
| 480 |
+
-------
|
| 481 |
+
bool
|
| 482 |
+
True if the current object is attached, False otherwise.
|
| 483 |
+
"""
|
| 484 |
+
...
|
| 485 |
+
|
| 486 |
+
@property
|
| 487 |
+
def logger(self) -> Logger:
|
| 488 |
+
"""
|
| 489 |
+
Property to access the logger for this protocol.
|
| 490 |
+
|
| 491 |
+
If the logger has not been configured yet, the property
|
| 492 |
+
will call the `_setup_logger` method to configure it.
|
| 493 |
+
|
| 494 |
+
Returns
|
| 495 |
+
-------
|
| 496 |
+
Logger
|
| 497 |
+
The logger for this protocol.
|
| 498 |
+
"""
|
| 499 |
+
if not hasattr(self, "_logger"):
|
| 500 |
+
self._setup_logger()
|
| 501 |
+
return self._logger
|
| 502 |
+
|
| 503 |
+
@logger.setter
|
| 504 |
+
def logger(self, logger: Logger) -> None:
|
| 505 |
+
"""
|
| 506 |
+
Setter for the logger for this protocol.
|
| 507 |
+
|
| 508 |
+
Parameters
|
| 509 |
+
----------
|
| 510 |
+
logger: Logger
|
| 511 |
+
The logger to set for this protocol.
|
| 512 |
+
"""
|
| 513 |
+
self._logger = logger
|
| 514 |
+
|
| 515 |
+
def _setup_logger(self) -> None:
|
| 516 |
+
"""
|
| 517 |
+
Method to setup the logger for all active learning protocols.
|
| 518 |
+
|
| 519 |
+
Each protocol should have their own logger
|
| 520 |
+
"""
|
| 521 |
+
self.logger = logging.getLogger(
|
| 522 |
+
f"core.active_learning.{self.__protocol_name__}"
|
| 523 |
+
)
|
| 524 |
+
# Don't add handlers here - let the parent logger handle formatting
|
| 525 |
+
# This prevents duplicate console output
|
| 526 |
+
self.logger.setLevel(logging.WARNING)
|
| 527 |
+
|
| 528 |
+
@property
|
| 529 |
+
def strategy_dir(self) -> Path:
|
| 530 |
+
"""
|
| 531 |
+
Returns the directory where the underlying strategy can use
|
| 532 |
+
to persist data.
|
| 533 |
+
|
| 534 |
+
Depending on the strategy abstraction, further nesting may be
|
| 535 |
+
required (e.g active learning step index, phase, etc.).
|
| 536 |
+
|
| 537 |
+
Returns
|
| 538 |
+
-------
|
| 539 |
+
Path
|
| 540 |
+
The directory where the metrology strategy will persist
|
| 541 |
+
its records.
|
| 542 |
+
|
| 543 |
+
Raises
|
| 544 |
+
------
|
| 545 |
+
RuntimeError
|
| 546 |
+
If the metrology strategy is not attached to a driver yet.
|
| 547 |
+
"""
|
| 548 |
+
if not self.is_attached:
|
| 549 |
+
raise RuntimeError(
|
| 550 |
+
f"{self.__class__.__name__} is not attached to a driver yet."
|
| 551 |
+
)
|
| 552 |
+
path = (
|
| 553 |
+
self.driver.log_dir / str(self.__protocol_type__) / self.__class__.__name__
|
| 554 |
+
)
|
| 555 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 556 |
+
return path
|
| 557 |
+
|
| 558 |
+
@property
|
| 559 |
+
def checkpoint_dir(self) -> Path:
|
| 560 |
+
"""
|
| 561 |
+
Utility property for strategies to conveniently access the checkpoint directory.
|
| 562 |
+
|
| 563 |
+
This is useful for (de)serializing data tied to checkpointing.
|
| 564 |
+
|
| 565 |
+
Returns
|
| 566 |
+
-------
|
| 567 |
+
Path
|
| 568 |
+
The checkpoint directory, which includes the active learning step index.
|
| 569 |
+
|
| 570 |
+
Raises
|
| 571 |
+
------
|
| 572 |
+
RuntimeError
|
| 573 |
+
If the strategy is not attached to a driver yet.
|
| 574 |
+
"""
|
| 575 |
+
if not self.is_attached:
|
| 576 |
+
raise RuntimeError(
|
| 577 |
+
f"{self.__class__.__name__} is not attached to a driver yet."
|
| 578 |
+
)
|
| 579 |
+
path = (
|
| 580 |
+
self.driver.log_dir
|
| 581 |
+
/ "checkpoints"
|
| 582 |
+
/ f"step_{self.driver.active_learning_step_idx}"
|
| 583 |
+
)
|
| 584 |
+
path.mkdir(parents=True, exist_ok=True)
|
| 585 |
+
return path
|
| 586 |
+
|
| 587 |
+
|
| 588 |
+
class QueryStrategy(ActiveLearningProtocol):
|
| 589 |
+
"""
|
| 590 |
+
This protocol defines a query strategy for active learning.
|
| 591 |
+
|
| 592 |
+
A query strategy is responsible for selecting data points for labeling.
|
| 593 |
+
In the most general sense, concrete instances of this protocol
|
| 594 |
+
will specify how many samples to query, and the heuristics for
|
| 595 |
+
selecting samples.
|
| 596 |
+
|
| 597 |
+
Attributes
|
| 598 |
+
----------
|
| 599 |
+
max_samples: int
|
| 600 |
+
The maximum number of samples to query. This can be interpreted
|
| 601 |
+
as the exact number of samples to query, or as an upper limit
|
| 602 |
+
for querying methods that are threshold based.
|
| 603 |
+
"""
|
| 604 |
+
|
| 605 |
+
max_samples: int
|
| 606 |
+
__protocol_type__ = ActiveLearningPhase.QUERY
|
| 607 |
+
|
| 608 |
+
def sample(self, query_queue: AbstractQueue[T], *args: Any, **kwargs: Any) -> None:
|
| 609 |
+
"""
|
| 610 |
+
Method that implements the logic behind querying data to be labeled.
|
| 611 |
+
|
| 612 |
+
This method should be implemented by concrete implementations,
|
| 613 |
+
and assume that an active learning driver will pass a queue
|
| 614 |
+
for this method to enqueue data to be labeled.
|
| 615 |
+
|
| 616 |
+
Additional ``args`` and ``kwargs`` are passed to the method,
|
| 617 |
+
and can be used to pass additional information to the query strategy.
|
| 618 |
+
|
| 619 |
+
This method will enqueue in place, and should not return anything.
|
| 620 |
+
|
| 621 |
+
Parameters
|
| 622 |
+
----------
|
| 623 |
+
query_queue: AbstractQueue[T]
|
| 624 |
+
The queue to enqueue data to be labeled.
|
| 625 |
+
args: Any
|
| 626 |
+
Additional arguments to pass to the method.
|
| 627 |
+
kwargs: Any
|
| 628 |
+
Additional keyword arguments to pass to the method.
|
| 629 |
+
"""
|
| 630 |
+
...
|
| 631 |
+
|
| 632 |
+
def __call__(
|
| 633 |
+
self, query_queue: AbstractQueue[T], *args: Any, **kwargs: Any
|
| 634 |
+
) -> None:
|
| 635 |
+
"""
|
| 636 |
+
Syntactic sugar for the ``sample`` method.
|
| 637 |
+
|
| 638 |
+
This allows the object to be called as a function, and will pass
|
| 639 |
+
the arguments to the strategy's ``sample`` method.
|
| 640 |
+
|
| 641 |
+
Parameters
|
| 642 |
+
----------
|
| 643 |
+
query_queue: AbstractQueue[T]
|
| 644 |
+
The queue to enqueue data to be labeled.
|
| 645 |
+
args: Any
|
| 646 |
+
Additional arguments to pass to the method.
|
| 647 |
+
kwargs: Any
|
| 648 |
+
Additional keyword arguments to pass to the method.
|
| 649 |
+
"""
|
| 650 |
+
self.sample(query_queue, *args, **kwargs)
|
| 651 |
+
|
| 652 |
+
|
| 653 |
+
class LabelStrategy(ActiveLearningProtocol):
|
| 654 |
+
"""
|
| 655 |
+
This protocol defines a label strategy for active learning.
|
| 656 |
+
|
| 657 |
+
A label strategy is responsible for labeling data points; this may
|
| 658 |
+
be an simple Python function for demonstrating a concept, or an external,
|
| 659 |
+
potentially time consuming and complex, process.
|
| 660 |
+
|
| 661 |
+
Attributes
|
| 662 |
+
----------
|
| 663 |
+
__is_external_process__: bool
|
| 664 |
+
Whether the label strategy is running in an external process.
|
| 665 |
+
__provides_fields__: set[str]
|
| 666 |
+
The fields that the label strategy provides. This should be
|
| 667 |
+
set by concrete implementations, and should be used to write
|
| 668 |
+
and map labeled data to fields within the data structure ``T``.
|
| 669 |
+
"""
|
| 670 |
+
|
| 671 |
+
__is_external_process__: bool
|
| 672 |
+
__provides_fields__: set[str] | None = None
|
| 673 |
+
__protocol_type__ = ActiveLearningPhase.LABELING
|
| 674 |
+
|
| 675 |
+
def label(
|
| 676 |
+
self,
|
| 677 |
+
queue_to_label: AbstractQueue[T],
|
| 678 |
+
serialize_queue: AbstractQueue[T],
|
| 679 |
+
*args: Any,
|
| 680 |
+
**kwargs: Any,
|
| 681 |
+
) -> None:
|
| 682 |
+
"""
|
| 683 |
+
Method that implements the logic behind labeling data.
|
| 684 |
+
|
| 685 |
+
This method should be implemented by concrete implementations,
|
| 686 |
+
and assume that an active learning driver will pass a queue
|
| 687 |
+
for this method to dequeue data to be labeled.
|
| 688 |
+
|
| 689 |
+
Parameters
|
| 690 |
+
----------
|
| 691 |
+
queue_to_label: AbstractQueue[T]
|
| 692 |
+
Queue containing data structures to be labeled. Generally speaking,
|
| 693 |
+
this should be passed over after running query strateg(ies).
|
| 694 |
+
serialize_queue: AbstractQueue[T]
|
| 695 |
+
Queue for enqueing labeled data to be serialized.
|
| 696 |
+
args: Any
|
| 697 |
+
Additional arguments to pass to the method.
|
| 698 |
+
kwargs: Any
|
| 699 |
+
Additional keyword arguments to pass to the method.
|
| 700 |
+
"""
|
| 701 |
+
...
|
| 702 |
+
|
| 703 |
+
def __call__(
|
| 704 |
+
self,
|
| 705 |
+
queue_to_label: AbstractQueue[T],
|
| 706 |
+
serialize_queue: AbstractQueue[T],
|
| 707 |
+
*args: Any,
|
| 708 |
+
**kwargs: Any,
|
| 709 |
+
) -> None:
|
| 710 |
+
"""
|
| 711 |
+
Syntactic sugar for the ``label`` method.
|
| 712 |
+
|
| 713 |
+
This allows the object to be called as a function, and will pass
|
| 714 |
+
the arguments to the strategy's ``label`` method.
|
| 715 |
+
|
| 716 |
+
Parameters
|
| 717 |
+
----------
|
| 718 |
+
queue_to_label: AbstractQueue[T]
|
| 719 |
+
Queue containing data structures to be labeled.
|
| 720 |
+
serialize_queue: AbstractQueue[T]
|
| 721 |
+
Queue for enqueing labeled data to be serialized.
|
| 722 |
+
args: Any
|
| 723 |
+
Additional arguments to pass to the method.
|
| 724 |
+
kwargs: Any
|
| 725 |
+
Additional keyword arguments to pass to the method.
|
| 726 |
+
"""
|
| 727 |
+
self.label(queue_to_label, serialize_queue, *args, **kwargs)
|
| 728 |
+
|
| 729 |
+
|
| 730 |
+
class MetrologyStrategy(ActiveLearningProtocol):
|
| 731 |
+
"""
|
| 732 |
+
This protocol defines a metrology strategy for active learning.
|
| 733 |
+
|
| 734 |
+
A metrology strategy is responsible for assessing the improvements to the underlying
|
| 735 |
+
model, beyond simple validation metrics. This should reflect the application
|
| 736 |
+
requirements of the model, which may include running a simulation.
|
| 737 |
+
|
| 738 |
+
Attributes
|
| 739 |
+
----------
|
| 740 |
+
records: list[S]
|
| 741 |
+
A sequence of record data structures that records the
|
| 742 |
+
history of the active learning process, as viewed by
|
| 743 |
+
this particular metrology view.
|
| 744 |
+
"""
|
| 745 |
+
|
| 746 |
+
records: list[S]
|
| 747 |
+
__protocol_type__ = ActiveLearningPhase.METROLOGY
|
| 748 |
+
|
| 749 |
+
def append(self, record: S) -> None:
|
| 750 |
+
"""
|
| 751 |
+
Method to append a record to the metrology strategy.
|
| 752 |
+
|
| 753 |
+
Parameters
|
| 754 |
+
----------
|
| 755 |
+
record: S
|
| 756 |
+
The record to append to the metrology strategy.
|
| 757 |
+
"""
|
| 758 |
+
self.records.append(record)
|
| 759 |
+
|
| 760 |
+
def __len__(self) -> int:
|
| 761 |
+
"""
|
| 762 |
+
Method to get the length of the metrology strategy.
|
| 763 |
+
|
| 764 |
+
Returns
|
| 765 |
+
-------
|
| 766 |
+
int
|
| 767 |
+
The length of the metrology strategy.
|
| 768 |
+
"""
|
| 769 |
+
return len(self.records)
|
| 770 |
+
|
| 771 |
+
def serialize_records(
|
| 772 |
+
self, path: Path | None = None, *args: Any, **kwargs: Any
|
| 773 |
+
) -> None:
|
| 774 |
+
"""
|
| 775 |
+
Method to serialize the records of the metrology strategy.
|
| 776 |
+
|
| 777 |
+
This should be defined by a concrete implementation, which dictates
|
| 778 |
+
how the records are persisted, e.g. to a JSON file, database, etc.
|
| 779 |
+
|
| 780 |
+
The `strategy_dir` property can be used to determine the directory where
|
| 781 |
+
the records should be persisted.
|
| 782 |
+
|
| 783 |
+
Parameters
|
| 784 |
+
----------
|
| 785 |
+
path: Path | None
|
| 786 |
+
The path to serialize the records to. If not provided, the strategy
|
| 787 |
+
should provide a reasonable default, such as with the checkpointing
|
| 788 |
+
or within the corresponding metrology directory via `strategy_dir`.
|
| 789 |
+
args: Any
|
| 790 |
+
Additional arguments to pass to the method.
|
| 791 |
+
kwargs: Any
|
| 792 |
+
Additional keyword arguments to pass to the method.
|
| 793 |
+
"""
|
| 794 |
+
...
|
| 795 |
+
|
| 796 |
+
def load_records(self, path: Path | None = None, *args: Any, **kwargs: Any) -> None:
|
| 797 |
+
"""
|
| 798 |
+
Method to load the records of the metrology strategy, i.e.
|
| 799 |
+
the reverse of `serialize_records`.
|
| 800 |
+
|
| 801 |
+
This should be defined by a concrete implementation, which dictates
|
| 802 |
+
how the records are loaded, e.g. from a JSON file, database, etc.
|
| 803 |
+
|
| 804 |
+
If no path is provided, the strategy should load the latest records
|
| 805 |
+
as sensible defaults. The `records` attribute should then be overwritten
|
| 806 |
+
in-place.
|
| 807 |
+
|
| 808 |
+
Parameters
|
| 809 |
+
----------
|
| 810 |
+
path: Path | None
|
| 811 |
+
The path to load the records from. If not provided, the strategy
|
| 812 |
+
should load the latest records as sensible defaults.
|
| 813 |
+
args: Any
|
| 814 |
+
Additional arguments to pass to the method.
|
| 815 |
+
kwargs: Any
|
| 816 |
+
Additional keyword arguments to pass to the method.
|
| 817 |
+
"""
|
| 818 |
+
...
|
| 819 |
+
|
| 820 |
+
def compute(self, *args: Any, **kwargs: Any) -> None:
|
| 821 |
+
"""
|
| 822 |
+
Method to compute the metrology strategy. No data is passed to
|
| 823 |
+
this method, as it is expected that the data be drawn as needed
|
| 824 |
+
from various ``DataPool`` connected to the driver.
|
| 825 |
+
|
| 826 |
+
This method defines the core logic for computing a particular view
|
| 827 |
+
of performance by the underlying model on the data. Once computed,
|
| 828 |
+
the data needs to be formatted into a record data structure ``S``,
|
| 829 |
+
that is then appended to the ``records`` attribute.
|
| 830 |
+
|
| 831 |
+
Parameters
|
| 832 |
+
----------
|
| 833 |
+
args: Any
|
| 834 |
+
Additional arguments to pass to the method.
|
| 835 |
+
kwargs: Any
|
| 836 |
+
Additional keyword arguments to pass to the method.
|
| 837 |
+
"""
|
| 838 |
+
...
|
| 839 |
+
|
| 840 |
+
def __call__(self, *args: Any, **kwargs: Any) -> None:
|
| 841 |
+
"""
|
| 842 |
+
Syntactic sugar for the ``compute`` method.
|
| 843 |
+
|
| 844 |
+
This allows the object to be called as a function, and will pass
|
| 845 |
+
the arguments to the strategy's ``compute`` method.
|
| 846 |
+
|
| 847 |
+
Parameters
|
| 848 |
+
----------
|
| 849 |
+
args: Any
|
| 850 |
+
Additional arguments to pass to the method.
|
| 851 |
+
kwargs: Any
|
| 852 |
+
Additional keyword arguments to pass to the method.
|
| 853 |
+
"""
|
| 854 |
+
self.compute(*args, **kwargs)
|
| 855 |
+
|
| 856 |
+
def reset(self) -> None:
|
| 857 |
+
"""
|
| 858 |
+
Method to reset any stateful attributes of the metrology strategy.
|
| 859 |
+
|
| 860 |
+
By default, the ``records`` attribute is reset to an empty list.
|
| 861 |
+
"""
|
| 862 |
+
self.records = []
|
| 863 |
+
|
| 864 |
+
|
| 865 |
+
class TrainingProtocol(Protocol):
|
| 866 |
+
"""
|
| 867 |
+
This protocol defines the interface for training steps: given
|
| 868 |
+
a model and some input data, compute the reduced, differentiable
|
| 869 |
+
loss tensor and return it.
|
| 870 |
+
|
| 871 |
+
A concrete implementation can simply be a function with a signature that
|
| 872 |
+
matches what is defined in ``__call__``.
|
| 873 |
+
"""
|
| 874 |
+
|
| 875 |
+
def __call__(
|
| 876 |
+
self, model: Module, data: T, *args: Any, **kwargs: Any
|
| 877 |
+
) -> torch.Tensor:
|
| 878 |
+
"""
|
| 879 |
+
Implements the training logic for a single training sample or batch.
|
| 880 |
+
|
| 881 |
+
For a PhysicsNeMo ``Module`` with trainable parameters, the output
|
| 882 |
+
of this function should correspond to a PyTorch tensor that is
|
| 883 |
+
``backward``-ready. If there are any logging operations associated
|
| 884 |
+
with training, they should be performed within this function.
|
| 885 |
+
|
| 886 |
+
For ideal performance, this function should also be wrappable with
|
| 887 |
+
``StaticCaptureTraining`` for optimization.
|
| 888 |
+
|
| 889 |
+
Parameters
|
| 890 |
+
----------
|
| 891 |
+
model: Module
|
| 892 |
+
The model to train.
|
| 893 |
+
data: T
|
| 894 |
+
The data to train on. This data structure should comprise
|
| 895 |
+
both input and ground truths to compute the loss.
|
| 896 |
+
args: Any
|
| 897 |
+
Additional arguments to pass to the method.
|
| 898 |
+
kwargs: Any
|
| 899 |
+
Additional keyword arguments to pass to the method.
|
| 900 |
+
|
| 901 |
+
Returns
|
| 902 |
+
-------
|
| 903 |
+
torch.Tensor
|
| 904 |
+
The reduced, differentiable loss tensor.
|
| 905 |
+
|
| 906 |
+
Example
|
| 907 |
+
-------
|
| 908 |
+
Minimum viable implementation:
|
| 909 |
+
>>> import torch
|
| 910 |
+
>>> def training_step(model, data):
|
| 911 |
+
... output = model(data)
|
| 912 |
+
... loss = torch.sum(torch.pow(output - data, 2))
|
| 913 |
+
... return loss
|
| 914 |
+
"""
|
| 915 |
+
...
|
| 916 |
+
|
| 917 |
+
|
| 918 |
+
class ValidationProtocol(Protocol):
|
| 919 |
+
"""
|
| 920 |
+
This protocol defines the interface for validation steps: given
|
| 921 |
+
a model and some input data, compute metrics of interest and if
|
| 922 |
+
relevant to do so, log the results.
|
| 923 |
+
|
| 924 |
+
A concrete implementation can simply be a function with a signature that
|
| 925 |
+
matches what is defined in ``__call__``.
|
| 926 |
+
"""
|
| 927 |
+
|
| 928 |
+
def __call__(self, model: Module, data: T, *args: Any, **kwargs: Any) -> None:
|
| 929 |
+
"""
|
| 930 |
+
Implements the validation logic for a single sample or batch.
|
| 931 |
+
|
| 932 |
+
This method will be called in validation steps **only**, and not used
|
| 933 |
+
for training, query, or metrology steps. In those cases, implement the
|
| 934 |
+
``inference_step`` method instead.
|
| 935 |
+
|
| 936 |
+
This function should not return anything, but should contain the logic
|
| 937 |
+
for computing metrics of interest over a validation/test set. If there
|
| 938 |
+
are any logging operations that need to be performed, they should also
|
| 939 |
+
be performed here.
|
| 940 |
+
|
| 941 |
+
Depending on the type of model architecture, consider wrapping this method
|
| 942 |
+
with ``StaticCaptureEvaluateNoGrad`` for performance optimizations. This
|
| 943 |
+
should be used if the model does not require autograd as part of its
|
| 944 |
+
forward pass.
|
| 945 |
+
|
| 946 |
+
Parameters
|
| 947 |
+
----------
|
| 948 |
+
model: Module
|
| 949 |
+
The model to validate.
|
| 950 |
+
data: T
|
| 951 |
+
The data to validate on. This data structure should comprise
|
| 952 |
+
both input and ground truths to compute the loss.
|
| 953 |
+
args: Any
|
| 954 |
+
Additional arguments to pass to the method.
|
| 955 |
+
kwargs: Any
|
| 956 |
+
Additional keyword arguments to pass to the method.
|
| 957 |
+
|
| 958 |
+
Example
|
| 959 |
+
-------
|
| 960 |
+
Minimum viable implementation:
|
| 961 |
+
>>> import torch
|
| 962 |
+
>>> def validation_step(model, data):
|
| 963 |
+
... output = model(data)
|
| 964 |
+
... loss = torch.sum(torch.pow(output - data, 2))
|
| 965 |
+
... return loss
|
| 966 |
+
"""
|
| 967 |
+
...
|
| 968 |
+
|
| 969 |
+
|
| 970 |
+
class InferenceProtocol(Protocol):
|
| 971 |
+
"""
|
| 972 |
+
This protocol defines the interface for inference steps: given
|
| 973 |
+
a model and some input data, return the output of the model's forward pass.
|
| 974 |
+
|
| 975 |
+
A concrete implementation can simply be a function with a signature that
|
| 976 |
+
matches what is defined in ``__call__``.
|
| 977 |
+
"""
|
| 978 |
+
|
| 979 |
+
def __call__(self, model: Module, data: S, *args: Any, **kwargs: Any) -> Any:
|
| 980 |
+
"""
|
| 981 |
+
Implements the inference logic for a single sample or batch.
|
| 982 |
+
|
| 983 |
+
This method will be called in query and metrology steps, and should
|
| 984 |
+
return the output of the model's forward pass, likely minimally processed
|
| 985 |
+
so that any transformations can be performed by strategies that utilize
|
| 986 |
+
this protocol.
|
| 987 |
+
|
| 988 |
+
The key difference between this protocol and the other two training and
|
| 989 |
+
validation protocols is that the data structure ``S`` does not need
|
| 990 |
+
to contain ground truth values to compute a loss.
|
| 991 |
+
|
| 992 |
+
Similar to ``ValidationProtocol``, if relevant to the underlying architecture,
|
| 993 |
+
consider wrapping a concrete implementation of this protocol with
|
| 994 |
+
``StaticCaptureInference`` for performance optimizations.
|
| 995 |
+
|
| 996 |
+
Parameters
|
| 997 |
+
----------
|
| 998 |
+
model: Module
|
| 999 |
+
The model to infer on.
|
| 1000 |
+
data: S
|
| 1001 |
+
The data to infer on. This data structure should comprise
|
| 1002 |
+
only input values to compute the forward pass.
|
| 1003 |
+
args: Any
|
| 1004 |
+
Additional arguments to pass to the method.
|
| 1005 |
+
kwargs: Any
|
| 1006 |
+
Additional keyword arguments to pass to the method.
|
| 1007 |
+
|
| 1008 |
+
Returns
|
| 1009 |
+
-------
|
| 1010 |
+
Any
|
| 1011 |
+
The output of the model's forward pass.
|
| 1012 |
+
|
| 1013 |
+
Example
|
| 1014 |
+
-------
|
| 1015 |
+
Minimum viable implementation:
|
| 1016 |
+
>>> def inference_step(model, data):
|
| 1017 |
+
... output = model(data)
|
| 1018 |
+
... return output
|
| 1019 |
+
"""
|
| 1020 |
+
...
|
| 1021 |
+
|
| 1022 |
+
|
| 1023 |
+
class TrainingLoop(Protocol):
|
| 1024 |
+
"""
|
| 1025 |
+
Defines a protocol that implements a training loop.
|
| 1026 |
+
|
| 1027 |
+
This protocol is intended to be called within the active learning loop
|
| 1028 |
+
during the training phase, where the model is trained on a specified
|
| 1029 |
+
number of epochs or training steps, and optionally validated on a dataset.
|
| 1030 |
+
|
| 1031 |
+
If a ``LearnerProtocol`` is provided, then ``train_fn`` and ``validate_fn``
|
| 1032 |
+
become optional as they will be defined within the ``LearnerProtocol``. If
|
| 1033 |
+
they are provided, however, then they should override the ``LearnerProtocol``
|
| 1034 |
+
variants.
|
| 1035 |
+
|
| 1036 |
+
If graph capture/compilation is intended, then ``train_fn`` and ``validate_fn``
|
| 1037 |
+
should be wrapped with ``StaticCaptureTraining`` and ``StaticCaptureEvaluateNoGrad``,
|
| 1038 |
+
respectively.
|
| 1039 |
+
"""
|
| 1040 |
+
|
| 1041 |
+
def __call__(
|
| 1042 |
+
self,
|
| 1043 |
+
model: Module | LearnerProtocol,
|
| 1044 |
+
optimizer: Optimizer,
|
| 1045 |
+
train_dataloader: DataLoader,
|
| 1046 |
+
validation_dataloader: DataLoader | None = None,
|
| 1047 |
+
train_step_fn: TrainingProtocol | None = None,
|
| 1048 |
+
validate_step_fn: ValidationProtocol | None = None,
|
| 1049 |
+
max_epochs: int | None = None,
|
| 1050 |
+
max_train_steps: int | None = None,
|
| 1051 |
+
max_val_steps: int | None = None,
|
| 1052 |
+
lr_scheduler: _LRScheduler | None = None,
|
| 1053 |
+
device: str | torch.device | None = None,
|
| 1054 |
+
dtype: torch.dtype | None = None,
|
| 1055 |
+
*args: Any,
|
| 1056 |
+
**kwargs: Any,
|
| 1057 |
+
) -> None:
|
| 1058 |
+
"""
|
| 1059 |
+
Defines the signature for a minimal viable training loop.
|
| 1060 |
+
|
| 1061 |
+
The protocol defines a ``model`` with trainable parameters
|
| 1062 |
+
tracked by ``optimizer`` will go through multiple epochs or
|
| 1063 |
+
training steps. In the latter, the ``train_dataloader`` will be
|
| 1064 |
+
exhausted ``max_epochs`` times, while the mutually exclusive
|
| 1065 |
+
``max_train_steps`` will limit the number of training batches,
|
| 1066 |
+
which can be greater or less than the length of the ``train_dataloader``.
|
| 1067 |
+
|
| 1068 |
+
(Optional) Validation is intended to be performed either at the end of a training
|
| 1069 |
+
epoch, or when the maximum number of training steps is reached. The
|
| 1070 |
+
``max_val_steps`` parameter can be used to limit the number of batches to validate with
|
| 1071 |
+
on a per-epoch basis. Validation is only performed if a ``validate_step_fn`` is provided,
|
| 1072 |
+
alongside ``validation_dataloader``.
|
| 1073 |
+
|
| 1074 |
+
The pseudocode for training to ``max_epochs`` would look like this:
|
| 1075 |
+
|
| 1076 |
+
.. code-block:: python
|
| 1077 |
+
|
| 1078 |
+
max_epochs = 10
|
| 1079 |
+
for epoch in range(max_epochs):
|
| 1080 |
+
for train_idx, batch in enumerate(train_dataloader):
|
| 1081 |
+
optimizer.zero_grad()
|
| 1082 |
+
loss = train_step_fn(model, batch)
|
| 1083 |
+
loss.backward()
|
| 1084 |
+
optimizer.step()
|
| 1085 |
+
if train_idx + 1 == max_train_steps:
|
| 1086 |
+
break
|
| 1087 |
+
if validate_step_fn and validation_dataloader:
|
| 1088 |
+
for val_idx, batch in enumerate(validation_dataloader):
|
| 1089 |
+
validate_step_fn(model, batch)
|
| 1090 |
+
if val_idx + 1 == max_val_steps:
|
| 1091 |
+
break
|
| 1092 |
+
|
| 1093 |
+
The pseudocode for training with a ``LearnerProtocol`` would look like this:
|
| 1094 |
+
|
| 1095 |
+
.. code-block:: python
|
| 1096 |
+
|
| 1097 |
+
for epoch in range(max_epochs):
|
| 1098 |
+
for train_idx, batch in enumerate(train_dataloader):
|
| 1099 |
+
loss = model.training_step(batch)
|
| 1100 |
+
if train_idx + 1 == max_train_steps:
|
| 1101 |
+
break
|
| 1102 |
+
if validation_dataloader:
|
| 1103 |
+
for val_idx, batch in enumerate(validation_dataloader):
|
| 1104 |
+
model.validation_step(batch)
|
| 1105 |
+
if val_idx + 1 == max_val_steps:
|
| 1106 |
+
break
|
| 1107 |
+
|
| 1108 |
+
The key difference between specifying ``train_step_fn`` and ``LearnerProtocol``
|
| 1109 |
+
is that the former excludes the backward pass and optimizer step logic,
|
| 1110 |
+
whereas the latter encapsulates them.
|
| 1111 |
+
|
| 1112 |
+
The ``device`` and ``dtype`` parameters are used to specify the device and
|
| 1113 |
+
dtype to use for the training loop. If not provided, a reasonable default
|
| 1114 |
+
should be used (e.g. from ``torch.get_default_device()`` and ``torch.get_default_dtype()``).
|
| 1115 |
+
|
| 1116 |
+
Parameters
|
| 1117 |
+
----------
|
| 1118 |
+
model: Module | LearnerProtocol
|
| 1119 |
+
The model to train.
|
| 1120 |
+
optimizer: Optimizer
|
| 1121 |
+
The optimizer to use for training.
|
| 1122 |
+
train_dataloader: DataLoader
|
| 1123 |
+
The dataloader to use for training.
|
| 1124 |
+
validation_dataloader: DataLoader | None
|
| 1125 |
+
The dataloader to use for validation.
|
| 1126 |
+
train_step_fn: TrainingProtocol | None
|
| 1127 |
+
The training function to use for training. This is optional only
|
| 1128 |
+
if ``model`` implements the ``LearnerProtocol``. If this is
|
| 1129 |
+
provided and ``model`` implements the ``LearnerProtocol``,
|
| 1130 |
+
then this function will take precedence over the
|
| 1131 |
+
``LearnerProtocol.training_step`` method.
|
| 1132 |
+
validate_step_fn: ValidationProtocol | None
|
| 1133 |
+
The validation function to use for validation, only if it is
|
| 1134 |
+
provided alongside ``validation_dataloader``. If ``model`` implements
|
| 1135 |
+
the ``LearnerProtocol``, then this function will take precedence over
|
| 1136 |
+
the ``LearnerProtocol.validation_step`` method.
|
| 1137 |
+
max_epochs: int | None
|
| 1138 |
+
The maximum number of epochs to train for. Mututally exclusive
|
| 1139 |
+
with ``max_train_steps``.
|
| 1140 |
+
max_train_steps: int | None
|
| 1141 |
+
The maximum number of training steps to perform. Mututally exclusive
|
| 1142 |
+
with ``max_epochs``. If this value is greater than the length
|
| 1143 |
+
of ``train_dataloader``, then the training loop will recycle the data
|
| 1144 |
+
(i.e. more than one epoch) until the maximum number of training steps
|
| 1145 |
+
is reached.
|
| 1146 |
+
max_val_steps: int | None
|
| 1147 |
+
The maximum number of validation steps to perform per training
|
| 1148 |
+
epoch. If ``None``, then the full validation set will be used.
|
| 1149 |
+
lr_scheduler: _LRScheduler | None = None,
|
| 1150 |
+
The learning rate scheduler to use for training. If provided,
|
| 1151 |
+
this will be used to update the learning rate of the optimizer
|
| 1152 |
+
during training. If not provided, then the learning rate will
|
| 1153 |
+
not be adjusted within this function.
|
| 1154 |
+
device: str | torch.device | None = None
|
| 1155 |
+
The device to use for the training loop.
|
| 1156 |
+
dtype: torch.dtype | None = None
|
| 1157 |
+
The dtype to use for the training loop.
|
| 1158 |
+
args: Any
|
| 1159 |
+
Additional arguments to pass to the method.
|
| 1160 |
+
kwargs: Any
|
| 1161 |
+
Additional keyword arguments to pass to the method.
|
| 1162 |
+
"""
|
| 1163 |
+
...
|
| 1164 |
+
|
| 1165 |
+
|
| 1166 |
+
class LearnerProtocol:
|
| 1167 |
+
"""
|
| 1168 |
+
This protocol represents the learner part of an active learning
|
| 1169 |
+
algorithm.
|
| 1170 |
+
|
| 1171 |
+
This corresponds to a set of trainable parameters that are optimized,
|
| 1172 |
+
and subsequently used for inference and evaluation.
|
| 1173 |
+
|
| 1174 |
+
The required methods make this classes that implement this protocol
|
| 1175 |
+
provide all the required functionality across all active learning steps.
|
| 1176 |
+
Keep in mind that, similar to all other protocols in this module, this
|
| 1177 |
+
is merely the required interface and not the actual implementation.
|
| 1178 |
+
"""
|
| 1179 |
+
|
| 1180 |
+
def training_step(self, data: T, *args: Any, **kwargs: Any) -> None:
|
| 1181 |
+
"""
|
| 1182 |
+
Implements the training logic for a single batch.
|
| 1183 |
+
|
| 1184 |
+
This method will be called in training steps **only**, and not used
|
| 1185 |
+
for validation, query, or metrology steps. Specifically this means
|
| 1186 |
+
that gradients will be computed and used to update parameters.
|
| 1187 |
+
|
| 1188 |
+
In cases where gradients are not needed, consider implementing the
|
| 1189 |
+
``validation_step`` method instead.
|
| 1190 |
+
|
| 1191 |
+
This should mirror the ``TrainingProtocol`` definition, except that
|
| 1192 |
+
the model corresponds to this object.
|
| 1193 |
+
|
| 1194 |
+
Parameters
|
| 1195 |
+
----------
|
| 1196 |
+
data: T
|
| 1197 |
+
The data to train on. Typically assumed to be a batch
|
| 1198 |
+
of data.
|
| 1199 |
+
args: Any
|
| 1200 |
+
Additional arguments to pass to the method.
|
| 1201 |
+
kwargs: Any
|
| 1202 |
+
Additional keyword arguments to pass to the method.
|
| 1203 |
+
"""
|
| 1204 |
+
...
|
| 1205 |
+
|
| 1206 |
+
def validation_step(self, data: T, *args: Any, **kwargs: Any) -> None:
|
| 1207 |
+
"""
|
| 1208 |
+
Implements the validation logic for a single batch.
|
| 1209 |
+
|
| 1210 |
+
This can match the forward pass, without the need for weight updates.
|
| 1211 |
+
This method will be called in validation steps **only**, and not used
|
| 1212 |
+
for query or metrology steps. In those cases, implement the ``inference_step``
|
| 1213 |
+
method instead.
|
| 1214 |
+
|
| 1215 |
+
This should mirror the ``ValidationProtocol`` definition, except that
|
| 1216 |
+
the model corresponds to this object.
|
| 1217 |
+
|
| 1218 |
+
Parameters
|
| 1219 |
+
----------
|
| 1220 |
+
data: T
|
| 1221 |
+
The data to validate on. Typically assumed to be a batch
|
| 1222 |
+
of data.
|
| 1223 |
+
args: Any
|
| 1224 |
+
Additional arguments to pass to the method.
|
| 1225 |
+
kwargs: Any
|
| 1226 |
+
Additional keyword arguments to pass to the method.
|
| 1227 |
+
"""
|
| 1228 |
+
...
|
| 1229 |
+
|
| 1230 |
+
def inference_step(self, data: T | S, *args: Any, **kwargs: Any) -> None:
|
| 1231 |
+
"""
|
| 1232 |
+
Implements the inference logic for a single batch.
|
| 1233 |
+
|
| 1234 |
+
This can match the forward pass exactly, but provides an opportunity
|
| 1235 |
+
to differentiate (or lack thereof, with no pun intended). Specifically,
|
| 1236 |
+
this method will be called during query and metrology steps.
|
| 1237 |
+
|
| 1238 |
+
This should mirror the ``InferenceProtocol`` definition, except that
|
| 1239 |
+
the model corresponds to this object.
|
| 1240 |
+
|
| 1241 |
+
Parameters
|
| 1242 |
+
----------
|
| 1243 |
+
data: T
|
| 1244 |
+
The data to infer on. Typically assumed to be a batch
|
| 1245 |
+
of data.
|
| 1246 |
+
args: Any
|
| 1247 |
+
Additional arguments to pass to the method.
|
| 1248 |
+
kwargs: Any
|
| 1249 |
+
Additional keyword arguments to pass to the method.
|
| 1250 |
+
"""
|
| 1251 |
+
...
|
| 1252 |
+
|
| 1253 |
+
@property
|
| 1254 |
+
def parameters(self) -> Iterator[torch.Tensor]:
|
| 1255 |
+
"""
|
| 1256 |
+
Returns an iterator over the parameters of the learner.
|
| 1257 |
+
|
| 1258 |
+
If subclassing from `torch.nn.Module`, this will automatically return
|
| 1259 |
+
the parameters of the module.
|
| 1260 |
+
|
| 1261 |
+
Returns
|
| 1262 |
+
-------
|
| 1263 |
+
Iterator[torch.Tensor]
|
| 1264 |
+
An iterator over the parameters of the learner.
|
| 1265 |
+
"""
|
| 1266 |
+
...
|
| 1267 |
+
|
| 1268 |
+
def forward(self, *args: Any, **kwargs: Any) -> Any:
|
| 1269 |
+
"""
|
| 1270 |
+
Implements the forward pass for a single batch.
|
| 1271 |
+
|
| 1272 |
+
This method is called between all active learning steps, and should
|
| 1273 |
+
contain the logic for how a model ingests data and produces predictions.
|
| 1274 |
+
|
| 1275 |
+
Parameters
|
| 1276 |
+
----------
|
| 1277 |
+
args: Any
|
| 1278 |
+
Additional arguments to pass to the model.
|
| 1279 |
+
kwargs: Any
|
| 1280 |
+
Additional keyword arguments to pass to the model.
|
| 1281 |
+
|
| 1282 |
+
Returns
|
| 1283 |
+
-------
|
| 1284 |
+
Any
|
| 1285 |
+
The output of the model's forward pass.
|
| 1286 |
+
"""
|
| 1287 |
+
...
|
| 1288 |
+
|
| 1289 |
+
|
| 1290 |
+
class DriverProtocol:
|
| 1291 |
+
"""
|
| 1292 |
+
This protocol specifies the expected interface for an active learning
|
| 1293 |
+
driver: for a concrete implementation, refer to the `driver` module
|
| 1294 |
+
instead. The specification is provided mostly as a reference, and for
|
| 1295 |
+
ease of type hinting to prevent circular imports.
|
| 1296 |
+
|
| 1297 |
+
Attributes
|
| 1298 |
+
----------
|
| 1299 |
+
learner: LearnerProtocol
|
| 1300 |
+
The learner module that will be used as the surrogate within
|
| 1301 |
+
the active learning loop.
|
| 1302 |
+
query_strategies: list[QueryStrategy]
|
| 1303 |
+
The query strategies that will be used for selecting data points to label.
|
| 1304 |
+
A list of strategies can be included, and will sequentially be used to
|
| 1305 |
+
populate the ``query_queue`` that passes samples over to labeling.
|
| 1306 |
+
query_queue: AbstractQueue[T]
|
| 1307 |
+
The queue containing data samples to be labeled. ``QueryStrategy`` instances
|
| 1308 |
+
should enqueue samples to this queue.
|
| 1309 |
+
label_strategy: LabelStrategy | None
|
| 1310 |
+
The label strategy that will be used for labeling data points. In contrast
|
| 1311 |
+
to the other strategies, only a single label strategy is supported.
|
| 1312 |
+
This strategy will consume the ``query_queue`` and enqueue labeled data to
|
| 1313 |
+
the ``label_queue``.
|
| 1314 |
+
label_queue: AbstractQueue[T] | None
|
| 1315 |
+
The queue containing freshly labeled data. ``LabelStrategy`` instances
|
| 1316 |
+
should enqueue labeled data to this queue, and the driver will subsequently
|
| 1317 |
+
serialize data contained within this queue to a persistent format.
|
| 1318 |
+
metrology_strategies: list[MetrologyStrategy] | None
|
| 1319 |
+
The metrology strategies that will be used for assessing the performance
|
| 1320 |
+
of the surrogate. A list of strategies can be included, and will sequentially
|
| 1321 |
+
be used to populate the ``metrology_queue`` that passes data over to the
|
| 1322 |
+
learner.
|
| 1323 |
+
training_pool: DataPool[T]
|
| 1324 |
+
The pool of data to be used for training. This data will be used to train
|
| 1325 |
+
the underlying model, and is assumed to be mutable in that additional data
|
| 1326 |
+
can be added to the pool over the course of active learning.
|
| 1327 |
+
validation_pool: DataPool[T] | None
|
| 1328 |
+
The pool of data to be used for validation. This data will be used for both
|
| 1329 |
+
conventional validation, as well as for metrology. This dataset is considered
|
| 1330 |
+
to be immutable, and should not be modified over the course of active learning.
|
| 1331 |
+
This dataset is considered optional, as both validation and metrology are.
|
| 1332 |
+
unlabeled_pool: DataPool[T] | None
|
| 1333 |
+
An optional pool of data to be used for querying and labeling. If supplied,
|
| 1334 |
+
this dataset can be depleted by a query strategy to select data points for labeling.
|
| 1335 |
+
In principle, this could also represent a generative model, i.e. not just a static
|
| 1336 |
+
dataset, but at a high level represents a distribution of data.
|
| 1337 |
+
"""
|
| 1338 |
+
|
| 1339 |
+
learner: LearnerProtocol
|
| 1340 |
+
query_strategies: list[QueryStrategy]
|
| 1341 |
+
query_queue: AbstractQueue[T]
|
| 1342 |
+
label_strategy: LabelStrategy | None
|
| 1343 |
+
label_queue: AbstractQueue[T] | None
|
| 1344 |
+
metrology_strategies: list[MetrologyStrategy] | None
|
| 1345 |
+
training_pool: DataPool[T]
|
| 1346 |
+
validation_pool: DataPool[T] | None
|
| 1347 |
+
unlabeled_pool: DataPool[T] | None
|
| 1348 |
+
|
| 1349 |
+
def active_learning_step(self, *args: Any, **kwargs: Any) -> None:
|
| 1350 |
+
"""
|
| 1351 |
+
Implements the active learning step.
|
| 1352 |
+
|
| 1353 |
+
This step performs a single pass of the active learning loop, with the
|
| 1354 |
+
intended order being: training, metrology, query, labeling, with
|
| 1355 |
+
the metrology and labeling steps being optional.
|
| 1356 |
+
|
| 1357 |
+
Parameters
|
| 1358 |
+
----------
|
| 1359 |
+
args: Any
|
| 1360 |
+
Additional arguments to pass to the method.
|
| 1361 |
+
kwargs: Any
|
| 1362 |
+
Additional keyword arguments to pass to the method.
|
| 1363 |
+
"""
|
| 1364 |
+
...
|
| 1365 |
+
|
| 1366 |
+
def _setup_logger(self) -> None:
|
| 1367 |
+
"""
|
| 1368 |
+
Sets up the logger for the driver.
|
| 1369 |
+
|
| 1370 |
+
The intended concrete method should account for the ability to
|
| 1371 |
+
scope logging, such that things like active learning iteration
|
| 1372 |
+
counts, etc. can be logged.
|
| 1373 |
+
"""
|
| 1374 |
+
...
|
| 1375 |
+
|
| 1376 |
+
def attach_strategies(self) -> None:
|
| 1377 |
+
"""
|
| 1378 |
+
Attaches all provided strategies.
|
| 1379 |
+
|
| 1380 |
+
This method relies on the ``attach`` method of the strategies, which
|
| 1381 |
+
will subsequently give the strategy access to the driver's scope.
|
| 1382 |
+
|
| 1383 |
+
Example use cases would be for any strategy (apart from label strategy)
|
| 1384 |
+
to access the underlying model (``LearnerProtocol``); for a query
|
| 1385 |
+
strategy to access the ``unlabeled_pool``; for a metrology strategy
|
| 1386 |
+
to access the ``validation_pool``.
|
| 1387 |
+
"""
|
| 1388 |
+
for strategy in self.query_strategies:
|
| 1389 |
+
strategy.attach(self)
|
| 1390 |
+
if self.label_strategy:
|
| 1391 |
+
self.label_strategy.attach(self)
|
| 1392 |
+
if self.metrology_strategies:
|
| 1393 |
+
for strategy in self.metrology_strategies:
|
| 1394 |
+
strategy.attach(self)
|
physics_mcp/source/physicsnemo/constants.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
constant values used by PhysicsNeMo
|
| 19 |
+
"""
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
# string used to determine derivatives
|
| 25 |
+
diff_str: str = "__"
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def diff(y: str, x: str, degree: int = 1) -> str:
|
| 29 |
+
"""Function to apply diff string"""
|
| 30 |
+
return diff_str.join([y] + degree * [x])
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
# for changing to float16 or float64
|
| 34 |
+
tf_dt = torch.float32
|
| 35 |
+
np_dt = np.float32
|
| 36 |
+
|
| 37 |
+
# tensorboard naming
|
| 38 |
+
TF_SUMMARY = False
|
| 39 |
+
|
| 40 |
+
# Pytorch Version for which JIT will be default on
|
| 41 |
+
# Torch version of NGC container 22.08
|
| 42 |
+
JIT_PYTORCH_VERSION = "1.13.0a0+d321be6"
|
| 43 |
+
|
| 44 |
+
# No scaling is needed if using NO_OP_SCALE
|
| 45 |
+
NO_OP_SCALE = (0.0, 1.0)
|
| 46 |
+
|
| 47 |
+
# If using NO_OP_NORM, it is effectively doing no normalization
|
| 48 |
+
NO_OP_NORM = (-1.0, 1.0)
|
physics_mcp/source/physicsnemo/datapipes/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
physics_mcp/source/physicsnemo/datapipes/benchmarks/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
physics_mcp/source/physicsnemo/datapipes/benchmarks/darcy.py
ADDED
|
@@ -0,0 +1,322 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import sys
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Dict, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import warp as wp
|
| 24 |
+
|
| 25 |
+
from ..datapipe import Datapipe
|
| 26 |
+
from ..meta import DatapipeMetaData
|
| 27 |
+
from .kernels.finite_difference import (
|
| 28 |
+
darcy_mgrid_jacobi_iterative_batched_2d,
|
| 29 |
+
mgrid_inf_residual_batched_2d,
|
| 30 |
+
)
|
| 31 |
+
from .kernels.initialization import init_uniform_random_4d
|
| 32 |
+
from .kernels.utils import (
|
| 33 |
+
bilinear_upsample_batched_2d,
|
| 34 |
+
fourier_to_array_batched_2d,
|
| 35 |
+
threshold_3d,
|
| 36 |
+
)
|
| 37 |
+
|
| 38 |
+
Tensor = torch.Tensor
|
| 39 |
+
# TODO unsure if better to remove this. Keeping this in for now
|
| 40 |
+
wp.init()
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
@dataclass
|
| 44 |
+
class MetaData(DatapipeMetaData):
|
| 45 |
+
name: str = "Darcy2D"
|
| 46 |
+
# Optimization
|
| 47 |
+
auto_device: bool = True
|
| 48 |
+
cuda_graphs: bool = True
|
| 49 |
+
# Parallel
|
| 50 |
+
ddp_sharding: bool = False
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
class Darcy2D(Datapipe):
|
| 54 |
+
"""2D Darcy flow benchmark problem datapipe.
|
| 55 |
+
|
| 56 |
+
This datapipe continuously generates solutions to the 2D Darcy equation with variable
|
| 57 |
+
permeability. All samples are generated on the fly and is meant to be a benchmark
|
| 58 |
+
problem for testing data driven models. Permeability is drawn from a random Fourier
|
| 59 |
+
series and threshold it to give a piecewise constant function. The solution is obtained
|
| 60 |
+
using a GPU enabled multi-grid Jacobi iterative method.
|
| 61 |
+
|
| 62 |
+
Parameters
|
| 63 |
+
----------
|
| 64 |
+
resolution : int, optional
|
| 65 |
+
Resolution to run simulation at, by default 256
|
| 66 |
+
batch_size : int, optional
|
| 67 |
+
Batch size of simulations, by default 64
|
| 68 |
+
nr_permeability_freq : int, optional
|
| 69 |
+
Number of frequencies to use for generating random permeability. Higher values
|
| 70 |
+
will give higher freq permeability fields., by default 5
|
| 71 |
+
max_permeability : float, optional
|
| 72 |
+
Max permeability, by default 2.0
|
| 73 |
+
min_permeability : float, optional
|
| 74 |
+
Min permeability, by default 0.5
|
| 75 |
+
max_iterations : int, optional
|
| 76 |
+
Maximum iterations to use for each multi-grid, by default 30000
|
| 77 |
+
convergence_threshold : float, optional
|
| 78 |
+
Solver L-Infinity convergence threshold, by default 1e-6
|
| 79 |
+
iterations_per_convergence_check : int, optional
|
| 80 |
+
Number of Jacobi iterations to run before checking convergence, by default 1000
|
| 81 |
+
nr_multigrids : int, optional
|
| 82 |
+
Number of multi-grid levels, by default 4
|
| 83 |
+
normaliser : Union[Dict[str, Tuple[float, float]], None], optional
|
| 84 |
+
Dictionary with keys `permeability` and `darcy`. The values for these keys are two floats corresponding to mean and std `(mean, std)`.
|
| 85 |
+
device : Union[str, torch.device], optional
|
| 86 |
+
Device for datapipe to run place data on, by default "cuda"
|
| 87 |
+
|
| 88 |
+
Raises
|
| 89 |
+
------
|
| 90 |
+
ValueError
|
| 91 |
+
Incompatable multi-grid and resolution settings
|
| 92 |
+
"""
|
| 93 |
+
|
| 94 |
+
def __init__(
|
| 95 |
+
self,
|
| 96 |
+
resolution: int = 256,
|
| 97 |
+
batch_size: int = 64,
|
| 98 |
+
nr_permeability_freq: int = 5,
|
| 99 |
+
max_permeability: float = 2.0,
|
| 100 |
+
min_permeability: float = 0.5,
|
| 101 |
+
max_iterations: int = 30000,
|
| 102 |
+
convergence_threshold: float = 1e-6,
|
| 103 |
+
iterations_per_convergence_check: int = 1000,
|
| 104 |
+
nr_multigrids: int = 4,
|
| 105 |
+
normaliser: Union[Dict[str, Tuple[float, float]], None] = None,
|
| 106 |
+
device: Union[str, torch.device] = "cuda",
|
| 107 |
+
):
|
| 108 |
+
super().__init__(meta=MetaData())
|
| 109 |
+
|
| 110 |
+
# simulation params
|
| 111 |
+
self.resolution = resolution
|
| 112 |
+
self.batch_size = batch_size
|
| 113 |
+
self.nr_permeability_freq = nr_permeability_freq
|
| 114 |
+
self.max_permeability = max_permeability
|
| 115 |
+
self.min_permeability = min_permeability
|
| 116 |
+
self.max_iterations = max_iterations
|
| 117 |
+
self.convergence_threshold = convergence_threshold
|
| 118 |
+
self.iterations_per_convergence_check = iterations_per_convergence_check
|
| 119 |
+
self.nr_multigrids = nr_multigrids
|
| 120 |
+
self.normaliser = normaliser
|
| 121 |
+
|
| 122 |
+
# check normaliser keys
|
| 123 |
+
if self.normaliser is not None:
|
| 124 |
+
if not {"permeability", "darcy"}.issubset(set(self.normaliser.keys())):
|
| 125 |
+
raise ValueError(
|
| 126 |
+
"normaliser need to have keys permeability and darcy with mean and std"
|
| 127 |
+
)
|
| 128 |
+
|
| 129 |
+
# Set up device for warp, warp has same naming convention as torch.
|
| 130 |
+
if isinstance(device, torch.device):
|
| 131 |
+
device = str(device)
|
| 132 |
+
self.device = device
|
| 133 |
+
|
| 134 |
+
# spatial dims
|
| 135 |
+
self.dx = 1.0 / (self.resolution + 1) # pad edges by 1 for multi-grid
|
| 136 |
+
self.dim = (self.batch_size, self.resolution + 1, self.resolution + 1)
|
| 137 |
+
self.fourier_dim = (
|
| 138 |
+
4,
|
| 139 |
+
self.batch_size,
|
| 140 |
+
self.nr_permeability_freq,
|
| 141 |
+
self.nr_permeability_freq,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# assert resolution is compatible with multi-grid method
|
| 145 |
+
if (resolution % 2 ** (nr_multigrids - 1)) != 0:
|
| 146 |
+
raise ValueError("Resolution is incompatible with number of sub grids.")
|
| 147 |
+
|
| 148 |
+
# allocate arrays for constructing dataset
|
| 149 |
+
self.darcy0 = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 150 |
+
self.darcy1 = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 151 |
+
self.permeability = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 152 |
+
self.rand_fourier = wp.zeros(self.fourier_dim, dtype=float, device=self.device)
|
| 153 |
+
self.inf_residual = wp.zeros([1], dtype=float, device=self.device)
|
| 154 |
+
|
| 155 |
+
# Output tenors
|
| 156 |
+
self.output_k = None
|
| 157 |
+
self.output_p = None
|
| 158 |
+
|
| 159 |
+
def initialize_batch(self) -> None:
|
| 160 |
+
"""Initializes arrays for new batch of simulations"""
|
| 161 |
+
|
| 162 |
+
# initialize permeability
|
| 163 |
+
self.permeability.zero_()
|
| 164 |
+
seed = np.random.randint(np.iinfo(np.uint64).max, dtype=np.uint64)
|
| 165 |
+
wp.launch(
|
| 166 |
+
kernel=init_uniform_random_4d,
|
| 167 |
+
dim=self.fourier_dim,
|
| 168 |
+
inputs=[self.rand_fourier, -1.0, 1.0, seed],
|
| 169 |
+
device=self.device,
|
| 170 |
+
)
|
| 171 |
+
wp.launch(
|
| 172 |
+
kernel=fourier_to_array_batched_2d,
|
| 173 |
+
dim=self.dim,
|
| 174 |
+
inputs=[
|
| 175 |
+
self.permeability,
|
| 176 |
+
self.rand_fourier,
|
| 177 |
+
self.nr_permeability_freq,
|
| 178 |
+
self.resolution,
|
| 179 |
+
self.resolution,
|
| 180 |
+
],
|
| 181 |
+
device=self.device,
|
| 182 |
+
)
|
| 183 |
+
wp.launch(
|
| 184 |
+
kernel=threshold_3d,
|
| 185 |
+
dim=self.dim,
|
| 186 |
+
inputs=[
|
| 187 |
+
self.permeability,
|
| 188 |
+
0.0,
|
| 189 |
+
self.min_permeability,
|
| 190 |
+
self.max_permeability,
|
| 191 |
+
],
|
| 192 |
+
device=self.device,
|
| 193 |
+
)
|
| 194 |
+
|
| 195 |
+
# zero darcy arrays
|
| 196 |
+
self.darcy0.zero_()
|
| 197 |
+
self.darcy1.zero_()
|
| 198 |
+
|
| 199 |
+
def generate_batch(self) -> None:
|
| 200 |
+
"""Solve for new batch of simulations"""
|
| 201 |
+
|
| 202 |
+
# initialize tensors with random permeability
|
| 203 |
+
self.initialize_batch()
|
| 204 |
+
|
| 205 |
+
# run solver
|
| 206 |
+
for res in range(self.nr_multigrids):
|
| 207 |
+
# calculate grid reduction factor and reduced dim
|
| 208 |
+
grid_reduction_factor = 2 ** (self.nr_multigrids - res - 1)
|
| 209 |
+
if grid_reduction_factor > 1:
|
| 210 |
+
multigrid_dim = tuple(
|
| 211 |
+
[self.batch_size] + 2 * [(self.resolution) // grid_reduction_factor]
|
| 212 |
+
)
|
| 213 |
+
else:
|
| 214 |
+
multigrid_dim = self.dim
|
| 215 |
+
|
| 216 |
+
# run till max steps is reached
|
| 217 |
+
for k in range(
|
| 218 |
+
self.max_iterations // self.iterations_per_convergence_check
|
| 219 |
+
):
|
| 220 |
+
# run jacobi iterations
|
| 221 |
+
for s in range(self.iterations_per_convergence_check):
|
| 222 |
+
# iterate solver
|
| 223 |
+
wp.launch(
|
| 224 |
+
kernel=darcy_mgrid_jacobi_iterative_batched_2d,
|
| 225 |
+
dim=multigrid_dim,
|
| 226 |
+
inputs=[
|
| 227 |
+
self.darcy0,
|
| 228 |
+
self.darcy1,
|
| 229 |
+
self.permeability,
|
| 230 |
+
1.0,
|
| 231 |
+
self.dim[1],
|
| 232 |
+
self.dim[2],
|
| 233 |
+
self.dx,
|
| 234 |
+
grid_reduction_factor,
|
| 235 |
+
],
|
| 236 |
+
device=self.device,
|
| 237 |
+
)
|
| 238 |
+
|
| 239 |
+
# swap buffers
|
| 240 |
+
(self.darcy0, self.darcy1) = (self.darcy1, self.darcy0)
|
| 241 |
+
|
| 242 |
+
# compute residual
|
| 243 |
+
self.inf_residual.zero_()
|
| 244 |
+
wp.launch(
|
| 245 |
+
kernel=mgrid_inf_residual_batched_2d,
|
| 246 |
+
dim=multigrid_dim,
|
| 247 |
+
inputs=[
|
| 248 |
+
self.darcy0,
|
| 249 |
+
self.darcy1,
|
| 250 |
+
self.inf_residual,
|
| 251 |
+
grid_reduction_factor,
|
| 252 |
+
],
|
| 253 |
+
device=self.device,
|
| 254 |
+
)
|
| 255 |
+
normalized_inf_residual = self.inf_residual.numpy()[0]
|
| 256 |
+
|
| 257 |
+
# check if converged
|
| 258 |
+
if normalized_inf_residual < (
|
| 259 |
+
self.convergence_threshold * grid_reduction_factor
|
| 260 |
+
):
|
| 261 |
+
break
|
| 262 |
+
|
| 263 |
+
# upsample to higher resolution
|
| 264 |
+
if grid_reduction_factor > 1:
|
| 265 |
+
wp.launch(
|
| 266 |
+
kernel=bilinear_upsample_batched_2d,
|
| 267 |
+
dim=self.dim,
|
| 268 |
+
inputs=[
|
| 269 |
+
self.darcy0,
|
| 270 |
+
self.dim[1],
|
| 271 |
+
self.dim[2],
|
| 272 |
+
grid_reduction_factor,
|
| 273 |
+
],
|
| 274 |
+
device=self.device,
|
| 275 |
+
)
|
| 276 |
+
|
| 277 |
+
def __iter__(self) -> Tuple[Tensor, Tensor]:
|
| 278 |
+
"""
|
| 279 |
+
Yields
|
| 280 |
+
------
|
| 281 |
+
Iterator[Tuple[Tensor, Tensor]]
|
| 282 |
+
Infinite iterator that returns a batch of (permeability, darcy pressure)
|
| 283 |
+
fields of size [batch, resolution, resolution]
|
| 284 |
+
"""
|
| 285 |
+
# infinite generator
|
| 286 |
+
while True:
|
| 287 |
+
# run simulation
|
| 288 |
+
self.generate_batch()
|
| 289 |
+
|
| 290 |
+
# convert warp arrays to pytorch
|
| 291 |
+
permeability = wp.to_torch(self.permeability)
|
| 292 |
+
darcy = wp.to_torch(self.darcy0)
|
| 293 |
+
|
| 294 |
+
# add channel dims
|
| 295 |
+
permeability = torch.unsqueeze(permeability, axis=1)
|
| 296 |
+
darcy = torch.unsqueeze(darcy, axis=1)
|
| 297 |
+
|
| 298 |
+
# crop edges by 1 from multi-grid TODO messy
|
| 299 |
+
permeability = permeability[:, :, : self.resolution, : self.resolution]
|
| 300 |
+
darcy = darcy[:, :, : self.resolution, : self.resolution]
|
| 301 |
+
|
| 302 |
+
# normalize values
|
| 303 |
+
if self.normaliser is not None:
|
| 304 |
+
permeability = (
|
| 305 |
+
permeability - self.normaliser["permeability"][0]
|
| 306 |
+
) / self.normaliser["permeability"][1]
|
| 307 |
+
darcy = (darcy - self.normaliser["darcy"][0]) / self.normaliser[
|
| 308 |
+
"darcy"
|
| 309 |
+
][1]
|
| 310 |
+
|
| 311 |
+
# CUDA graphs static copies
|
| 312 |
+
if self.output_k is None:
|
| 313 |
+
self.output_k = permeability
|
| 314 |
+
self.output_p = darcy
|
| 315 |
+
else:
|
| 316 |
+
self.output_k.data.copy_(permeability)
|
| 317 |
+
self.output_p.data.copy_(darcy)
|
| 318 |
+
|
| 319 |
+
yield {"permeability": self.output_k, "darcy": self.output_p}
|
| 320 |
+
|
| 321 |
+
def __len__(self):
|
| 322 |
+
return sys.maxsize
|
physics_mcp/source/physicsnemo/datapipes/benchmarks/kelvin_helmholtz.py
ADDED
|
@@ -0,0 +1,436 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import sys
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Dict, Tuple, Union
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
import warp as wp
|
| 24 |
+
|
| 25 |
+
from ..datapipe import Datapipe
|
| 26 |
+
from ..meta import DatapipeMetaData
|
| 27 |
+
from .kernels.finite_volume import (
|
| 28 |
+
euler_apply_flux_batched_2d,
|
| 29 |
+
euler_conserved_to_primitive_batched_2d,
|
| 30 |
+
euler_extrapolation_batched_2d,
|
| 31 |
+
euler_get_flux_batched_2d,
|
| 32 |
+
euler_primitive_to_conserved_batched_2d,
|
| 33 |
+
initialize_kelvin_helmoltz_batched_2d,
|
| 34 |
+
)
|
| 35 |
+
from .kernels.initialization import init_uniform_random_2d
|
| 36 |
+
|
| 37 |
+
Tensor = torch.Tensor
|
| 38 |
+
# TODO unsure if better to remove this
|
| 39 |
+
wp.init()
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
@dataclass
|
| 43 |
+
class MetaData(DatapipeMetaData):
|
| 44 |
+
name: str = "KelvinHelmholtz2D"
|
| 45 |
+
# Optimization
|
| 46 |
+
auto_device: bool = True
|
| 47 |
+
cuda_graphs: bool = True
|
| 48 |
+
# Parallel
|
| 49 |
+
ddp_sharding: bool = False
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class KelvinHelmholtz2D(Datapipe):
|
| 53 |
+
"""Kelvin-Helmholtz instability benchmark problem datapipe.
|
| 54 |
+
|
| 55 |
+
This datapipe continuously generates samples with random initial conditions. All samples
|
| 56 |
+
are generated on the fly and is meant to be a benchmark problem for testing data driven
|
| 57 |
+
models. Initial conditions are given in the form of small perturbations. The solution
|
| 58 |
+
is obtained using a GPU enabled Finite Volume Method.
|
| 59 |
+
|
| 60 |
+
Parameters
|
| 61 |
+
----------
|
| 62 |
+
resolution : int, optional
|
| 63 |
+
Resolution to run simulation at, by default 512
|
| 64 |
+
batch_size : int, optional
|
| 65 |
+
Batch size of simulations, by default 16
|
| 66 |
+
seq_length : int, optional
|
| 67 |
+
Sequence length of output samples, by default 8
|
| 68 |
+
nr_perturbation_freq : int, optional
|
| 69 |
+
Number of frequencies to use for generating random initial perturbations, by default 5
|
| 70 |
+
perturbation_range : float, optional
|
| 71 |
+
Range to use for random perturbations. This value will be the max amplitude of the
|
| 72 |
+
initial perturbation, by default 0.1
|
| 73 |
+
nr_snapshots : int, optional
|
| 74 |
+
Number of snapshots of simulation to generate for data generation. This will
|
| 75 |
+
control how long the simulation is run for, by default 256
|
| 76 |
+
iteration_per_snapshot : int, optional
|
| 77 |
+
Number of finite volume steps to take between each snapshot. Each step size is
|
| 78 |
+
fixed as the smallest possible value that satisfies the Courant-Friedrichs-Lewy
|
| 79 |
+
condition, by default 32
|
| 80 |
+
gamma : float, optional
|
| 81 |
+
Heat capacity ratio, by default 5.0/3.0
|
| 82 |
+
normaliser : Union[Dict[str, Tuple[float, float]], None], optional
|
| 83 |
+
Dictionary with keys `density`, `velocity`, and `pressure`. The values for these keys are two floats corresponding to mean and std `(mean, std)`.
|
| 84 |
+
device : Union[str, torch.device], optional
|
| 85 |
+
Device for datapipe to run place data on, by default "cuda"
|
| 86 |
+
"""
|
| 87 |
+
|
| 88 |
+
def __init__(
|
| 89 |
+
self,
|
| 90 |
+
resolution: int = 512,
|
| 91 |
+
batch_size: int = 16,
|
| 92 |
+
seq_length: int = 8,
|
| 93 |
+
nr_perturbation_freq: int = 5,
|
| 94 |
+
perturbation_range: float = 0.1,
|
| 95 |
+
nr_snapshots: int = 256,
|
| 96 |
+
iteration_per_snapshot: int = 32,
|
| 97 |
+
gamma: float = 5.0 / 3.0,
|
| 98 |
+
normaliser: Union[Dict[str, Tuple[float, float]], None] = None,
|
| 99 |
+
device: Union[str, torch.device] = "cuda",
|
| 100 |
+
):
|
| 101 |
+
super().__init__(meta=MetaData())
|
| 102 |
+
|
| 103 |
+
# simulation params
|
| 104 |
+
self.resolution = resolution
|
| 105 |
+
self.batch_size = batch_size
|
| 106 |
+
self.seq_length = seq_length
|
| 107 |
+
self.nr_perturbation_freq = nr_perturbation_freq
|
| 108 |
+
self.perturbation_range = perturbation_range
|
| 109 |
+
self.nr_snapshots = nr_snapshots
|
| 110 |
+
self.iteration_per_snapshot = iteration_per_snapshot
|
| 111 |
+
self.gamma = gamma
|
| 112 |
+
self.courant_fac = 0.4 # hard set
|
| 113 |
+
self.normaliser = normaliser
|
| 114 |
+
|
| 115 |
+
# check normaliser keys
|
| 116 |
+
if self.normaliser is not None:
|
| 117 |
+
if not {"density", "velocity", "pressure"}.issubset(
|
| 118 |
+
set(self.normaliser.keys())
|
| 119 |
+
):
|
| 120 |
+
raise ValueError(
|
| 121 |
+
"normaliser need to have keys `density`, `velocity` and `pressure` with mean and std"
|
| 122 |
+
)
|
| 123 |
+
|
| 124 |
+
# Set up device for warp, warp has same naming convention as torch.
|
| 125 |
+
if isinstance(device, torch.device):
|
| 126 |
+
device = str(device)
|
| 127 |
+
self.device = device
|
| 128 |
+
|
| 129 |
+
# spatial dims
|
| 130 |
+
self.dx = 1.0 / resolution
|
| 131 |
+
self.dt = (
|
| 132 |
+
self.courant_fac * self.dx / (np.sqrt(self.gamma * 5.0) + 2.0)
|
| 133 |
+
) # hard set to smallest possible step needed
|
| 134 |
+
self.vol = self.dx**2
|
| 135 |
+
self.dim = (self.batch_size, self.resolution, self.resolution)
|
| 136 |
+
|
| 137 |
+
# allocate array for initial freq perturbation
|
| 138 |
+
self.w = wp.zeros(
|
| 139 |
+
(self.batch_size, self.nr_perturbation_freq),
|
| 140 |
+
dtype=float,
|
| 141 |
+
device=self.device,
|
| 142 |
+
)
|
| 143 |
+
|
| 144 |
+
# allocate conservation quantities
|
| 145 |
+
self.mass = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 146 |
+
self.mom = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
|
| 147 |
+
self.e = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 148 |
+
|
| 149 |
+
# allocate primitive quantities
|
| 150 |
+
self.rho = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 151 |
+
self.vel = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
|
| 152 |
+
self.p = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 153 |
+
|
| 154 |
+
# allocate flux values for computation
|
| 155 |
+
self.mass_flux_x = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 156 |
+
self.mass_flux_y = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 157 |
+
self.mom_flux_x = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
|
| 158 |
+
self.mom_flux_y = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
|
| 159 |
+
self.e_flux_x = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 160 |
+
self.e_flux_y = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 161 |
+
|
| 162 |
+
# allocate extrapolation values for computation
|
| 163 |
+
self.rho_xl = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 164 |
+
self.rho_xr = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 165 |
+
self.rho_yl = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 166 |
+
self.rho_yr = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 167 |
+
self.vel_xl = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
|
| 168 |
+
self.vel_xr = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
|
| 169 |
+
self.vel_yl = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
|
| 170 |
+
self.vel_yr = wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
|
| 171 |
+
self.p_xl = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 172 |
+
self.p_xr = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 173 |
+
self.p_yl = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 174 |
+
self.p_yr = wp.zeros(self.dim, dtype=float, device=self.device)
|
| 175 |
+
|
| 176 |
+
# allocate arrays for storing results
|
| 177 |
+
self.seq_rho = [
|
| 178 |
+
wp.zeros(self.dim, dtype=float, device=self.device)
|
| 179 |
+
for _ in range(self.nr_snapshots)
|
| 180 |
+
]
|
| 181 |
+
self.seq_vel = [
|
| 182 |
+
wp.zeros(self.dim, dtype=wp.vec2, device=self.device)
|
| 183 |
+
for _ in range(self.nr_snapshots)
|
| 184 |
+
]
|
| 185 |
+
self.seq_p = [
|
| 186 |
+
wp.zeros(self.dim, dtype=float, device=self.device)
|
| 187 |
+
for _ in range(self.nr_snapshots)
|
| 188 |
+
]
|
| 189 |
+
|
| 190 |
+
self.output_rho = None
|
| 191 |
+
self.output_vel = None
|
| 192 |
+
self.output_p = None
|
| 193 |
+
|
| 194 |
+
def initialize_batch(self) -> None:
|
| 195 |
+
"""Initializes arrays for new batch of simulations"""
|
| 196 |
+
|
| 197 |
+
# initialize random Fourier freq
|
| 198 |
+
seed = np.random.randint(np.iinfo(np.uint64).max, dtype=np.uint64)
|
| 199 |
+
wp.launch(
|
| 200 |
+
init_uniform_random_2d,
|
| 201 |
+
dim=[self.batch_size, self.nr_perturbation_freq],
|
| 202 |
+
inputs=[self.w, -self.perturbation_range, self.perturbation_range, seed],
|
| 203 |
+
device=self.device,
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
# initialize fields
|
| 207 |
+
wp.launch(
|
| 208 |
+
initialize_kelvin_helmoltz_batched_2d,
|
| 209 |
+
dim=self.dim,
|
| 210 |
+
inputs=[
|
| 211 |
+
self.rho,
|
| 212 |
+
self.vel,
|
| 213 |
+
self.p,
|
| 214 |
+
self.w,
|
| 215 |
+
0.05 / np.sqrt(2.0),
|
| 216 |
+
self.dim[1],
|
| 217 |
+
self.dim[2],
|
| 218 |
+
self.nr_perturbation_freq,
|
| 219 |
+
],
|
| 220 |
+
device=self.device,
|
| 221 |
+
)
|
| 222 |
+
wp.launch(
|
| 223 |
+
euler_primitive_to_conserved_batched_2d,
|
| 224 |
+
dim=self.dim,
|
| 225 |
+
inputs=[
|
| 226 |
+
self.rho,
|
| 227 |
+
self.vel,
|
| 228 |
+
self.p,
|
| 229 |
+
self.mass,
|
| 230 |
+
self.mom,
|
| 231 |
+
self.e,
|
| 232 |
+
self.gamma,
|
| 233 |
+
self.vol,
|
| 234 |
+
self.dim[1],
|
| 235 |
+
self.dim[2],
|
| 236 |
+
],
|
| 237 |
+
device=self.device,
|
| 238 |
+
)
|
| 239 |
+
|
| 240 |
+
def generate_batch(self) -> None:
|
| 241 |
+
"""Solve for new batch of simulations"""
|
| 242 |
+
|
| 243 |
+
# initialize tensors with random coef
|
| 244 |
+
self.initialize_batch()
|
| 245 |
+
|
| 246 |
+
# run solver
|
| 247 |
+
for s in range(self.nr_snapshots):
|
| 248 |
+
# save arrays for
|
| 249 |
+
wp.copy(self.seq_rho[s], self.rho)
|
| 250 |
+
wp.copy(self.seq_vel[s], self.vel)
|
| 251 |
+
wp.copy(self.seq_p[s], self.p)
|
| 252 |
+
|
| 253 |
+
# iterations
|
| 254 |
+
for i in range(self.iteration_per_snapshot):
|
| 255 |
+
# compute primitives
|
| 256 |
+
wp.launch(
|
| 257 |
+
euler_conserved_to_primitive_batched_2d,
|
| 258 |
+
dim=self.dim,
|
| 259 |
+
inputs=[
|
| 260 |
+
self.mass,
|
| 261 |
+
self.mom,
|
| 262 |
+
self.e,
|
| 263 |
+
self.rho,
|
| 264 |
+
self.vel,
|
| 265 |
+
self.p,
|
| 266 |
+
self.gamma,
|
| 267 |
+
self.vol,
|
| 268 |
+
self.dim[1],
|
| 269 |
+
self.dim[2],
|
| 270 |
+
],
|
| 271 |
+
device=self.device,
|
| 272 |
+
)
|
| 273 |
+
|
| 274 |
+
# compute extrapolations to faces
|
| 275 |
+
wp.launch(
|
| 276 |
+
euler_extrapolation_batched_2d,
|
| 277 |
+
dim=self.dim,
|
| 278 |
+
inputs=[
|
| 279 |
+
self.rho,
|
| 280 |
+
self.vel,
|
| 281 |
+
self.p,
|
| 282 |
+
self.rho_xl,
|
| 283 |
+
self.rho_xr,
|
| 284 |
+
self.rho_yl,
|
| 285 |
+
self.rho_yr,
|
| 286 |
+
self.vel_xl,
|
| 287 |
+
self.vel_xr,
|
| 288 |
+
self.vel_yl,
|
| 289 |
+
self.vel_yr,
|
| 290 |
+
self.p_xl,
|
| 291 |
+
self.p_xr,
|
| 292 |
+
self.p_yl,
|
| 293 |
+
self.p_yr,
|
| 294 |
+
self.gamma,
|
| 295 |
+
self.dx,
|
| 296 |
+
self.dt,
|
| 297 |
+
self.dim[1],
|
| 298 |
+
self.dim[2],
|
| 299 |
+
],
|
| 300 |
+
device=self.device,
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
# compute fluxes
|
| 304 |
+
wp.launch(
|
| 305 |
+
euler_get_flux_batched_2d,
|
| 306 |
+
dim=self.dim,
|
| 307 |
+
inputs=[
|
| 308 |
+
self.rho_xl,
|
| 309 |
+
self.rho_xr,
|
| 310 |
+
self.rho_yl,
|
| 311 |
+
self.rho_yr,
|
| 312 |
+
self.vel_xl,
|
| 313 |
+
self.vel_xr,
|
| 314 |
+
self.vel_yl,
|
| 315 |
+
self.vel_yr,
|
| 316 |
+
self.p_xl,
|
| 317 |
+
self.p_xr,
|
| 318 |
+
self.p_yl,
|
| 319 |
+
self.p_yr,
|
| 320 |
+
self.mass_flux_x,
|
| 321 |
+
self.mass_flux_y,
|
| 322 |
+
self.mom_flux_x,
|
| 323 |
+
self.mom_flux_y,
|
| 324 |
+
self.e_flux_x,
|
| 325 |
+
self.e_flux_y,
|
| 326 |
+
self.gamma,
|
| 327 |
+
self.dim[1],
|
| 328 |
+
self.dim[2],
|
| 329 |
+
],
|
| 330 |
+
device=self.device,
|
| 331 |
+
)
|
| 332 |
+
|
| 333 |
+
# apply fluxes
|
| 334 |
+
wp.launch(
|
| 335 |
+
euler_apply_flux_batched_2d,
|
| 336 |
+
dim=self.dim,
|
| 337 |
+
inputs=[
|
| 338 |
+
self.mass_flux_x,
|
| 339 |
+
self.mass_flux_y,
|
| 340 |
+
self.mom_flux_x,
|
| 341 |
+
self.mom_flux_y,
|
| 342 |
+
self.e_flux_x,
|
| 343 |
+
self.e_flux_y,
|
| 344 |
+
self.mass,
|
| 345 |
+
self.mom,
|
| 346 |
+
self.e,
|
| 347 |
+
self.dx,
|
| 348 |
+
self.dt,
|
| 349 |
+
self.dim[1],
|
| 350 |
+
self.dim[2],
|
| 351 |
+
],
|
| 352 |
+
device=self.device,
|
| 353 |
+
)
|
| 354 |
+
|
| 355 |
+
def __iter__(self) -> Tuple[Tensor, Tensor, Tensor]:
|
| 356 |
+
"""
|
| 357 |
+
Yields
|
| 358 |
+
------
|
| 359 |
+
Iterator[Tuple[Tensor, Tensor]]
|
| 360 |
+
Infinite iterator that returns a batch of timeseries with (density, velocity, pressure)
|
| 361 |
+
fields of size [batch, seq_length, dim, resolution, resolution]
|
| 362 |
+
"""
|
| 363 |
+
# infinite generator
|
| 364 |
+
while True:
|
| 365 |
+
# run simulation
|
| 366 |
+
self.generate_batch()
|
| 367 |
+
|
| 368 |
+
# return all samples generated before rerunning simulation
|
| 369 |
+
batch_ind = [
|
| 370 |
+
np.arange(self.nr_snapshots - self.seq_length)
|
| 371 |
+
for _ in range(self.batch_size)
|
| 372 |
+
]
|
| 373 |
+
for b_ind in batch_ind:
|
| 374 |
+
np.random.shuffle(b_ind)
|
| 375 |
+
for bb in range(self.nr_snapshots - self.seq_length):
|
| 376 |
+
# run over batch to gather samples
|
| 377 |
+
batched_seq_rho = []
|
| 378 |
+
batched_seq_vel = []
|
| 379 |
+
batched_seq_p = []
|
| 380 |
+
for b in range(self.batch_size):
|
| 381 |
+
# gather seq from each batch
|
| 382 |
+
seq_rho = []
|
| 383 |
+
seq_vel = []
|
| 384 |
+
seq_p = []
|
| 385 |
+
for s in range(self.seq_length):
|
| 386 |
+
# get variables
|
| 387 |
+
rho = wp.to_torch(self.seq_rho[batch_ind[b][bb] + s])[b]
|
| 388 |
+
vel = wp.to_torch(self.seq_vel[batch_ind[b][bb] + s])[b]
|
| 389 |
+
p = wp.to_torch(self.seq_p[batch_ind[b][bb] + s])[b]
|
| 390 |
+
|
| 391 |
+
# add channels
|
| 392 |
+
rho = torch.unsqueeze(rho, 0)
|
| 393 |
+
vel = torch.permute(vel, (2, 0, 1))
|
| 394 |
+
p = torch.unsqueeze(p, 0)
|
| 395 |
+
|
| 396 |
+
# normalize values
|
| 397 |
+
if self.normaliser is not None:
|
| 398 |
+
rho = (
|
| 399 |
+
rho - self.normaliser["density"][0]
|
| 400 |
+
) / self.normaliser["density"][1]
|
| 401 |
+
vel = (
|
| 402 |
+
vel - self.normaliser["velocity"][0]
|
| 403 |
+
) / self.normaliser["velocity"][1]
|
| 404 |
+
p = (p - self.normaliser["pressure"][0]) / self.normaliser[
|
| 405 |
+
"pressure"
|
| 406 |
+
][1]
|
| 407 |
+
|
| 408 |
+
# store for producing seq
|
| 409 |
+
seq_rho.append(rho)
|
| 410 |
+
seq_vel.append(vel)
|
| 411 |
+
seq_p.append(p)
|
| 412 |
+
|
| 413 |
+
# concat seq
|
| 414 |
+
batched_seq_rho.append(torch.stack(seq_rho, axis=0))
|
| 415 |
+
batched_seq_vel.append(torch.stack(seq_vel, axis=0))
|
| 416 |
+
batched_seq_p.append(torch.stack(seq_p, axis=0))
|
| 417 |
+
|
| 418 |
+
# CUDA graphs static copies
|
| 419 |
+
if self.output_rho is None:
|
| 420 |
+
# concat batches
|
| 421 |
+
self.output_rho = torch.stack(batched_seq_rho, axis=0)
|
| 422 |
+
self.output_vel = torch.stack(batched_seq_vel, axis=0)
|
| 423 |
+
self.output_p = torch.stack(batched_seq_p, axis=0)
|
| 424 |
+
else:
|
| 425 |
+
self.output_rho.data.copy_(torch.stack(batched_seq_rho, axis=0))
|
| 426 |
+
self.output_vel.data.copy_(torch.stack(batched_seq_vel, axis=0))
|
| 427 |
+
self.output_p.data.copy_(torch.stack(batched_seq_p, axis=0))
|
| 428 |
+
|
| 429 |
+
yield {
|
| 430 |
+
"density": self.output_rho,
|
| 431 |
+
"velocity": self.output_vel,
|
| 432 |
+
"pressure": self.output_p,
|
| 433 |
+
}
|
| 434 |
+
|
| 435 |
+
def __len__(self):
|
| 436 |
+
return sys.maxsize
|
physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/finite_difference.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
try:
|
| 19 |
+
import warp as wp
|
| 20 |
+
except ImportError:
|
| 21 |
+
print(
|
| 22 |
+
"""NVIDIA WARP is required for this datapipe. This package is under the
|
| 23 |
+
NVIDIA Source Code License (NVSCL). To install use:
|
| 24 |
+
|
| 25 |
+
pip install warp-lang
|
| 26 |
+
"""
|
| 27 |
+
)
|
| 28 |
+
raise SystemExit(1)
|
| 29 |
+
|
| 30 |
+
from .indexing import index_clamped_edges_batched_2d, index_zero_edges_batched_2d
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
@wp.kernel
|
| 34 |
+
def darcy_mgrid_jacobi_iterative_batched_2d(
|
| 35 |
+
darcy0: wp.array3d(dtype=float),
|
| 36 |
+
darcy1: wp.array3d(dtype=float),
|
| 37 |
+
permeability: wp.array3d(dtype=float),
|
| 38 |
+
source: float,
|
| 39 |
+
lx: int,
|
| 40 |
+
ly: int,
|
| 41 |
+
dx: float,
|
| 42 |
+
mgrid_reduction_factor: int,
|
| 43 |
+
): # pragma: no cover
|
| 44 |
+
"""Mult-grid jacobi step for Darcy equation.
|
| 45 |
+
|
| 46 |
+
Parameters
|
| 47 |
+
----------
|
| 48 |
+
darcy0 : wp.array3d
|
| 49 |
+
Darcy solution previous step
|
| 50 |
+
darcy1 : wp.array3d
|
| 51 |
+
Darcy solution for next step
|
| 52 |
+
permeability : wp.array3d
|
| 53 |
+
Permeability field for Darcy equation
|
| 54 |
+
source : float
|
| 55 |
+
Source value for Darcy equation
|
| 56 |
+
lx : int
|
| 57 |
+
Length of domain in x dim
|
| 58 |
+
ly : int
|
| 59 |
+
Length of domain in y dim
|
| 60 |
+
dx : float
|
| 61 |
+
Grid cell size
|
| 62 |
+
mgrid_reduction_factor : int
|
| 63 |
+
Current multi-grid running at
|
| 64 |
+
"""
|
| 65 |
+
|
| 66 |
+
# get index
|
| 67 |
+
b, x, y = wp.tid()
|
| 68 |
+
|
| 69 |
+
# update index from grid reduction factor
|
| 70 |
+
gx = mgrid_reduction_factor * x + (mgrid_reduction_factor - 1)
|
| 71 |
+
gy = mgrid_reduction_factor * y + (mgrid_reduction_factor - 1)
|
| 72 |
+
gdx = dx * wp.float32(mgrid_reduction_factor)
|
| 73 |
+
|
| 74 |
+
# compute darcy stensil
|
| 75 |
+
d_0_1 = index_zero_edges_batched_2d(
|
| 76 |
+
darcy0, b, gx - mgrid_reduction_factor, gy, lx, ly
|
| 77 |
+
)
|
| 78 |
+
d_2_1 = index_zero_edges_batched_2d(
|
| 79 |
+
darcy0, b, gx + mgrid_reduction_factor, gy, lx, ly
|
| 80 |
+
)
|
| 81 |
+
d_1_0 = index_zero_edges_batched_2d(
|
| 82 |
+
darcy0, b, gx, gy - mgrid_reduction_factor, lx, ly
|
| 83 |
+
)
|
| 84 |
+
d_1_2 = index_zero_edges_batched_2d(
|
| 85 |
+
darcy0, b, gx, gy + mgrid_reduction_factor, lx, ly
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
# compute permeability stensil
|
| 89 |
+
p_1_1 = index_clamped_edges_batched_2d(permeability, b, gx, gy, lx, ly)
|
| 90 |
+
p_0_1 = index_clamped_edges_batched_2d(
|
| 91 |
+
permeability, b, gx - mgrid_reduction_factor, gy, lx, ly
|
| 92 |
+
)
|
| 93 |
+
p_2_1 = index_clamped_edges_batched_2d(
|
| 94 |
+
permeability, b, gx + mgrid_reduction_factor, gy, lx, ly
|
| 95 |
+
)
|
| 96 |
+
p_1_0 = index_clamped_edges_batched_2d(
|
| 97 |
+
permeability, b, gx, gy - mgrid_reduction_factor, lx, ly
|
| 98 |
+
)
|
| 99 |
+
p_1_2 = index_clamped_edges_batched_2d(
|
| 100 |
+
permeability, b, gx, gy + mgrid_reduction_factor, lx, ly
|
| 101 |
+
)
|
| 102 |
+
|
| 103 |
+
# compute terms
|
| 104 |
+
dx_squared = gdx * gdx
|
| 105 |
+
t_1 = p_1_1 * (d_0_1 + d_2_1 + d_1_0 + d_1_2) / dx_squared
|
| 106 |
+
t_2 = ((p_2_1 - p_0_1) * (d_2_1 - d_0_1)) / (2.0 * gdx)
|
| 107 |
+
t_3 = ((p_1_2 - p_1_0) * (d_1_2 - d_1_0)) / (2.0 * gdx)
|
| 108 |
+
|
| 109 |
+
# jacobi iterative method
|
| 110 |
+
d_star = (t_1 + t_2 + t_3 + source) / (p_1_1 * 4.0 / dx_squared)
|
| 111 |
+
|
| 112 |
+
# buffers get swapped each iteration
|
| 113 |
+
darcy1[b, gx, gy] = d_star
|
| 114 |
+
|
| 115 |
+
|
| 116 |
+
@wp.kernel
|
| 117 |
+
def mgrid_inf_residual_batched_2d(
|
| 118 |
+
phi0: wp.array3d(dtype=float),
|
| 119 |
+
phi1: wp.array3d(dtype=float),
|
| 120 |
+
inf_res: wp.array(dtype=float),
|
| 121 |
+
mgrid_reduction_factor: int,
|
| 122 |
+
): # pragma: no cover
|
| 123 |
+
"""Infinity norm for checking multi-grid solutions.
|
| 124 |
+
|
| 125 |
+
Parameters
|
| 126 |
+
----------
|
| 127 |
+
phi0 : wp.array3d
|
| 128 |
+
Previous solution
|
| 129 |
+
phi1 : wp.array3d
|
| 130 |
+
Current solution
|
| 131 |
+
inf_res : wp.array
|
| 132 |
+
Array to hold infinity norm value in
|
| 133 |
+
mgrid_reduction_factor : int
|
| 134 |
+
Current multi-grid running at
|
| 135 |
+
"""
|
| 136 |
+
b, x, y = wp.tid()
|
| 137 |
+
gx = mgrid_reduction_factor * x + (mgrid_reduction_factor - 1)
|
| 138 |
+
gy = mgrid_reduction_factor * y + (mgrid_reduction_factor - 1)
|
| 139 |
+
wp.atomic_max(inf_res, 0, wp.abs(phi0[b, gx, gy] - phi1[b, gx, gy]))
|
physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/finite_volume.py
ADDED
|
@@ -0,0 +1,759 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import warp as wp
|
| 19 |
+
except ImportError:
|
| 20 |
+
print(
|
| 21 |
+
"""NVIDIA WARP is required for this datapipe. This package is under the
|
| 22 |
+
NVIDIA Source Code License (NVSCL). To install use:
|
| 23 |
+
|
| 24 |
+
pip install warp-lang
|
| 25 |
+
"""
|
| 26 |
+
)
|
| 27 |
+
raise SystemExit(1)
|
| 28 |
+
|
| 29 |
+
from .indexing import (
|
| 30 |
+
index_periodic_edges_batched_2d,
|
| 31 |
+
index_vec2_periodic_edges_batched_2d,
|
| 32 |
+
)
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
@wp.func
|
| 36 |
+
def extrapolate_to_face_2d(
|
| 37 |
+
f: float, f_dx: float, f_dy: float, dx: float
|
| 38 |
+
): # pragma: no cover
|
| 39 |
+
"""Extrapolate cell values to edges of face
|
| 40 |
+
|
| 41 |
+
Parameters
|
| 42 |
+
----------
|
| 43 |
+
f : float
|
| 44 |
+
Cell value
|
| 45 |
+
f_dx : float
|
| 46 |
+
X derivative of cell value
|
| 47 |
+
f_dy : float
|
| 48 |
+
Y derivative of cell value
|
| 49 |
+
dx : float
|
| 50 |
+
Cell size
|
| 51 |
+
|
| 52 |
+
Returns
|
| 53 |
+
-------
|
| 54 |
+
wp.vec4
|
| 55 |
+
(value on left x, value on right x, value left y, value right y)
|
| 56 |
+
"""
|
| 57 |
+
f_xl = f - f_dx * (dx / 2.0)
|
| 58 |
+
f_xr = f + f_dx * (dx / 2.0)
|
| 59 |
+
f_yl = f - f_dy * (dx / 2.0)
|
| 60 |
+
f_yr = f + f_dy * (dx / 2.0)
|
| 61 |
+
return wp.vec4(f_xl, f_xr, f_yl, f_yr)
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
@wp.func
|
| 65 |
+
def apply_flux_2d(
|
| 66 |
+
f: float,
|
| 67 |
+
flux_f_xl_dx: float,
|
| 68 |
+
flux_f_xr_dx: float,
|
| 69 |
+
flux_f_yl_dy: float,
|
| 70 |
+
flux_f_yr_dy: float,
|
| 71 |
+
dx: float,
|
| 72 |
+
dt: float,
|
| 73 |
+
): # pragma: no cover
|
| 74 |
+
"""Apply flux to cell
|
| 75 |
+
|
| 76 |
+
Parameters
|
| 77 |
+
----------
|
| 78 |
+
f : float
|
| 79 |
+
Cell value
|
| 80 |
+
flux_f_xl_dx : float
|
| 81 |
+
Left x flux
|
| 82 |
+
flux_f_xr_dx : float
|
| 83 |
+
Right x flux
|
| 84 |
+
flux_f_yl_dy : float
|
| 85 |
+
Left y flux
|
| 86 |
+
flux_f_yr_dy : float
|
| 87 |
+
Right y flux
|
| 88 |
+
dx : float
|
| 89 |
+
Cell size
|
| 90 |
+
dt : float
|
| 91 |
+
Time step size
|
| 92 |
+
|
| 93 |
+
Returns
|
| 94 |
+
-------
|
| 95 |
+
float
|
| 96 |
+
Cell value with added flux
|
| 97 |
+
"""
|
| 98 |
+
f += -dt * dx * flux_f_xl_dx
|
| 99 |
+
f += dt * dx * flux_f_xr_dx
|
| 100 |
+
f += -dt * dx * flux_f_yl_dy
|
| 101 |
+
f += dt * dx * flux_f_yr_dy
|
| 102 |
+
return f
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
@wp.func
|
| 106 |
+
def apply_flux_vec2_2d(
|
| 107 |
+
f: wp.vec2,
|
| 108 |
+
flux_f_xl_dx: wp.vec2,
|
| 109 |
+
flux_f_xr_dx: wp.vec2,
|
| 110 |
+
flux_f_yl_dy: wp.vec2,
|
| 111 |
+
flux_f_yr_dy: wp.vec2,
|
| 112 |
+
dx: float,
|
| 113 |
+
dt: float,
|
| 114 |
+
): # pragma: no cover
|
| 115 |
+
"""Apply flux on cell with vector value
|
| 116 |
+
|
| 117 |
+
Parameters
|
| 118 |
+
----------
|
| 119 |
+
f : wp.vec2
|
| 120 |
+
Cell vector value
|
| 121 |
+
flux_f_xl_dx : wp.vec2
|
| 122 |
+
Vector flux in x left
|
| 123 |
+
flux_f_xr_dx : wp.vec2
|
| 124 |
+
Vector flux in x right
|
| 125 |
+
flux_f_yl_dy : wp.vec2
|
| 126 |
+
Vector flux in y left
|
| 127 |
+
flux_f_yr_dy : wp.vec2
|
| 128 |
+
Vector flux in y right
|
| 129 |
+
dx : float
|
| 130 |
+
Cell size
|
| 131 |
+
dt : float
|
| 132 |
+
Time step size
|
| 133 |
+
|
| 134 |
+
Returns
|
| 135 |
+
-------
|
| 136 |
+
wp.vec2
|
| 137 |
+
Vector cell value with added flux
|
| 138 |
+
"""
|
| 139 |
+
f += -dt * dx * flux_f_xl_dx
|
| 140 |
+
f += dt * dx * flux_f_xr_dx
|
| 141 |
+
f += -dt * dx * flux_f_yl_dy
|
| 142 |
+
f += dt * dx * flux_f_yr_dy
|
| 143 |
+
return f
|
| 144 |
+
|
| 145 |
+
|
| 146 |
+
@wp.func
|
| 147 |
+
def euler_flux_2d(
|
| 148 |
+
rho_l: float,
|
| 149 |
+
rho_r: float,
|
| 150 |
+
vx_l: float,
|
| 151 |
+
vx_r: float,
|
| 152 |
+
vy_l: float,
|
| 153 |
+
vy_r: float,
|
| 154 |
+
p_l: float,
|
| 155 |
+
p_r: float,
|
| 156 |
+
gamma: float,
|
| 157 |
+
): # pragma: no cover
|
| 158 |
+
"""Compute Euler flux
|
| 159 |
+
|
| 160 |
+
Parameters
|
| 161 |
+
----------
|
| 162 |
+
rho_l : float
|
| 163 |
+
Density left
|
| 164 |
+
rho_r : float
|
| 165 |
+
Density right
|
| 166 |
+
vx_l : float
|
| 167 |
+
X velocity left
|
| 168 |
+
vx_r : float
|
| 169 |
+
X velocity right
|
| 170 |
+
vy_l : float
|
| 171 |
+
Y velocity left
|
| 172 |
+
vy_r : float
|
| 173 |
+
Y velocity right
|
| 174 |
+
p_l : float
|
| 175 |
+
Pressure left
|
| 176 |
+
p_r : float
|
| 177 |
+
Pressure right
|
| 178 |
+
gamma : float
|
| 179 |
+
Gas constant
|
| 180 |
+
|
| 181 |
+
Returns
|
| 182 |
+
-------
|
| 183 |
+
wp.vec4
|
| 184 |
+
Vector containing mass, momentum x, momentum y, and energy flux.
|
| 185 |
+
"""
|
| 186 |
+
# get energies
|
| 187 |
+
e_l = p_l / (gamma - 1.0) + 0.5 * rho_l * (vx_l * vx_l + vy_l * vy_l)
|
| 188 |
+
e_r = p_r / (gamma - 1.0) + 0.5 * rho_r * (vx_r * vx_r + vy_r * vy_r)
|
| 189 |
+
|
| 190 |
+
# averaged states
|
| 191 |
+
rho_ave = 0.5 * (rho_l + rho_r)
|
| 192 |
+
momx_ave = 0.5 * (rho_l * vx_l + rho_r * vx_r)
|
| 193 |
+
momy_ave = 0.5 * (rho_l * vy_l + rho_r * vy_r)
|
| 194 |
+
e_ave = 0.5 * (e_l + e_r)
|
| 195 |
+
p_ave = (gamma - 1.0) * (
|
| 196 |
+
e_ave - 0.5 * (momx_ave * momx_ave + momy_ave * momy_ave) / rho_ave
|
| 197 |
+
)
|
| 198 |
+
|
| 199 |
+
# compute fluxes
|
| 200 |
+
flux_mass = momx_ave
|
| 201 |
+
flux_momx = momx_ave * momx_ave / rho_ave + p_ave
|
| 202 |
+
flux_momy = momx_ave * momy_ave / rho_ave
|
| 203 |
+
flux_e = (e_ave + p_ave) * momx_ave / rho_ave
|
| 204 |
+
|
| 205 |
+
# compute wavespeed
|
| 206 |
+
c_l = wp.sqrt(gamma * p_l / rho_l) + wp.abs(vx_l)
|
| 207 |
+
c_r = wp.sqrt(gamma * p_r / rho_r) + wp.abs(vx_r)
|
| 208 |
+
c = wp.max(c_l, c_r)
|
| 209 |
+
|
| 210 |
+
# add stabilizing diffusion term
|
| 211 |
+
flux_mass -= c * 0.5 * (rho_l - rho_r)
|
| 212 |
+
flux_momx -= c * 0.5 * (rho_l * vx_l - rho_r * vx_r)
|
| 213 |
+
flux_momy -= c * 0.5 * (rho_l * vy_l - rho_r * vy_r)
|
| 214 |
+
flux_e -= c * 0.5 * (e_l - e_r)
|
| 215 |
+
|
| 216 |
+
return wp.vec4(flux_mass, flux_momx, flux_momy, flux_e)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
@wp.kernel
|
| 220 |
+
def euler_primitive_to_conserved_batched_2d(
|
| 221 |
+
rho: wp.array3d(dtype=float),
|
| 222 |
+
vel: wp.array3d(dtype=wp.vec2),
|
| 223 |
+
p: wp.array3d(dtype=float),
|
| 224 |
+
mass: wp.array3d(dtype=float),
|
| 225 |
+
mom: wp.array3d(dtype=wp.vec2),
|
| 226 |
+
e: wp.array3d(dtype=float),
|
| 227 |
+
gamma: float,
|
| 228 |
+
vol: float,
|
| 229 |
+
lx: int,
|
| 230 |
+
ly: int,
|
| 231 |
+
): # pragma: no cover
|
| 232 |
+
"""Primitive Euler to conserved values
|
| 233 |
+
|
| 234 |
+
Parameters
|
| 235 |
+
----------
|
| 236 |
+
rho : wp.array3d
|
| 237 |
+
Density
|
| 238 |
+
vel : wp.array3d
|
| 239 |
+
Velocity
|
| 240 |
+
p : wp.array3d
|
| 241 |
+
Pressure
|
| 242 |
+
mass : wp.array3d
|
| 243 |
+
Mass
|
| 244 |
+
mom : wp.array3d
|
| 245 |
+
Momentum
|
| 246 |
+
e : wp.array3d
|
| 247 |
+
Energy
|
| 248 |
+
gamma : float
|
| 249 |
+
Gas constant
|
| 250 |
+
vol : float
|
| 251 |
+
Volume of cell
|
| 252 |
+
lx : int
|
| 253 |
+
Grid size x dim
|
| 254 |
+
ly : int
|
| 255 |
+
Grid size y dim
|
| 256 |
+
"""
|
| 257 |
+
|
| 258 |
+
# get index
|
| 259 |
+
b, i, j = wp.tid()
|
| 260 |
+
|
| 261 |
+
# get conserve values
|
| 262 |
+
rho_i_j = index_periodic_edges_batched_2d(rho, b, i, j, lx, ly)
|
| 263 |
+
vel_i_j = index_vec2_periodic_edges_batched_2d(vel, b, i, j, lx, ly)
|
| 264 |
+
p_i_j = index_periodic_edges_batched_2d(p, b, i, j, lx, ly)
|
| 265 |
+
|
| 266 |
+
# get primitive values
|
| 267 |
+
mass_i_j = rho_i_j * vol
|
| 268 |
+
mom_i_j = vel_i_j * rho_i_j * vol
|
| 269 |
+
e_i_j = (
|
| 270 |
+
p_i_j / (gamma - 1.0)
|
| 271 |
+
+ 0.5 * rho_i_j * (vel_i_j[0] * vel_i_j[0] + vel_i_j[1] * vel_i_j[1])
|
| 272 |
+
) * vol
|
| 273 |
+
|
| 274 |
+
# set values
|
| 275 |
+
mass[b, i, j] = mass_i_j
|
| 276 |
+
mom[b, i, j] = mom_i_j
|
| 277 |
+
e[b, i, j] = e_i_j
|
| 278 |
+
|
| 279 |
+
|
| 280 |
+
@wp.kernel
|
| 281 |
+
def euler_conserved_to_primitive_batched_2d(
|
| 282 |
+
mass: wp.array3d(dtype=float),
|
| 283 |
+
mom: wp.array3d(dtype=wp.vec2),
|
| 284 |
+
e: wp.array3d(dtype=float),
|
| 285 |
+
rho: wp.array3d(dtype=float),
|
| 286 |
+
vel: wp.array3d(dtype=wp.vec2),
|
| 287 |
+
p: wp.array3d(dtype=float),
|
| 288 |
+
gamma: float,
|
| 289 |
+
vol: float,
|
| 290 |
+
lx: int,
|
| 291 |
+
ly: int,
|
| 292 |
+
): # pragma: no cover
|
| 293 |
+
"""Conserved Euler to primitive value
|
| 294 |
+
|
| 295 |
+
Parameters
|
| 296 |
+
----------
|
| 297 |
+
mass : wp.array3d
|
| 298 |
+
Mass
|
| 299 |
+
mom : wp.array3d
|
| 300 |
+
Momentum
|
| 301 |
+
e : wp.array3d
|
| 302 |
+
Energy
|
| 303 |
+
rho : wp.array3d
|
| 304 |
+
Density
|
| 305 |
+
vel : wp.array3d
|
| 306 |
+
Velocity
|
| 307 |
+
p : wp.array3d
|
| 308 |
+
Pressure
|
| 309 |
+
gamma : float
|
| 310 |
+
Gas constant
|
| 311 |
+
vol : float
|
| 312 |
+
Cell volume
|
| 313 |
+
lx : int
|
| 314 |
+
Grid size X dim
|
| 315 |
+
ly : int
|
| 316 |
+
Grid size Y dim
|
| 317 |
+
"""
|
| 318 |
+
|
| 319 |
+
# get index
|
| 320 |
+
b, i, j = wp.tid()
|
| 321 |
+
|
| 322 |
+
# get conserve values
|
| 323 |
+
mass_i_j = index_periodic_edges_batched_2d(mass, b, i, j, lx, ly)
|
| 324 |
+
mom_i_j = index_vec2_periodic_edges_batched_2d(mom, b, i, j, lx, ly)
|
| 325 |
+
e_i_j = index_periodic_edges_batched_2d(e, b, i, j, lx, ly)
|
| 326 |
+
|
| 327 |
+
# get primitive values
|
| 328 |
+
rho_i_j = mass_i_j / vol
|
| 329 |
+
vel_i_j = mom_i_j / rho_i_j / vol
|
| 330 |
+
p_i_j = (
|
| 331 |
+
e_i_j / vol
|
| 332 |
+
- 0.5 * rho_i_j * (vel_i_j[0] * vel_i_j[0] + vel_i_j[1] * vel_i_j[1])
|
| 333 |
+
) * (gamma - 1.0)
|
| 334 |
+
|
| 335 |
+
# set values
|
| 336 |
+
rho[b, i, j] = rho_i_j
|
| 337 |
+
vel[b, i, j] = vel_i_j
|
| 338 |
+
p[b, i, j] = p_i_j
|
| 339 |
+
|
| 340 |
+
|
| 341 |
+
@wp.kernel
|
| 342 |
+
def euler_extrapolation_batched_2d(
|
| 343 |
+
rho: wp.array3d(dtype=float),
|
| 344 |
+
vel: wp.array3d(dtype=wp.vec2),
|
| 345 |
+
p: wp.array3d(dtype=float),
|
| 346 |
+
rho_xl: wp.array3d(dtype=float),
|
| 347 |
+
rho_xr: wp.array3d(dtype=float),
|
| 348 |
+
rho_yl: wp.array3d(dtype=float),
|
| 349 |
+
rho_yr: wp.array3d(dtype=float),
|
| 350 |
+
vel_xl: wp.array3d(dtype=wp.vec2),
|
| 351 |
+
vel_xr: wp.array3d(dtype=wp.vec2),
|
| 352 |
+
vel_yl: wp.array3d(dtype=wp.vec2),
|
| 353 |
+
vel_yr: wp.array3d(dtype=wp.vec2),
|
| 354 |
+
p_xl: wp.array3d(dtype=float),
|
| 355 |
+
p_xr: wp.array3d(dtype=float),
|
| 356 |
+
p_yl: wp.array3d(dtype=float),
|
| 357 |
+
p_yr: wp.array3d(dtype=float),
|
| 358 |
+
gamma: float,
|
| 359 |
+
dx: float,
|
| 360 |
+
dt: float,
|
| 361 |
+
lx: int,
|
| 362 |
+
ly: int,
|
| 363 |
+
): # pragma: no cover
|
| 364 |
+
"""Extrapolate Euler values to edges
|
| 365 |
+
|
| 366 |
+
Parameters
|
| 367 |
+
----------
|
| 368 |
+
rho : wp.array3d
|
| 369 |
+
Density
|
| 370 |
+
vel : wp.array3d
|
| 371 |
+
Velocity
|
| 372 |
+
p : wp.array3d
|
| 373 |
+
Pressure
|
| 374 |
+
rho_xl : wp.array3d
|
| 375 |
+
Density x left
|
| 376 |
+
rho_xr : wp.array3d
|
| 377 |
+
Density x right
|
| 378 |
+
rho_yl : wp.array3d
|
| 379 |
+
Density y left
|
| 380 |
+
rho_yr : wp.array3d
|
| 381 |
+
Density y right
|
| 382 |
+
vel_xl : wp.array3d
|
| 383 |
+
Velocity x left
|
| 384 |
+
vel_xr : wp.array3d
|
| 385 |
+
Velocity x right
|
| 386 |
+
vel_yl : wp.array3d
|
| 387 |
+
Velocity y left
|
| 388 |
+
vel_yr : wp.array3d
|
| 389 |
+
Velocity y right
|
| 390 |
+
p_xl : wp.array3d
|
| 391 |
+
Pressure x left
|
| 392 |
+
p_xr : wp.array3d
|
| 393 |
+
Pressure x right
|
| 394 |
+
p_yl : wp.array3d
|
| 395 |
+
Pressure y left
|
| 396 |
+
p_yr : wp.array3d
|
| 397 |
+
Pressure y right
|
| 398 |
+
gamma : float
|
| 399 |
+
Gas constant
|
| 400 |
+
dx : float
|
| 401 |
+
Cell size
|
| 402 |
+
dt : float
|
| 403 |
+
Time step size
|
| 404 |
+
lx : int
|
| 405 |
+
Grid size x
|
| 406 |
+
ly : int
|
| 407 |
+
Grid size y
|
| 408 |
+
"""
|
| 409 |
+
|
| 410 |
+
# get index
|
| 411 |
+
b, i, j = wp.tid()
|
| 412 |
+
|
| 413 |
+
# get rho stensil
|
| 414 |
+
rho_1_1 = index_periodic_edges_batched_2d(rho, b, i, j, lx, ly)
|
| 415 |
+
rho_2_1 = index_periodic_edges_batched_2d(rho, b, i + 1, j, lx, ly)
|
| 416 |
+
rho_1_2 = index_periodic_edges_batched_2d(rho, b, i, j + 1, lx, ly)
|
| 417 |
+
rho_0_1 = index_periodic_edges_batched_2d(rho, b, i - 1, j, lx, ly)
|
| 418 |
+
rho_1_0 = index_periodic_edges_batched_2d(rho, b, i, j - 1, lx, ly)
|
| 419 |
+
|
| 420 |
+
# get momentum stensil
|
| 421 |
+
vel_1_1 = index_vec2_periodic_edges_batched_2d(vel, b, i, j, lx, ly)
|
| 422 |
+
vel_2_1 = index_vec2_periodic_edges_batched_2d(vel, b, i + 1, j, lx, ly)
|
| 423 |
+
vel_1_2 = index_vec2_periodic_edges_batched_2d(vel, b, i, j + 1, lx, ly)
|
| 424 |
+
vel_0_1 = index_vec2_periodic_edges_batched_2d(vel, b, i - 1, j, lx, ly)
|
| 425 |
+
vel_1_0 = index_vec2_periodic_edges_batched_2d(vel, b, i, j - 1, lx, ly)
|
| 426 |
+
|
| 427 |
+
# get energy stensil
|
| 428 |
+
p_1_1 = index_periodic_edges_batched_2d(p, b, i, j, lx, ly)
|
| 429 |
+
p_2_1 = index_periodic_edges_batched_2d(p, b, i + 1, j, lx, ly)
|
| 430 |
+
p_1_2 = index_periodic_edges_batched_2d(p, b, i, j + 1, lx, ly)
|
| 431 |
+
p_0_1 = index_periodic_edges_batched_2d(p, b, i - 1, j, lx, ly)
|
| 432 |
+
p_1_0 = index_periodic_edges_batched_2d(p, b, i, j - 1, lx, ly)
|
| 433 |
+
|
| 434 |
+
# compute density grad
|
| 435 |
+
rho_dx = (rho_2_1 - rho_0_1) / (2.0 * dx)
|
| 436 |
+
rho_dy = (rho_1_2 - rho_1_0) / (2.0 * dx)
|
| 437 |
+
|
| 438 |
+
# compute velocity grad
|
| 439 |
+
vel_dx = (vel_2_1 - vel_0_1) / (2.0 * dx)
|
| 440 |
+
vel_dy = (vel_1_2 - vel_1_0) / (2.0 * dx)
|
| 441 |
+
|
| 442 |
+
# compute pressure grad
|
| 443 |
+
p_dx = (p_2_1 - p_0_1) / (2.0 * dx)
|
| 444 |
+
p_dy = (p_1_2 - p_1_0) / (2.0 * dx)
|
| 445 |
+
|
| 446 |
+
# extrapolate half time step density
|
| 447 |
+
rho_prime = rho_1_1 - 0.5 * dt * (
|
| 448 |
+
vel_1_1[0] * rho_dx
|
| 449 |
+
+ rho_1_1 * vel_dx[0]
|
| 450 |
+
+ vel_1_1[1] * rho_dy
|
| 451 |
+
+ rho_1_1 * vel_dy[1]
|
| 452 |
+
)
|
| 453 |
+
vx_prime = vel_1_1[0] - 0.5 * dt * (
|
| 454 |
+
vel_1_1[0] * vel_dx[0] + vel_1_1[1] * vel_dy[0] + (1.0 / rho_1_1) * p_dx
|
| 455 |
+
)
|
| 456 |
+
vy_prime = vel_1_1[1] - 0.5 * dt * (
|
| 457 |
+
vel_1_1[0] * vel_dx[1] + vel_1_1[1] * vel_dy[1] + (1.0 / rho_1_1) * p_dy
|
| 458 |
+
)
|
| 459 |
+
p_prime = p_1_1 - 0.5 * dt * (
|
| 460 |
+
gamma * p_1_1 * (vel_dx[0] + vel_dy[1]) + vel_1_1[0] * p_dx + vel_1_1[1] * p_dy
|
| 461 |
+
)
|
| 462 |
+
|
| 463 |
+
# extrapolate in space to face centers
|
| 464 |
+
rho_space_extra = extrapolate_to_face_2d(rho_prime, rho_dx, rho_dy, dx)
|
| 465 |
+
vx_space_extra = extrapolate_to_face_2d(vx_prime, vel_dx[0], vel_dy[0], dx)
|
| 466 |
+
vy_space_extra = extrapolate_to_face_2d(vy_prime, vel_dx[1], vel_dy[1], dx)
|
| 467 |
+
p_space_extra = extrapolate_to_face_2d(p_prime, p_dx, p_dy, dx)
|
| 468 |
+
|
| 469 |
+
# store values
|
| 470 |
+
rho_xl[b, i, j] = rho_space_extra[0]
|
| 471 |
+
rho_xr[b, i, j] = rho_space_extra[1]
|
| 472 |
+
rho_yl[b, i, j] = rho_space_extra[2]
|
| 473 |
+
rho_yr[b, i, j] = rho_space_extra[3]
|
| 474 |
+
vel_xl[b, i, j] = wp.vec2(vx_space_extra[0], vy_space_extra[0])
|
| 475 |
+
vel_xr[b, i, j] = wp.vec2(vx_space_extra[1], vy_space_extra[1])
|
| 476 |
+
vel_yl[b, i, j] = wp.vec2(vx_space_extra[2], vy_space_extra[2])
|
| 477 |
+
vel_yr[b, i, j] = wp.vec2(vx_space_extra[3], vy_space_extra[3])
|
| 478 |
+
p_xl[b, i, j] = p_space_extra[0]
|
| 479 |
+
p_xr[b, i, j] = p_space_extra[1]
|
| 480 |
+
p_yl[b, i, j] = p_space_extra[2]
|
| 481 |
+
p_yr[b, i, j] = p_space_extra[3]
|
| 482 |
+
|
| 483 |
+
|
| 484 |
+
@wp.kernel
|
| 485 |
+
def euler_get_flux_batched_2d(
|
| 486 |
+
rho_xl: wp.array3d(dtype=float),
|
| 487 |
+
rho_xr: wp.array3d(dtype=float),
|
| 488 |
+
rho_yl: wp.array3d(dtype=float),
|
| 489 |
+
rho_yr: wp.array3d(dtype=float),
|
| 490 |
+
vel_xl: wp.array3d(dtype=wp.vec2),
|
| 491 |
+
vel_xr: wp.array3d(dtype=wp.vec2),
|
| 492 |
+
vel_yl: wp.array3d(dtype=wp.vec2),
|
| 493 |
+
vel_yr: wp.array3d(dtype=wp.vec2),
|
| 494 |
+
p_xl: wp.array3d(dtype=float),
|
| 495 |
+
p_xr: wp.array3d(dtype=float),
|
| 496 |
+
p_yl: wp.array3d(dtype=float),
|
| 497 |
+
p_yr: wp.array3d(dtype=float),
|
| 498 |
+
mass_flux_x: wp.array3d(dtype=float),
|
| 499 |
+
mass_flux_y: wp.array3d(dtype=float),
|
| 500 |
+
mom_flux_x: wp.array3d(dtype=wp.vec2),
|
| 501 |
+
mom_flux_y: wp.array3d(dtype=wp.vec2),
|
| 502 |
+
e_flux_x: wp.array3d(dtype=float),
|
| 503 |
+
e_flux_y: wp.array3d(dtype=float),
|
| 504 |
+
gamma: float,
|
| 505 |
+
lx: int,
|
| 506 |
+
ly: int,
|
| 507 |
+
): # pragma: no cover
|
| 508 |
+
"""Use extrapolated Euler values to compute fluxes
|
| 509 |
+
|
| 510 |
+
Parameters
|
| 511 |
+
----------
|
| 512 |
+
rho_xl : wp.array3d
|
| 513 |
+
Density x left
|
| 514 |
+
rho_xr : wp.array3d
|
| 515 |
+
Density x right
|
| 516 |
+
rho_yl : wp.array3d
|
| 517 |
+
Density y left
|
| 518 |
+
rho_yr : wp.array3d
|
| 519 |
+
Density y right
|
| 520 |
+
vel_xl : wp.array3d
|
| 521 |
+
Velocity x left
|
| 522 |
+
vel_xr : wp.array3d
|
| 523 |
+
Velocity x right
|
| 524 |
+
vel_yl : wp.array3d
|
| 525 |
+
Velocity y left
|
| 526 |
+
vel_yr : wp.array3d
|
| 527 |
+
Velocity y right
|
| 528 |
+
p_xl : wp.array3d
|
| 529 |
+
Pressure x left
|
| 530 |
+
p_xr : wp.array3d
|
| 531 |
+
Pressure x right
|
| 532 |
+
p_yl : wp.array3d
|
| 533 |
+
Pressure y left
|
| 534 |
+
p_yr : wp.array3d
|
| 535 |
+
Pressure y right
|
| 536 |
+
mass_flux_x : wp.array3d
|
| 537 |
+
Mass flux x
|
| 538 |
+
mass_flux_y : wp.array3d
|
| 539 |
+
Mass flux y
|
| 540 |
+
mom_flux_x : wp.array3d
|
| 541 |
+
Momentum flux x
|
| 542 |
+
mom_flux_y : wp.array3d
|
| 543 |
+
Momentum flux y
|
| 544 |
+
e_flux_x : wp.array3d
|
| 545 |
+
Energy flux x
|
| 546 |
+
e_flux_y : wp.array3d
|
| 547 |
+
Energy flux y
|
| 548 |
+
gamma : float
|
| 549 |
+
Gas constant
|
| 550 |
+
lx : int
|
| 551 |
+
Grid size x
|
| 552 |
+
ly : int
|
| 553 |
+
Grid size y
|
| 554 |
+
"""
|
| 555 |
+
|
| 556 |
+
# get index
|
| 557 |
+
b, i, j = wp.tid()
|
| 558 |
+
|
| 559 |
+
# get space extrapolation for faces
|
| 560 |
+
rho_xl_1 = index_periodic_edges_batched_2d(rho_xl, b, i + 1, j, lx, ly)
|
| 561 |
+
rho_xr_0 = index_periodic_edges_batched_2d(rho_xr, b, i, j, lx, ly)
|
| 562 |
+
rho_yl_1 = index_periodic_edges_batched_2d(rho_yl, b, i, j + 1, lx, ly)
|
| 563 |
+
rho_yr_0 = index_periodic_edges_batched_2d(rho_yr, b, i, j, lx, ly)
|
| 564 |
+
vel_xl_1 = index_vec2_periodic_edges_batched_2d(vel_xl, b, i + 1, j, lx, ly)
|
| 565 |
+
vel_xr_0 = index_vec2_periodic_edges_batched_2d(vel_xr, b, i, j, lx, ly)
|
| 566 |
+
vel_yl_1 = index_vec2_periodic_edges_batched_2d(vel_yl, b, i, j + 1, lx, ly)
|
| 567 |
+
vel_yr_0 = index_vec2_periodic_edges_batched_2d(vel_yr, b, i, j, lx, ly)
|
| 568 |
+
p_xl_1 = index_periodic_edges_batched_2d(p_xl, b, i + 1, j, lx, ly)
|
| 569 |
+
p_xr_0 = index_periodic_edges_batched_2d(p_xr, b, i, j, lx, ly)
|
| 570 |
+
p_yl_1 = index_periodic_edges_batched_2d(p_yl, b, i, j + 1, lx, ly)
|
| 571 |
+
p_yr_0 = index_periodic_edges_batched_2d(p_yr, b, i, j, lx, ly)
|
| 572 |
+
|
| 573 |
+
# compute fluxes
|
| 574 |
+
flux_x = euler_flux_2d(
|
| 575 |
+
rho_xl_1,
|
| 576 |
+
rho_xr_0,
|
| 577 |
+
vel_xl_1[0],
|
| 578 |
+
vel_xr_0[0],
|
| 579 |
+
vel_xl_1[1],
|
| 580 |
+
vel_xr_0[1],
|
| 581 |
+
p_xl_1,
|
| 582 |
+
p_xr_0,
|
| 583 |
+
gamma,
|
| 584 |
+
)
|
| 585 |
+
flux_y = euler_flux_2d(
|
| 586 |
+
rho_yl_1,
|
| 587 |
+
rho_yr_0,
|
| 588 |
+
vel_yl_1[1],
|
| 589 |
+
vel_yr_0[1],
|
| 590 |
+
vel_yl_1[0],
|
| 591 |
+
vel_yr_0[0],
|
| 592 |
+
p_yl_1,
|
| 593 |
+
p_yr_0,
|
| 594 |
+
gamma,
|
| 595 |
+
)
|
| 596 |
+
|
| 597 |
+
# set values
|
| 598 |
+
mass_flux_x[b, i, j] = flux_x[0]
|
| 599 |
+
mass_flux_y[b, i, j] = flux_y[0]
|
| 600 |
+
mom_flux_x[b, i, j] = wp.vec2(flux_x[1], flux_x[2])
|
| 601 |
+
mom_flux_y[b, i, j] = wp.vec2(flux_y[2], flux_y[1])
|
| 602 |
+
e_flux_x[b, i, j] = flux_x[3]
|
| 603 |
+
e_flux_y[b, i, j] = flux_y[3]
|
| 604 |
+
|
| 605 |
+
|
| 606 |
+
@wp.kernel
|
| 607 |
+
def euler_apply_flux_batched_2d(
|
| 608 |
+
mass_flux_x: wp.array3d(dtype=float),
|
| 609 |
+
mass_flux_y: wp.array3d(dtype=float),
|
| 610 |
+
mom_flux_x: wp.array3d(dtype=wp.vec2),
|
| 611 |
+
mom_flux_y: wp.array3d(dtype=wp.vec2),
|
| 612 |
+
e_flux_x: wp.array3d(dtype=float),
|
| 613 |
+
e_flux_y: wp.array3d(dtype=float),
|
| 614 |
+
mass: wp.array3d(dtype=float),
|
| 615 |
+
mom: wp.array3d(dtype=wp.vec2),
|
| 616 |
+
e: wp.array3d(dtype=float),
|
| 617 |
+
dx: float,
|
| 618 |
+
dt: float,
|
| 619 |
+
lx: int,
|
| 620 |
+
ly: int,
|
| 621 |
+
): # pragma: no cover
|
| 622 |
+
"""Apply fluxes to Euler values
|
| 623 |
+
|
| 624 |
+
Parameters
|
| 625 |
+
----------
|
| 626 |
+
mass_flux_x : wp.array3d
|
| 627 |
+
Mass flux X
|
| 628 |
+
mass_flux_y : wp.array3d
|
| 629 |
+
Mass flux Y
|
| 630 |
+
mom_flux_x : wp.array3d
|
| 631 |
+
Momentum flux X
|
| 632 |
+
mom_flux_y : wp.array3d
|
| 633 |
+
Momentum flux Y
|
| 634 |
+
e_flux_x : wp.array3d
|
| 635 |
+
Energy flux X
|
| 636 |
+
e_flux_y : wp.array3d
|
| 637 |
+
Energy flux Y
|
| 638 |
+
mass : wp.array3d
|
| 639 |
+
Mass
|
| 640 |
+
mom : wp.array3d
|
| 641 |
+
Momentum
|
| 642 |
+
e : wp.array3d
|
| 643 |
+
Energy
|
| 644 |
+
dx : float
|
| 645 |
+
Cell size
|
| 646 |
+
dt : float
|
| 647 |
+
Time step size
|
| 648 |
+
lx : int
|
| 649 |
+
Grid size x
|
| 650 |
+
ly : int
|
| 651 |
+
Grid size y
|
| 652 |
+
"""
|
| 653 |
+
|
| 654 |
+
# get index
|
| 655 |
+
b, i, j = wp.tid()
|
| 656 |
+
|
| 657 |
+
# get new mass
|
| 658 |
+
mass_1 = index_periodic_edges_batched_2d(mass, b, i, j, lx, ly)
|
| 659 |
+
mass_flux_x_1 = index_periodic_edges_batched_2d(mass_flux_x, b, i, j, lx, ly)
|
| 660 |
+
mass_flux_x_0 = index_periodic_edges_batched_2d(mass_flux_x, b, i - 1, j, lx, ly)
|
| 661 |
+
mass_flux_y_1 = index_periodic_edges_batched_2d(mass_flux_y, b, i, j, lx, ly)
|
| 662 |
+
mass_flux_y_0 = index_periodic_edges_batched_2d(mass_flux_y, b, i, j - 1, lx, ly)
|
| 663 |
+
new_mass = apply_flux_2d(
|
| 664 |
+
mass_1, mass_flux_x_1, mass_flux_x_0, mass_flux_y_1, mass_flux_y_0, dx, dt
|
| 665 |
+
)
|
| 666 |
+
|
| 667 |
+
# get new mom
|
| 668 |
+
mom_1 = index_vec2_periodic_edges_batched_2d(mom, b, i, j, lx, ly)
|
| 669 |
+
mom_flux_x_1 = index_vec2_periodic_edges_batched_2d(mom_flux_x, b, i, j, lx, ly)
|
| 670 |
+
mom_flux_x_0 = index_vec2_periodic_edges_batched_2d(mom_flux_x, b, i - 1, j, lx, ly)
|
| 671 |
+
mom_flux_y_1 = index_vec2_periodic_edges_batched_2d(mom_flux_y, b, i, j, lx, ly)
|
| 672 |
+
mom_flux_y_0 = index_vec2_periodic_edges_batched_2d(mom_flux_y, b, i, j - 1, lx, ly)
|
| 673 |
+
new_mom = apply_flux_vec2_2d(
|
| 674 |
+
mom_1, mom_flux_x_1, mom_flux_x_0, mom_flux_y_1, mom_flux_y_0, dx, dt
|
| 675 |
+
)
|
| 676 |
+
|
| 677 |
+
# get new energy
|
| 678 |
+
e_1 = index_periodic_edges_batched_2d(e, b, i, j, lx, ly)
|
| 679 |
+
e_flux_x_1 = index_periodic_edges_batched_2d(e_flux_x, b, i, j, lx, ly)
|
| 680 |
+
e_flux_x_0 = index_periodic_edges_batched_2d(e_flux_x, b, i - 1, j, lx, ly)
|
| 681 |
+
e_flux_y_1 = index_periodic_edges_batched_2d(e_flux_y, b, i, j, lx, ly)
|
| 682 |
+
e_flux_y_0 = index_periodic_edges_batched_2d(e_flux_y, b, i, j - 1, lx, ly)
|
| 683 |
+
new_e = apply_flux_2d(e_1, e_flux_x_1, e_flux_x_0, e_flux_y_1, e_flux_y_0, dx, dt)
|
| 684 |
+
|
| 685 |
+
# set values
|
| 686 |
+
mass[b, i, j] = new_mass
|
| 687 |
+
mom[b, i, j] = new_mom
|
| 688 |
+
e[b, i, j] = new_e
|
| 689 |
+
|
| 690 |
+
|
| 691 |
+
@wp.kernel
|
| 692 |
+
def initialize_kelvin_helmoltz_batched_2d(
|
| 693 |
+
rho: wp.array3d(dtype=float),
|
| 694 |
+
vel: wp.array3d(dtype=wp.vec2),
|
| 695 |
+
p: wp.array3d(dtype=float),
|
| 696 |
+
w: wp.array2d(dtype=float),
|
| 697 |
+
sigma: float,
|
| 698 |
+
lx: float,
|
| 699 |
+
ly: float,
|
| 700 |
+
nr_freq: int,
|
| 701 |
+
): # pragma: no cover
|
| 702 |
+
"""Initialize state for Kelvin Helmoltz Instability
|
| 703 |
+
|
| 704 |
+
Parameters
|
| 705 |
+
----------
|
| 706 |
+
rho : wp.array3d
|
| 707 |
+
Density
|
| 708 |
+
vel : wp.array3d
|
| 709 |
+
Velocity
|
| 710 |
+
p : wp.array3d
|
| 711 |
+
Pressure
|
| 712 |
+
w : wp.array2d
|
| 713 |
+
Perturbation frequency amplitude
|
| 714 |
+
sigma : float
|
| 715 |
+
Perturbation sigma
|
| 716 |
+
vol : float
|
| 717 |
+
Volume of cell
|
| 718 |
+
gamma : float
|
| 719 |
+
Gas constant
|
| 720 |
+
lx : float
|
| 721 |
+
Grid size x
|
| 722 |
+
ly : float
|
| 723 |
+
Grid size y
|
| 724 |
+
nr_freq : int
|
| 725 |
+
Number of frequencies in perturbation
|
| 726 |
+
"""
|
| 727 |
+
|
| 728 |
+
# get cell coords
|
| 729 |
+
b, i, j = wp.tid()
|
| 730 |
+
x = wp.float(i) / wp.float(lx)
|
| 731 |
+
y = wp.float(j) / wp.float(ly)
|
| 732 |
+
|
| 733 |
+
# initial flow bands
|
| 734 |
+
if wp.abs(y - 0.5) < 0.25:
|
| 735 |
+
ux = 0.5
|
| 736 |
+
r = 2.0
|
| 737 |
+
else:
|
| 738 |
+
ux = -0.5
|
| 739 |
+
r = 1.0
|
| 740 |
+
|
| 741 |
+
# perturbation
|
| 742 |
+
uy = wp.float32(0.0)
|
| 743 |
+
for f in range(nr_freq):
|
| 744 |
+
ff = wp.float32(f + 1)
|
| 745 |
+
uy += (
|
| 746 |
+
ff
|
| 747 |
+
* w[b, f]
|
| 748 |
+
* wp.sin(4.0 * 3.14159 * x * ff)
|
| 749 |
+
* (
|
| 750 |
+
wp.exp(-(y - 0.25) * (y - 0.25) / (2.0 * sigma * sigma))
|
| 751 |
+
+ wp.exp(-(y - 0.75) * (y - 0.75) / (2.0 * sigma * sigma))
|
| 752 |
+
)
|
| 753 |
+
)
|
| 754 |
+
u = wp.vec2(ux, uy)
|
| 755 |
+
|
| 756 |
+
# set values
|
| 757 |
+
rho[b, i, j] = r
|
| 758 |
+
vel[b, i, j] = u
|
| 759 |
+
p[b, i, j] = 2.5
|
physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/indexing.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import warp as wp
|
| 19 |
+
except ImportError:
|
| 20 |
+
print(
|
| 21 |
+
"""NVIDIA WARP is required for this datapipe. This package is under the
|
| 22 |
+
NVIDIA Source Code License (NVSCL). To install use:
|
| 23 |
+
|
| 24 |
+
pip install warp-lang
|
| 25 |
+
"""
|
| 26 |
+
)
|
| 27 |
+
raise SystemExit(1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# TODO bug in warp mod function
|
| 31 |
+
@wp.func
|
| 32 |
+
def _mod_int(x: int, length: int): # pragma: no cover
|
| 33 |
+
"""Mod int
|
| 34 |
+
|
| 35 |
+
Parameters
|
| 36 |
+
----------
|
| 37 |
+
x : int
|
| 38 |
+
Int to mod
|
| 39 |
+
length : int
|
| 40 |
+
Mod by value
|
| 41 |
+
|
| 42 |
+
Returns
|
| 43 |
+
-------
|
| 44 |
+
int
|
| 45 |
+
Mod of x
|
| 46 |
+
"""
|
| 47 |
+
if x < 0:
|
| 48 |
+
return x + length
|
| 49 |
+
elif x > length - 1:
|
| 50 |
+
return x - length
|
| 51 |
+
return x
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@wp.func
|
| 55 |
+
def index_zero_edges_batched_2d(
|
| 56 |
+
array: wp.array3d(dtype=float), b: int, x: int, y: int, lx: int, ly: int
|
| 57 |
+
): # pragma: no cover
|
| 58 |
+
"""Index batched 2d array with zero on edges
|
| 59 |
+
|
| 60 |
+
Parameters
|
| 61 |
+
----------
|
| 62 |
+
array : wp.array3d
|
| 63 |
+
Array to index
|
| 64 |
+
b : int
|
| 65 |
+
Batch index
|
| 66 |
+
x : int
|
| 67 |
+
X index
|
| 68 |
+
y : int
|
| 69 |
+
Y index
|
| 70 |
+
lx : int
|
| 71 |
+
Grid size x
|
| 72 |
+
ly : int
|
| 73 |
+
Grid size y
|
| 74 |
+
|
| 75 |
+
Returns
|
| 76 |
+
-------
|
| 77 |
+
float
|
| 78 |
+
Array value
|
| 79 |
+
"""
|
| 80 |
+
if x == -1:
|
| 81 |
+
return 0.0
|
| 82 |
+
elif x == lx:
|
| 83 |
+
return 0.0
|
| 84 |
+
elif y == -1:
|
| 85 |
+
return 0.0
|
| 86 |
+
elif y == ly:
|
| 87 |
+
return 0.0
|
| 88 |
+
else:
|
| 89 |
+
return array[b, x, y]
|
| 90 |
+
|
| 91 |
+
|
| 92 |
+
@wp.func
|
| 93 |
+
def index_clamped_edges_batched_2d(
|
| 94 |
+
array: wp.array3d(dtype=float), b: int, x: int, y: int, lx: int, ly: int
|
| 95 |
+
): # pragma: no cover
|
| 96 |
+
"""Index batched 2d array with edges clamped to same value
|
| 97 |
+
|
| 98 |
+
Parameters
|
| 99 |
+
----------
|
| 100 |
+
array : wp.array3d
|
| 101 |
+
Array to index
|
| 102 |
+
b : int
|
| 103 |
+
Batch index
|
| 104 |
+
x : int
|
| 105 |
+
X index
|
| 106 |
+
y : int
|
| 107 |
+
Y index
|
| 108 |
+
lx : int
|
| 109 |
+
Grid size x
|
| 110 |
+
ly : int
|
| 111 |
+
Grid size y
|
| 112 |
+
|
| 113 |
+
Returns
|
| 114 |
+
-------
|
| 115 |
+
float
|
| 116 |
+
Array value
|
| 117 |
+
"""
|
| 118 |
+
x = wp.clamp(x, 0, lx - 1)
|
| 119 |
+
y = wp.clamp(y, 0, ly - 1)
|
| 120 |
+
return array[b, x, y]
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
@wp.func
|
| 124 |
+
def index_periodic_edges_batched_2d(
|
| 125 |
+
array: wp.array3d(dtype=float), b: int, x: int, y: int, lx: int, ly: int
|
| 126 |
+
): # pragma: no cover
|
| 127 |
+
"""Index batched 2d array with periodic edges
|
| 128 |
+
|
| 129 |
+
Parameters
|
| 130 |
+
----------
|
| 131 |
+
array : wp.array3d
|
| 132 |
+
Array to index
|
| 133 |
+
b : int
|
| 134 |
+
Batch index
|
| 135 |
+
x : int
|
| 136 |
+
X index
|
| 137 |
+
y : int
|
| 138 |
+
Y index
|
| 139 |
+
lx : int
|
| 140 |
+
Grid size x
|
| 141 |
+
ly : int
|
| 142 |
+
Grid size y
|
| 143 |
+
|
| 144 |
+
Returns
|
| 145 |
+
-------
|
| 146 |
+
float
|
| 147 |
+
Array value
|
| 148 |
+
"""
|
| 149 |
+
x = _mod_int(x, lx)
|
| 150 |
+
y = _mod_int(y, ly)
|
| 151 |
+
return array[b, x, y]
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
@wp.func
|
| 155 |
+
def index_vec2_periodic_edges_batched_2d(
|
| 156 |
+
vec: wp.array3d(dtype=wp.vec2), b: int, x: int, y: int, lx: int, ly: int
|
| 157 |
+
): # pragma: no cover
|
| 158 |
+
"""Index batched 2d array of wp.vec2 with periodic edges
|
| 159 |
+
|
| 160 |
+
Parameters
|
| 161 |
+
----------
|
| 162 |
+
vec : wp.array3d
|
| 163 |
+
Array to index
|
| 164 |
+
b : int
|
| 165 |
+
Batch index
|
| 166 |
+
x : int
|
| 167 |
+
X index
|
| 168 |
+
y : int
|
| 169 |
+
Y index
|
| 170 |
+
lx : int
|
| 171 |
+
Grid size x
|
| 172 |
+
ly : int
|
| 173 |
+
Grid size y
|
| 174 |
+
|
| 175 |
+
Returns
|
| 176 |
+
-------
|
| 177 |
+
wp.vec2
|
| 178 |
+
Vector value
|
| 179 |
+
"""
|
| 180 |
+
x = _mod_int(x, lx)
|
| 181 |
+
y = _mod_int(y, ly)
|
| 182 |
+
return vec[b, x, y]
|
physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/initialization.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import warp as wp
|
| 19 |
+
except ImportError:
|
| 20 |
+
print(
|
| 21 |
+
"""NVIDIA WARP is required for this datapipe. This package is under the
|
| 22 |
+
NVIDIA Source Code License (NVSCL). To install use:
|
| 23 |
+
|
| 24 |
+
pip install warp-lang
|
| 25 |
+
"""
|
| 26 |
+
)
|
| 27 |
+
raise SystemExit(1)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@wp.kernel
|
| 31 |
+
def init_uniform_random_2d(
|
| 32 |
+
array: wp.array2d(dtype=float),
|
| 33 |
+
min_value: float,
|
| 34 |
+
max_value: float,
|
| 35 |
+
external_seed: int,
|
| 36 |
+
): # pragma: no cover
|
| 37 |
+
"""Initialize 2d array with uniform random values
|
| 38 |
+
|
| 39 |
+
Parameters
|
| 40 |
+
----------
|
| 41 |
+
array : wp.array2d
|
| 42 |
+
Array to initialize
|
| 43 |
+
min_value : float
|
| 44 |
+
Min random value
|
| 45 |
+
max_value : float
|
| 46 |
+
Max random value
|
| 47 |
+
external_seed : int
|
| 48 |
+
External seed to use
|
| 49 |
+
"""
|
| 50 |
+
i, j = wp.tid()
|
| 51 |
+
state = wp.rand_init(external_seed, wp.tid())
|
| 52 |
+
array[i, j] = wp.randf(state, -min_value, max_value)
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
@wp.kernel
|
| 56 |
+
def init_uniform_random_4d(
|
| 57 |
+
array: wp.array4d(dtype=float),
|
| 58 |
+
min_value: float,
|
| 59 |
+
max_value: float,
|
| 60 |
+
external_seed: int,
|
| 61 |
+
): # pragma: no cover
|
| 62 |
+
"""Initialize 4d array with uniform random values
|
| 63 |
+
|
| 64 |
+
Parameters
|
| 65 |
+
----------
|
| 66 |
+
array : wp.array4d
|
| 67 |
+
Array to initialize
|
| 68 |
+
min_value : float
|
| 69 |
+
Min random value
|
| 70 |
+
max_value : float
|
| 71 |
+
Max random value
|
| 72 |
+
external_seed : int
|
| 73 |
+
External seed to use
|
| 74 |
+
"""
|
| 75 |
+
b, i, j, k = wp.tid()
|
| 76 |
+
state = wp.rand_init(external_seed, wp.tid())
|
| 77 |
+
array[b, i, j, k] = wp.randf(state, min_value, max_value)
|
physics_mcp/source/physicsnemo/datapipes/benchmarks/kernels/utils.py
ADDED
|
@@ -0,0 +1,141 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
try:
|
| 18 |
+
import warp as wp
|
| 19 |
+
except ImportError:
|
| 20 |
+
print(
|
| 21 |
+
"""NVIDIA WARP is required for this datapipe. This package is under the
|
| 22 |
+
NVIDIA Source Code License (NVSCL). To install use:
|
| 23 |
+
|
| 24 |
+
pip install warp-lang
|
| 25 |
+
"""
|
| 26 |
+
)
|
| 27 |
+
raise SystemExit(1)
|
| 28 |
+
|
| 29 |
+
from .indexing import index_zero_edges_batched_2d
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
@wp.kernel
|
| 33 |
+
def bilinear_upsample_batched_2d(
|
| 34 |
+
array: wp.array3d(dtype=float), lx: int, ly: int, grid_reduction_factor: int
|
| 35 |
+
): # pragma: no cover
|
| 36 |
+
"""Bilinear upsampling from batch 2d array
|
| 37 |
+
|
| 38 |
+
Parameters
|
| 39 |
+
----------
|
| 40 |
+
array : wp.array3d
|
| 41 |
+
Array to perform upsampling on
|
| 42 |
+
lx : int
|
| 43 |
+
Grid size X
|
| 44 |
+
ly : int
|
| 45 |
+
Grid size Y
|
| 46 |
+
grid_reduction_factor : int
|
| 47 |
+
Grid reduction factor for multi-grid
|
| 48 |
+
"""
|
| 49 |
+
# get index
|
| 50 |
+
b, x, y = wp.tid()
|
| 51 |
+
|
| 52 |
+
# get four neighbors coordinates
|
| 53 |
+
x_0 = x - (x + 1) % grid_reduction_factor
|
| 54 |
+
x_1 = x + (x + 1) % grid_reduction_factor
|
| 55 |
+
y_0 = y - (y + 1) % grid_reduction_factor
|
| 56 |
+
y_1 = y + (y + 1) % grid_reduction_factor
|
| 57 |
+
|
| 58 |
+
# simple linear upsampling
|
| 59 |
+
d_0_0 = index_zero_edges_batched_2d(array, b, x_0, y_0, lx, ly)
|
| 60 |
+
d_1_0 = index_zero_edges_batched_2d(array, b, x_1, y_0, lx, ly)
|
| 61 |
+
d_0_1 = index_zero_edges_batched_2d(array, b, x_0, y_1, lx, ly)
|
| 62 |
+
d_1_1 = index_zero_edges_batched_2d(array, b, x_1, y_1, lx, ly)
|
| 63 |
+
|
| 64 |
+
# get relative distance
|
| 65 |
+
rel_x = wp.float32(x - x_0) / wp.float32(grid_reduction_factor)
|
| 66 |
+
rel_y = wp.float32(y - y_0) / wp.float32(grid_reduction_factor)
|
| 67 |
+
|
| 68 |
+
# interpolation in x direction
|
| 69 |
+
d_x_0 = (1.0 - rel_x) * d_0_0 + rel_x * d_1_0
|
| 70 |
+
d_x_1 = (1.0 - rel_x) * d_0_1 + rel_x * d_1_1
|
| 71 |
+
|
| 72 |
+
# interpolation in y direction
|
| 73 |
+
d = (1.0 - rel_y) * d_x_0 + rel_y * d_x_1
|
| 74 |
+
|
| 75 |
+
# set interpolation
|
| 76 |
+
array[b, x, y] = d
|
| 77 |
+
|
| 78 |
+
|
| 79 |
+
@wp.kernel
|
| 80 |
+
def threshold_3d(
|
| 81 |
+
array: wp.array3d(dtype=float), threshold: float, min_value: float, max_value: float
|
| 82 |
+
): # pragma: no cover
|
| 83 |
+
"""Threshold 3d array by value. Values bellow threshold will be `min_value` and those above will be `max_value`.
|
| 84 |
+
|
| 85 |
+
Parameters
|
| 86 |
+
----------
|
| 87 |
+
array : wp.array3d
|
| 88 |
+
Array to apply threshold on
|
| 89 |
+
threshold : float
|
| 90 |
+
Threshold value
|
| 91 |
+
min_value : float
|
| 92 |
+
Value to set if bellow threshold
|
| 93 |
+
max_value : float
|
| 94 |
+
Value to set if above threshold
|
| 95 |
+
"""
|
| 96 |
+
i, j, k = wp.tid()
|
| 97 |
+
if array[i, j, k] < threshold:
|
| 98 |
+
array[i, j, k] = min_value
|
| 99 |
+
else:
|
| 100 |
+
array[i, j, k] = max_value
|
| 101 |
+
|
| 102 |
+
|
| 103 |
+
@wp.kernel
|
| 104 |
+
def fourier_to_array_batched_2d(
|
| 105 |
+
array: wp.array3d(dtype=float),
|
| 106 |
+
fourier: wp.array4d(dtype=float),
|
| 107 |
+
nr_freq: int,
|
| 108 |
+
lx: int,
|
| 109 |
+
ly: int,
|
| 110 |
+
): # pragma: no cover
|
| 111 |
+
"""Array of Fourier amplitudes to batched 2d spatial array
|
| 112 |
+
|
| 113 |
+
Parameters
|
| 114 |
+
----------
|
| 115 |
+
array : wp.array3d
|
| 116 |
+
Spatial array
|
| 117 |
+
fourier : wp.array4d
|
| 118 |
+
Array of Fourier amplitudes
|
| 119 |
+
nr_freq : int
|
| 120 |
+
Number of frequencies in Fourier array
|
| 121 |
+
lx : int
|
| 122 |
+
Grid size x
|
| 123 |
+
ly : int
|
| 124 |
+
Grid size y
|
| 125 |
+
"""
|
| 126 |
+
b, x, y = wp.tid()
|
| 127 |
+
dx = 6.28318 / wp.float32(lx)
|
| 128 |
+
dy = 6.28318 / wp.float32(ly)
|
| 129 |
+
rx = dx * wp.float32(x)
|
| 130 |
+
ry = dy * wp.float32(y)
|
| 131 |
+
for i in range(nr_freq):
|
| 132 |
+
for j in range(nr_freq):
|
| 133 |
+
ri = wp.float32(i)
|
| 134 |
+
rj = wp.float32(j)
|
| 135 |
+
ss = fourier[0, b, i, j] * wp.sin(ri * rx) * wp.sin(rj * ry)
|
| 136 |
+
cs = fourier[1, b, i, j] * wp.cos(ri * rx) * wp.sin(rj * ry)
|
| 137 |
+
sc = fourier[2, b, i, j] * wp.sin(ri * rx) * wp.cos(rj * ry)
|
| 138 |
+
cc = fourier[3, b, i, j] * wp.cos(ri * rx) * wp.cos(rj * ry)
|
| 139 |
+
wp.atomic_add(
|
| 140 |
+
array, b, x, y, 1.0 / (wp.float32(nr_freq) ** 2.0) * (ss + cs + sc + cc)
|
| 141 |
+
)
|
physics_mcp/source/physicsnemo/datapipes/cae/__init__.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from .domino_datapipe import DoMINODataPipe
|
| 18 |
+
from .mesh_datapipe import MeshDatapipe
|
physics_mcp/source/physicsnemo/datapipes/cae/cae_dataset.py
ADDED
|
@@ -0,0 +1,1275 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import pathlib
|
| 18 |
+
import time
|
| 19 |
+
from abc import ABC, abstractmethod
|
| 20 |
+
from concurrent.futures import ThreadPoolExecutor
|
| 21 |
+
|
| 22 |
+
import numpy as np
|
| 23 |
+
import torch
|
| 24 |
+
import torch.distributed as dist
|
| 25 |
+
import zarr
|
| 26 |
+
from torch.distributed.tensor import Replicate, Shard
|
| 27 |
+
|
| 28 |
+
try:
|
| 29 |
+
import tensorstore as ts
|
| 30 |
+
|
| 31 |
+
TENSORSTORE_AVAILABLE = True
|
| 32 |
+
except ImportError:
|
| 33 |
+
TENSORSTORE_AVAILABLE = False
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
import pyvista as pv
|
| 37 |
+
|
| 38 |
+
PV_AVAILABLE = True
|
| 39 |
+
except ImportError:
|
| 40 |
+
PV_AVAILABLE = False
|
| 41 |
+
|
| 42 |
+
from physicsnemo.distributed import ShardTensor, ShardTensorSpec
|
| 43 |
+
from physicsnemo.distributed.utils import compute_split_shapes
|
| 44 |
+
|
| 45 |
+
# Abstractions:
|
| 46 |
+
# - want to read npy/npz/.zarr/.stl/.vtp files
|
| 47 |
+
# - Need to share next level abstractions
|
| 48 |
+
# - Domain parallel dataloading is supported: output will be ShardTensor instead.
|
| 49 |
+
# - need to be able to configure preprocessing
|
| 50 |
+
# - CPU -> GPU transfer happens here, needs to be isolated in it's own stream
|
| 51 |
+
# - Output of dataloader should be torch.Tensor objects.
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
"""
|
| 55 |
+
This datapipe handles reading files from Zarr and piping into torch.Tensor objects.
|
| 56 |
+
|
| 57 |
+
It's expected that the files are organized as groups, with each .zarr
|
| 58 |
+
file representing one training example. To improve IO performance, the files
|
| 59 |
+
should be chunked for each array. The reader takes a list of keys in the
|
| 60 |
+
group to read, and will not read keys that are not specified. The exception
|
| 61 |
+
is if _no_ keys are passed, in which case _all_ keys will be read.
|
| 62 |
+
"""
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class BackendReader(ABC):
|
| 66 |
+
"""
|
| 67 |
+
Abstract base class for backend readers.
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
def __init__(
|
| 71 |
+
self,
|
| 72 |
+
keys_to_read: list[str] | None,
|
| 73 |
+
keys_to_read_if_available: dict[str, torch.Tensor] | None,
|
| 74 |
+
) -> None:
|
| 75 |
+
"""
|
| 76 |
+
Initialize the backend reader.
|
| 77 |
+
"""
|
| 78 |
+
self.keys_to_read = keys_to_read
|
| 79 |
+
self.keys_to_read_if_available = keys_to_read_if_available
|
| 80 |
+
|
| 81 |
+
self.volume_sampling_size = None
|
| 82 |
+
|
| 83 |
+
self.is_volumetric = any(["volume" in key for key in self.keys_to_read])
|
| 84 |
+
|
| 85 |
+
@abstractmethod
|
| 86 |
+
def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
|
| 87 |
+
"""
|
| 88 |
+
Read a file and return a dictionary of tensors.
|
| 89 |
+
"""
|
| 90 |
+
pass
|
| 91 |
+
|
| 92 |
+
@abstractmethod
|
| 93 |
+
def read_file_sharded(
|
| 94 |
+
self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh
|
| 95 |
+
) -> tuple[dict[str, torch.Tensor], dict[str, dict]]:
|
| 96 |
+
"""
|
| 97 |
+
Read a file and return a dictionary of tensors ready to convert to ShardTensors.
|
| 98 |
+
|
| 99 |
+
NOTE: this function does not actually convert torch tensors to ShardTensors.
|
| 100 |
+
It's possible that the conversion, in some cases, can be a collective function.
|
| 101 |
+
Due to the async nature of the loader, we don't rely on any ordering of
|
| 102 |
+
collectives and defer them to the last possible minute.
|
| 103 |
+
|
| 104 |
+
Additionally, these functions return CPU tensors and we don't actually
|
| 105 |
+
define shard tensors on cpu.
|
| 106 |
+
|
| 107 |
+
So, the dataset itself will convert a local tensor + shard info to shard tensor
|
| 108 |
+
after the cpu-> gpu movement.
|
| 109 |
+
"""
|
| 110 |
+
pass
|
| 111 |
+
|
| 112 |
+
def fill_optional_keys(
|
| 113 |
+
self, data: dict[str, torch.Tensor]
|
| 114 |
+
) -> dict[str, torch.Tensor]:
|
| 115 |
+
"""
|
| 116 |
+
Fill missing keys with the keys from the keys_to_read_if_available dictionary.
|
| 117 |
+
"""
|
| 118 |
+
for key in self.keys_to_read_if_available:
|
| 119 |
+
if key not in data.keys():
|
| 120 |
+
data[key] = self.keys_to_read_if_available[key]
|
| 121 |
+
return data
|
| 122 |
+
|
| 123 |
+
def _get_slice_boundaries(
|
| 124 |
+
self, array_shape: tuple[int], this_rank: int, n_splits: int, split_dim: int = 0
|
| 125 |
+
) -> tuple[int, int, tuple | None]:
|
| 126 |
+
"""
|
| 127 |
+
For an array, determine the slice boundaries for parallel reading.
|
| 128 |
+
|
| 129 |
+
Args:
|
| 130 |
+
array_shape: The total shape of the target array.
|
| 131 |
+
this_rank: The rank of the distributed process.
|
| 132 |
+
n_splits: The size of the distributed process.
|
| 133 |
+
split_dim: The dimension to split, default is 0.
|
| 134 |
+
|
| 135 |
+
Returns:
|
| 136 |
+
The slice boundaries for parallel reading.
|
| 137 |
+
"""
|
| 138 |
+
# Determine what slice this rank should read
|
| 139 |
+
|
| 140 |
+
sections = compute_split_shapes(array_shape[split_dim], n_splits)
|
| 141 |
+
|
| 142 |
+
global_chunk_start = sum(sections[:this_rank])
|
| 143 |
+
global_chunk_stop = global_chunk_start + sections[this_rank]
|
| 144 |
+
|
| 145 |
+
chunk_sizes = tuple(
|
| 146 |
+
array_shape[:split_dim] + (section,) + array_shape[split_dim + 1 :]
|
| 147 |
+
for section in sections
|
| 148 |
+
)
|
| 149 |
+
|
| 150 |
+
return global_chunk_start, global_chunk_stop, chunk_sizes
|
| 151 |
+
|
| 152 |
+
def set_volume_sampling_size(self, volume_sampling_size: int):
|
| 153 |
+
"""
|
| 154 |
+
Set the volume sampling size. When set, the readers will
|
| 155 |
+
assume the volumetric data is shuffled on disk and read only
|
| 156 |
+
contiguous chunks of the data up to the sampling size.
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
volume_sampling_size: The total size of the volume sampling.
|
| 161 |
+
|
| 162 |
+
"""
|
| 163 |
+
self.volume_sampling_size = volume_sampling_size
|
| 164 |
+
|
| 165 |
+
def select_random_sections_from_slice(
|
| 166 |
+
self,
|
| 167 |
+
slice_start: int,
|
| 168 |
+
slice_stop: int,
|
| 169 |
+
n_points: int,
|
| 170 |
+
) -> slice:
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
select the contiguous chunks of the volume data to read.
|
| 174 |
+
|
| 175 |
+
Args:
|
| 176 |
+
n_volume_points: The number of points to sample from the volume.
|
| 177 |
+
|
| 178 |
+
Returns:
|
| 179 |
+
A tuple of the start and stop indices of the contiguous chunks.
|
| 180 |
+
"""
|
| 181 |
+
|
| 182 |
+
if slice_stop - slice_start < n_points:
|
| 183 |
+
raise ValueError(
|
| 184 |
+
f"Slice size {slice_stop - slice_start} is less than the number of points {n_points}"
|
| 185 |
+
)
|
| 186 |
+
|
| 187 |
+
# Choose a random start point that will fit the entire n_points region:
|
| 188 |
+
start = np.random.randint(slice_start, slice_stop - n_points)
|
| 189 |
+
return slice(start, start + n_points)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
class NpyFileReader(BackendReader):
|
| 193 |
+
"""
|
| 194 |
+
Reader for numpy files.
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
def __init__(
|
| 198 |
+
self,
|
| 199 |
+
keys_to_read: list[str] | None,
|
| 200 |
+
keys_to_read_if_available: dict[str, torch.Tensor] | None,
|
| 201 |
+
) -> None:
|
| 202 |
+
super().__init__(keys_to_read, keys_to_read_if_available)
|
| 203 |
+
|
| 204 |
+
def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
|
| 205 |
+
"""
|
| 206 |
+
Read a file and return a dictionary of tensors.
|
| 207 |
+
"""
|
| 208 |
+
data = np.load(filename, allow_pickle=True).item()
|
| 209 |
+
|
| 210 |
+
missing_keys = set(self.keys_to_read) - set(data.keys())
|
| 211 |
+
|
| 212 |
+
if len(missing_keys) > 0:
|
| 213 |
+
raise ValueError(f"Keys {missing_keys} not found in file {filename}")
|
| 214 |
+
|
| 215 |
+
data = {key: torch.from_numpy(data[key]) for key in self.keys_to_read}
|
| 216 |
+
|
| 217 |
+
return self.fill_optional_keys(data)
|
| 218 |
+
|
| 219 |
+
def read_file_sharded(
|
| 220 |
+
self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh
|
| 221 |
+
) -> dict[str, ShardTensor]:
|
| 222 |
+
pass
|
| 223 |
+
|
| 224 |
+
def set_volume_sampling_size(self, volume_sampling_size: int):
|
| 225 |
+
"""
|
| 226 |
+
This is not supported for npy files.
|
| 227 |
+
"""
|
| 228 |
+
raise NotImplementedError(
|
| 229 |
+
"volume sampling directly from disk is not supported for npy files."
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
|
| 233 |
+
class NpzFileReader(BackendReader):
|
| 234 |
+
"""
|
| 235 |
+
Reader for npz files.
|
| 236 |
+
"""
|
| 237 |
+
|
| 238 |
+
def __init__(
|
| 239 |
+
self,
|
| 240 |
+
keys_to_read: list[str] | None,
|
| 241 |
+
keys_to_read_if_available: dict[str, torch.Tensor] | None,
|
| 242 |
+
) -> None:
|
| 243 |
+
super().__init__(keys_to_read, keys_to_read_if_available)
|
| 244 |
+
|
| 245 |
+
def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
|
| 246 |
+
"""
|
| 247 |
+
Read a file and return a dictionary of tensors.
|
| 248 |
+
"""
|
| 249 |
+
in_data = np.load(filename)
|
| 250 |
+
|
| 251 |
+
keys_found = set(in_data.keys())
|
| 252 |
+
keys_missing = set(self.keys_to_read) - keys_found
|
| 253 |
+
if len(keys_missing) > 0:
|
| 254 |
+
raise ValueError(f"Keys {keys_missing} not found in file {filename}")
|
| 255 |
+
|
| 256 |
+
# Make sure to select the slice outside of the loop.
|
| 257 |
+
if self.is_volumetric:
|
| 258 |
+
if self.volume_sampling_size is not None:
|
| 259 |
+
volume_slice = self.select_random_sections_from_slice(
|
| 260 |
+
0,
|
| 261 |
+
in_data["volume_mesh_centers"].shape[0],
|
| 262 |
+
self.volume_sampling_size,
|
| 263 |
+
)
|
| 264 |
+
else:
|
| 265 |
+
volume_slice = slice(0, in_data["volume_mesh_centers"].shape[0])
|
| 266 |
+
|
| 267 |
+
# This is a slower basic way to do this, to be improved:
|
| 268 |
+
data = {}
|
| 269 |
+
for key in self.keys_to_read:
|
| 270 |
+
if "volume" not in key:
|
| 271 |
+
data[key] = torch.from_numpy(in_data[key][:])
|
| 272 |
+
else:
|
| 273 |
+
data[key] = torch.from_numpy(in_data[key][volume_slice])
|
| 274 |
+
|
| 275 |
+
# data = {key: torch.from_numpy(in_data[key][:]) for key in self.keys_to_read}
|
| 276 |
+
|
| 277 |
+
return self.fill_optional_keys(data)
|
| 278 |
+
|
| 279 |
+
def read_file_sharded(
|
| 280 |
+
self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh
|
| 281 |
+
) -> dict[str, ShardTensor]:
|
| 282 |
+
pass
|
| 283 |
+
|
| 284 |
+
def set_volume_sampling_size(self, volume_sampling_size: int):
|
| 285 |
+
"""
|
| 286 |
+
This is not supported for npz files.
|
| 287 |
+
"""
|
| 288 |
+
raise NotImplementedError(
|
| 289 |
+
"volume sampling directly from disk is not supported for npz files."
|
| 290 |
+
)
|
| 291 |
+
|
| 292 |
+
|
| 293 |
+
class ZarrFileReader(BackendReader):
|
| 294 |
+
"""
|
| 295 |
+
Reader for zarr files.
|
| 296 |
+
"""
|
| 297 |
+
|
| 298 |
+
def __init__(
|
| 299 |
+
self,
|
| 300 |
+
keys_to_read: list[str] | None,
|
| 301 |
+
keys_to_read_if_available: dict[str, torch.Tensor] | None,
|
| 302 |
+
) -> None:
|
| 303 |
+
super().__init__(keys_to_read, keys_to_read_if_available)
|
| 304 |
+
|
| 305 |
+
def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
|
| 306 |
+
"""
|
| 307 |
+
Read a file and return a dictionary of tensors.
|
| 308 |
+
"""
|
| 309 |
+
group = zarr.open_group(filename, mode="r")
|
| 310 |
+
|
| 311 |
+
missing_keys = set(self.keys_to_read) - set(group.keys())
|
| 312 |
+
|
| 313 |
+
if len(missing_keys) > 0:
|
| 314 |
+
raise ValueError(f"Keys {missing_keys} not found in file {filename}")
|
| 315 |
+
|
| 316 |
+
# Make sure to select the slice outside of the loop.
|
| 317 |
+
if self.is_volumetric:
|
| 318 |
+
if self.volume_sampling_size is not None:
|
| 319 |
+
volume_slice = self.select_random_sections_from_slice(
|
| 320 |
+
0,
|
| 321 |
+
group["volume_mesh_centers"].shape[0],
|
| 322 |
+
self.volume_sampling_size,
|
| 323 |
+
)
|
| 324 |
+
else:
|
| 325 |
+
volume_slice = slice(0, group["volume_mesh_centers"].shape[0])
|
| 326 |
+
|
| 327 |
+
# This is a slower basic way to do this, to be improved:
|
| 328 |
+
data = {}
|
| 329 |
+
for key in self.keys_to_read:
|
| 330 |
+
if "volume" not in key:
|
| 331 |
+
data[key] = torch.from_numpy(group[key][:])
|
| 332 |
+
else:
|
| 333 |
+
data[key] = torch.from_numpy(group[key][volume_slice])
|
| 334 |
+
|
| 335 |
+
return self.fill_optional_keys(data)
|
| 336 |
+
|
| 337 |
+
def read_file_sharded(
|
| 338 |
+
self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh
|
| 339 |
+
) -> tuple[dict[str, torch.Tensor], dict[str, dict]]:
|
| 340 |
+
"""
|
| 341 |
+
Read a file and return a dictionary of tensors.
|
| 342 |
+
"""
|
| 343 |
+
|
| 344 |
+
# We need the coordinates of this GPU:
|
| 345 |
+
this_rank = device_mesh.get_local_rank()
|
| 346 |
+
domain_size = dist.get_world_size(group=device_mesh.get_group())
|
| 347 |
+
|
| 348 |
+
group = zarr.open_group(filename, mode="r")
|
| 349 |
+
|
| 350 |
+
missing_keys = set(self.keys_to_read) - set(group.keys())
|
| 351 |
+
|
| 352 |
+
if len(missing_keys) > 0:
|
| 353 |
+
raise ValueError(f"Keys {missing_keys} not found in file {filename}")
|
| 354 |
+
|
| 355 |
+
data = {}
|
| 356 |
+
specs = {}
|
| 357 |
+
for key in self.keys_to_read:
|
| 358 |
+
# Open the array in zarr without reading it and get info:
|
| 359 |
+
zarr_array = group[key]
|
| 360 |
+
array_shape = zarr_array.shape
|
| 361 |
+
if array_shape == ():
|
| 362 |
+
# Read scalars from every rank and use replicate sharding
|
| 363 |
+
raw_data = torch.from_numpy(zarr_array[:])
|
| 364 |
+
placement = [
|
| 365 |
+
Replicate(),
|
| 366 |
+
]
|
| 367 |
+
chunk_sizes = None
|
| 368 |
+
else:
|
| 369 |
+
target_dim = 0
|
| 370 |
+
if array_shape[target_dim] < domain_size:
|
| 371 |
+
# If the array is smaller than the number of ranks,
|
| 372 |
+
# again read and use replicate sharding:
|
| 373 |
+
raw_data = torch.from_numpy(zarr_array[:])
|
| 374 |
+
placement = [
|
| 375 |
+
Replicate(),
|
| 376 |
+
]
|
| 377 |
+
chunk_sizes = None
|
| 378 |
+
else:
|
| 379 |
+
# Read partially from the data and use Shard(target_dim) sharding
|
| 380 |
+
chunk_start, chunk_stop, chunk_sizes = self._get_slice_boundaries(
|
| 381 |
+
zarr_array.shape, this_rank, domain_size
|
| 382 |
+
)
|
| 383 |
+
raw_data = torch.from_numpy(zarr_array[chunk_start:chunk_stop])
|
| 384 |
+
placement = [
|
| 385 |
+
Shard(target_dim),
|
| 386 |
+
]
|
| 387 |
+
|
| 388 |
+
# Turn chunk sizes into a dict over mesh dim 0:
|
| 389 |
+
chunk_sizes = {0: chunk_sizes}
|
| 390 |
+
|
| 391 |
+
#
|
| 392 |
+
data[key] = raw_data
|
| 393 |
+
specs[key] = (placement, chunk_sizes)
|
| 394 |
+
|
| 395 |
+
# Patch in the optional keys:
|
| 396 |
+
data = self.fill_optional_keys(data)
|
| 397 |
+
for key in data.keys():
|
| 398 |
+
if key not in specs:
|
| 399 |
+
specs[key] = (
|
| 400 |
+
[
|
| 401 |
+
Replicate(),
|
| 402 |
+
],
|
| 403 |
+
{},
|
| 404 |
+
)
|
| 405 |
+
|
| 406 |
+
return data, specs
|
| 407 |
+
|
| 408 |
+
|
| 409 |
+
if PV_AVAILABLE:
|
| 410 |
+
|
| 411 |
+
class VTKFileReader(BackendReader):
|
| 412 |
+
"""
|
| 413 |
+
Reader for vtk files.
|
| 414 |
+
"""
|
| 415 |
+
|
| 416 |
+
def __init__(
|
| 417 |
+
self,
|
| 418 |
+
keys_to_read: list[str] | None,
|
| 419 |
+
keys_to_read_if_available: dict[str, torch.Tensor] | None,
|
| 420 |
+
) -> None:
|
| 421 |
+
super().__init__(keys_to_read, keys_to_read_if_available)
|
| 422 |
+
|
| 423 |
+
self.stl_file_keys = [
|
| 424 |
+
"stl_coordinates",
|
| 425 |
+
"stl_centers",
|
| 426 |
+
"stl_faces",
|
| 427 |
+
"stl_areas",
|
| 428 |
+
]
|
| 429 |
+
self.vtp_file_keys = [
|
| 430 |
+
"surface_mesh_centers",
|
| 431 |
+
"surface_normals",
|
| 432 |
+
"surface_mesh_sizes",
|
| 433 |
+
"CpMeanTrim",
|
| 434 |
+
"pMeanTrim",
|
| 435 |
+
"wallShearStressMeanTrim",
|
| 436 |
+
]
|
| 437 |
+
self.vtu_file_keys = [
|
| 438 |
+
"volume_mesh_centers",
|
| 439 |
+
"volume_fields",
|
| 440 |
+
]
|
| 441 |
+
|
| 442 |
+
self.exclude_patterns = [
|
| 443 |
+
"single_solid",
|
| 444 |
+
]
|
| 445 |
+
|
| 446 |
+
def get_file_name(self, dir_name: pathlib.Path, extension: str) -> pathlib.Path:
|
| 447 |
+
"""
|
| 448 |
+
Get the file name for a given directory and extension.
|
| 449 |
+
"""
|
| 450 |
+
# >>> matches = [p for p in list(dir_name.iterdir()) if p.suffix == ".stl" and not any(pattern in p.name for pattern in exclude_patterns)]
|
| 451 |
+
matches = [
|
| 452 |
+
p
|
| 453 |
+
for p in dir_name.iterdir()
|
| 454 |
+
if p.suffix == extension
|
| 455 |
+
and not any(pattern in p.name for pattern in self.exclude_patterns)
|
| 456 |
+
]
|
| 457 |
+
if len(matches) == 0:
|
| 458 |
+
raise FileNotFoundError(f"No {extension} files found in {dir_name}")
|
| 459 |
+
fname = matches[0]
|
| 460 |
+
return dir_name / fname
|
| 461 |
+
|
| 462 |
+
def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
|
| 463 |
+
"""
|
| 464 |
+
Read a set of files and return a dictionary of tensors.
|
| 465 |
+
"""
|
| 466 |
+
|
| 467 |
+
# This reader attempts to only read what's necessary, and not more.
|
| 468 |
+
# So, the functions that do the reading are each "one file" functions
|
| 469 |
+
# and we open them for processing only when necessary.
|
| 470 |
+
|
| 471 |
+
return_data = {}
|
| 472 |
+
|
| 473 |
+
# Note that this reader is, already, running in a background thread.
|
| 474 |
+
# It may or may not help to further thread these calls.
|
| 475 |
+
if any(key in self.stl_file_keys for key in self.keys_to_read):
|
| 476 |
+
stl_path = self.get_file_name(filename, ".stl")
|
| 477 |
+
stl_data = self.read_data_from_stl(stl_path)
|
| 478 |
+
return_data.update(stl_data)
|
| 479 |
+
if any(key in self.vtp_file_keys for key in self.keys_to_read):
|
| 480 |
+
vtp_path = self.get_file_name(filename, ".vtp")
|
| 481 |
+
vtp_data = self.read_data_from_vtp(vtp_path)
|
| 482 |
+
return_data.update(vtp_data)
|
| 483 |
+
if any(key in self.vtu_file_keys for key in self.keys_to_read):
|
| 484 |
+
raise NotImplementedError("VTU files are not supported yet.")
|
| 485 |
+
|
| 486 |
+
return self.fill_optional_keys(return_data)
|
| 487 |
+
|
| 488 |
+
def read_file_sharded(
|
| 489 |
+
self, filename: pathlib.Path, parallel_rank: int, parallel_size: int
|
| 490 |
+
) -> tuple[dict[str, torch.Tensor], dict[str, ShardTensorSpec]]:
|
| 491 |
+
"""
|
| 492 |
+
Read a file and return a dictionary of tensors.
|
| 493 |
+
"""
|
| 494 |
+
raise NotImplementedError("Not implemented yet.")
|
| 495 |
+
|
| 496 |
+
def read_data_from_stl(
|
| 497 |
+
self,
|
| 498 |
+
stl_path: str,
|
| 499 |
+
) -> dict:
|
| 500 |
+
"""
|
| 501 |
+
Reads surface mesh data from an STL file and prepares a batch dictionary for inference.
|
| 502 |
+
|
| 503 |
+
Args:
|
| 504 |
+
stl_path (str): Path to the STL file.
|
| 505 |
+
|
| 506 |
+
Returns:
|
| 507 |
+
dict: Batch dictionary with mesh faces and coordinates as torch tensors.
|
| 508 |
+
"""
|
| 509 |
+
|
| 510 |
+
mesh = pv.read(stl_path)
|
| 511 |
+
|
| 512 |
+
batch = {}
|
| 513 |
+
|
| 514 |
+
faces = mesh.faces.reshape(-1, 4)
|
| 515 |
+
faces = faces[:, 1:]
|
| 516 |
+
|
| 517 |
+
batch["stl_faces"] = faces.flatten()
|
| 518 |
+
|
| 519 |
+
batch["stl_coordinates"] = mesh.points
|
| 520 |
+
batch["surface_normals"] = mesh.cell_normals
|
| 521 |
+
|
| 522 |
+
batch = {k: torch.from_numpy(v) for k, v in batch.items()}
|
| 523 |
+
|
| 524 |
+
return batch
|
| 525 |
+
|
| 526 |
+
def read_data_from_vtp(self, vtp_path: str) -> dict:
|
| 527 |
+
"""
|
| 528 |
+
Read vtp file from a file
|
| 529 |
+
"""
|
| 530 |
+
|
| 531 |
+
raise NotImplementedError("Not implemented yet.")
|
| 532 |
+
|
| 533 |
+
def set_volume_sampling_size(self, volume_sampling_size: int):
|
| 534 |
+
"""
|
| 535 |
+
This is not supported for vtk files.
|
| 536 |
+
"""
|
| 537 |
+
raise NotImplementedError(
|
| 538 |
+
"volume sampling directly from disk is not supported for vtk files."
|
| 539 |
+
)
|
| 540 |
+
|
| 541 |
+
|
| 542 |
+
if TENSORSTORE_AVAILABLE:
|
| 543 |
+
|
| 544 |
+
class TensorStoreZarrReader(BackendReader):
|
| 545 |
+
"""
|
| 546 |
+
Reader for tensorstore zarr files.
|
| 547 |
+
"""
|
| 548 |
+
|
| 549 |
+
def __init__(
|
| 550 |
+
self,
|
| 551 |
+
keys_to_read: list[str] | None,
|
| 552 |
+
keys_to_read_if_available: dict[str, torch.Tensor] | None,
|
| 553 |
+
cache_bytes_limit: int = 10_000_000,
|
| 554 |
+
data_copy_concurrency: int = 72,
|
| 555 |
+
file_io_concurrency: int = 72,
|
| 556 |
+
) -> None:
|
| 557 |
+
super().__init__(keys_to_read, keys_to_read_if_available)
|
| 558 |
+
|
| 559 |
+
self.spec_template = {
|
| 560 |
+
"driver": "auto",
|
| 561 |
+
"kvstore": {
|
| 562 |
+
"driver": "file",
|
| 563 |
+
"path": None,
|
| 564 |
+
},
|
| 565 |
+
}
|
| 566 |
+
|
| 567 |
+
self.context = ts.Context(
|
| 568 |
+
{
|
| 569 |
+
"cache_pool": {"total_bytes_limit": cache_bytes_limit},
|
| 570 |
+
"data_copy_concurrency": {"limit": data_copy_concurrency},
|
| 571 |
+
"file_io_concurrency": {"limit": file_io_concurrency},
|
| 572 |
+
}
|
| 573 |
+
)
|
| 574 |
+
|
| 575 |
+
def read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
|
| 576 |
+
"""
|
| 577 |
+
Read a file and return a dictionary of tensors.
|
| 578 |
+
"""
|
| 579 |
+
|
| 580 |
+
# Trigger an async open of each data item:
|
| 581 |
+
read_futures = {}
|
| 582 |
+
for key in self.keys_to_read:
|
| 583 |
+
spec = self.spec_template.copy()
|
| 584 |
+
spec["kvstore"]["path"] = str(filename) + "/" + str(key)
|
| 585 |
+
|
| 586 |
+
read_futures[key] = ts.open(
|
| 587 |
+
spec, create=False, open=True, context=self.context
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
# Wait for all the opens to conclude:
|
| 591 |
+
read_futures = {
|
| 592 |
+
key: read_futures[key].result() for key in read_futures.keys()
|
| 593 |
+
}
|
| 594 |
+
|
| 595 |
+
# Make sure to select the slice outside of the loop.
|
| 596 |
+
# We need
|
| 597 |
+
if self.is_volumetric:
|
| 598 |
+
if self.volume_sampling_size is not None:
|
| 599 |
+
volume_slice = self.select_random_sections_from_slice(
|
| 600 |
+
0,
|
| 601 |
+
read_futures["volume_mesh_centers"].shape[0],
|
| 602 |
+
self.volume_sampling_size,
|
| 603 |
+
)
|
| 604 |
+
else:
|
| 605 |
+
volume_slice = slice(
|
| 606 |
+
0, read_futures["volume_mesh_centers"].shape[0]
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
# Trigger an async read of each data item:
|
| 610 |
+
# (Each item will be a numpy ndarray after this:)
|
| 611 |
+
tensor_futures = {}
|
| 612 |
+
for key in self.keys_to_read:
|
| 613 |
+
if "volume" not in key:
|
| 614 |
+
tensor_futures[key] = read_futures[key].read()
|
| 615 |
+
# For the volume data, read the slice:
|
| 616 |
+
else:
|
| 617 |
+
tensor_futures[key] = read_futures[key][volume_slice].read()
|
| 618 |
+
|
| 619 |
+
# Convert them to torch tensors:
|
| 620 |
+
# (make sure to block for the result)
|
| 621 |
+
data = {
|
| 622 |
+
key: torch.as_tensor(tensor_futures[key].result(), dtype=torch.float32)
|
| 623 |
+
for key in self.keys_to_read
|
| 624 |
+
}
|
| 625 |
+
|
| 626 |
+
return self.fill_optional_keys(data)
|
| 627 |
+
|
| 628 |
+
def read_file_sharded(
|
| 629 |
+
self, filename: pathlib.Path, device_mesh: torch.distributed.DeviceMesh
|
| 630 |
+
) -> tuple[dict[str, torch.Tensor], dict[str, dict]]:
|
| 631 |
+
"""
|
| 632 |
+
Read a file and return a dictionary of tensors.
|
| 633 |
+
"""
|
| 634 |
+
|
| 635 |
+
# We need the coordinates of this GPU:
|
| 636 |
+
this_rank = device_mesh.get_local_rank()
|
| 637 |
+
domain_size = dist.get_world_size(group=device_mesh.get_group())
|
| 638 |
+
|
| 639 |
+
# This pulls a list of store objects in tensorstore:
|
| 640 |
+
stores = {}
|
| 641 |
+
for key in self.keys_to_read:
|
| 642 |
+
spec = self.spec_template.copy()
|
| 643 |
+
spec["kvstore"]["path"] = str(filename) + "/" + str(key)
|
| 644 |
+
|
| 645 |
+
stores[key] = ts.open(
|
| 646 |
+
spec, create=False, open=True, context=self.context
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
stores = {key: stores[key].result() for key in stores.keys()}
|
| 650 |
+
|
| 651 |
+
data = {}
|
| 652 |
+
specs = {}
|
| 653 |
+
for key in self.keys_to_read:
|
| 654 |
+
# Open the array in zarr without reading it and get info:
|
| 655 |
+
store = stores[key]
|
| 656 |
+
array_shape = store.shape
|
| 657 |
+
if array_shape == ():
|
| 658 |
+
# Read scalars from every rank and use replicate sharding
|
| 659 |
+
_slice = np.s_[:]
|
| 660 |
+
# raw_data = torch.from_numpy(store[:])
|
| 661 |
+
placement = [
|
| 662 |
+
Replicate(),
|
| 663 |
+
]
|
| 664 |
+
chunk_sizes = None
|
| 665 |
+
else:
|
| 666 |
+
target_dim = 0
|
| 667 |
+
if array_shape[target_dim] < domain_size:
|
| 668 |
+
# If the array is smaller than the number of ranks,
|
| 669 |
+
# again read and use replicate sharding:
|
| 670 |
+
_slice = np.s_[:]
|
| 671 |
+
# raw_data = torch.from_numpy(store[:])
|
| 672 |
+
placement = [
|
| 673 |
+
Replicate(),
|
| 674 |
+
]
|
| 675 |
+
chunk_sizes = None
|
| 676 |
+
else:
|
| 677 |
+
# Read partially from the data and use Shard(target_dim) sharding
|
| 678 |
+
chunk_start, chunk_stop, chunk_sizes = (
|
| 679 |
+
self._get_slice_boundaries(
|
| 680 |
+
store.shape, this_rank, domain_size
|
| 681 |
+
)
|
| 682 |
+
)
|
| 683 |
+
_slice = np.s_[chunk_start:chunk_stop]
|
| 684 |
+
# raw_data = torch.from_numpy(zarr_array[chunk_start:chunk_stop])
|
| 685 |
+
placement = [
|
| 686 |
+
Shard(target_dim),
|
| 687 |
+
]
|
| 688 |
+
|
| 689 |
+
# Turn chunk sizes into a dict over mesh dim 0:
|
| 690 |
+
chunk_sizes = {0: chunk_sizes}
|
| 691 |
+
|
| 692 |
+
# Trigger the reads as async:
|
| 693 |
+
data[key] = store[_slice].read()
|
| 694 |
+
specs[key] = (placement, chunk_sizes)
|
| 695 |
+
|
| 696 |
+
# Finally, await the full data read:
|
| 697 |
+
for key in self.keys_to_read:
|
| 698 |
+
data[key] = torch.as_tensor(data[key].result())
|
| 699 |
+
|
| 700 |
+
# Patch in the optional keys:
|
| 701 |
+
data = self.fill_optional_keys(data)
|
| 702 |
+
for key in data.keys():
|
| 703 |
+
if key not in specs:
|
| 704 |
+
specs[key] = (
|
| 705 |
+
[
|
| 706 |
+
Replicate(),
|
| 707 |
+
],
|
| 708 |
+
{},
|
| 709 |
+
)
|
| 710 |
+
|
| 711 |
+
return data, specs
|
| 712 |
+
|
| 713 |
+
else:
|
| 714 |
+
|
| 715 |
+
class TensorStoreZarrReader(BackendReader):
|
| 716 |
+
"""
|
| 717 |
+
Null reader for tensorstore zarr files.
|
| 718 |
+
"""
|
| 719 |
+
|
| 720 |
+
def __init__(
|
| 721 |
+
self,
|
| 722 |
+
keys_to_read: list[str] | None,
|
| 723 |
+
keys_to_read_if_available: dict[str, torch.Tensor] | None,
|
| 724 |
+
) -> None:
|
| 725 |
+
# Raise an exception on construction if we get here:
|
| 726 |
+
raise NotImplementedError(
|
| 727 |
+
"TensorStoreZarrReader is not available without tensorstore. `pip install tensorstore`."
|
| 728 |
+
)
|
| 729 |
+
|
| 730 |
+
|
| 731 |
+
def is_vtk_directory(file: pathlib.Path) -> bool:
|
| 732 |
+
"""
|
| 733 |
+
Check if a file is a vtk directory.
|
| 734 |
+
"""
|
| 735 |
+
return file.is_dir() and all(
|
| 736 |
+
[f.suffix in [".vtp", ".stl", ".vtu", ".vtk", ".csv"] for f in file.iterdir()]
|
| 737 |
+
)
|
| 738 |
+
|
| 739 |
+
|
| 740 |
+
class CAEDataset:
|
| 741 |
+
"""
|
| 742 |
+
Dataset reader for DrivaerML and similar datasets. In general, this
|
| 743 |
+
dataset supports reading dictionary-like data, and returning a
|
| 744 |
+
dictionary of torch.Tensor objects.
|
| 745 |
+
|
| 746 |
+
When constructed, the user must pass a directory of data examples.
|
| 747 |
+
The dataset will inspect the folder, identify all children, and decide:
|
| 748 |
+
- If every file is a directory ending in .zarr, the zarr reader is used.
|
| 749 |
+
- If every file is .npy, the .npy reader is used.
|
| 750 |
+
- If every file is .npz, the .npz reader is used.
|
| 751 |
+
- If every file is a directory without an extension, it's assumed to be .stl/.vtp/.vtu
|
| 752 |
+
|
| 753 |
+
The user can optionally force one path with a parameter.
|
| 754 |
+
|
| 755 |
+
The flow of this dataset is:
|
| 756 |
+
- Load data from file, using a thread.
|
| 757 |
+
- Each individual file reading tool may or may not have it's own threading
|
| 758 |
+
or multi processing enabled. That's up to it. This just does async
|
| 759 |
+
loading.
|
| 760 |
+
- Data should come out of the readers in dict{str : torch.Tensor} format
|
| 761 |
+
- The data is transferred from CPU to GPU in a separate stream.
|
| 762 |
+
|
| 763 |
+
Users can call __getitem__(i), which will trigger the pipeline,
|
| 764 |
+
or they can call `preload(i)`, which will start the pipeline for index `i`.
|
| 765 |
+
Subsequent calls to `__getitem__(i)` should be faster since the IO is in
|
| 766 |
+
progress or complete.
|
| 767 |
+
|
| 768 |
+
Using the `__iter__` functionality will automatically enable preloading.
|
| 769 |
+
|
| 770 |
+
"""
|
| 771 |
+
|
| 772 |
+
def __init__(
|
| 773 |
+
self,
|
| 774 |
+
data_dir: str | pathlib.Path,
|
| 775 |
+
keys_to_read: list[str] | None,
|
| 776 |
+
keys_to_read_if_available: dict[str, torch.Tensor] | None,
|
| 777 |
+
output_device: torch.device,
|
| 778 |
+
preload_depth: int = 2,
|
| 779 |
+
pin_memory: bool = False,
|
| 780 |
+
device_mesh: torch.distributed.DeviceMesh | None = None,
|
| 781 |
+
placements: dict[str, torch.distributed.tensor.Placement] | None = None,
|
| 782 |
+
consumer_stream: torch.cuda.Stream | None = None,
|
| 783 |
+
) -> None:
|
| 784 |
+
if isinstance(data_dir, str):
|
| 785 |
+
data_dir = pathlib.Path(data_dir)
|
| 786 |
+
|
| 787 |
+
# Verify the data directory exists:
|
| 788 |
+
if not data_dir.exists():
|
| 789 |
+
raise FileNotFoundError(f"Data directory {data_dir} does not exist")
|
| 790 |
+
|
| 791 |
+
# Verify the data directory is a directory:
|
| 792 |
+
if not data_dir.is_dir():
|
| 793 |
+
raise NotADirectoryError(f"Data directory {data_dir} is not a directory")
|
| 794 |
+
|
| 795 |
+
self._keys_to_read = keys_to_read
|
| 796 |
+
|
| 797 |
+
# Make sure the optional keys are on the right device:
|
| 798 |
+
self._keys_to_read_if_available = {
|
| 799 |
+
k: v.to(output_device) for k, v in keys_to_read_if_available.items()
|
| 800 |
+
}
|
| 801 |
+
|
| 802 |
+
self.file_reader, self._filenames = self._infer_file_type_and_filenames(
|
| 803 |
+
data_dir
|
| 804 |
+
)
|
| 805 |
+
|
| 806 |
+
self.pin_memory = pin_memory
|
| 807 |
+
|
| 808 |
+
# Check the file names; some can be read well in parallel, while others
|
| 809 |
+
# are not parallelizable.
|
| 810 |
+
|
| 811 |
+
self._length = len(self._filenames)
|
| 812 |
+
|
| 813 |
+
self.output_device = output_device
|
| 814 |
+
if output_device.type == "cuda":
|
| 815 |
+
self._data_loader_stream = torch.cuda.Stream()
|
| 816 |
+
else:
|
| 817 |
+
self._data_loader_stream = None
|
| 818 |
+
|
| 819 |
+
self.device_mesh = device_mesh
|
| 820 |
+
self.placements = placements
|
| 821 |
+
# This tracks global tensor info
|
| 822 |
+
# so we can convert to ShardTensor at the right time.
|
| 823 |
+
self.shard_spec = {}
|
| 824 |
+
|
| 825 |
+
if self.device_mesh is not None:
|
| 826 |
+
if self.device_mesh.ndim != 1:
|
| 827 |
+
raise ValueError("Device mesh must be one dimensional")
|
| 828 |
+
|
| 829 |
+
# This is thread storage for data preloading:
|
| 830 |
+
self._preload_queue = {}
|
| 831 |
+
self._transfer_events = {}
|
| 832 |
+
self.preload_depth = preload_depth
|
| 833 |
+
self.preload_executor = ThreadPoolExecutor(max_workers=max(1, preload_depth))
|
| 834 |
+
|
| 835 |
+
if consumer_stream is None and self.output_device.type == "cuda":
|
| 836 |
+
consumer_stream = torch.cuda.current_stream()
|
| 837 |
+
|
| 838 |
+
self.consumer_stream = consumer_stream
|
| 839 |
+
|
| 840 |
+
def set_indices(self, indices: list[int]):
|
| 841 |
+
"""
|
| 842 |
+
Set the indices for the dataset for this epoch.
|
| 843 |
+
"""
|
| 844 |
+
|
| 845 |
+
# TODO - this needs to block while anything is in the preprocess queue.
|
| 846 |
+
|
| 847 |
+
self.indices = indices
|
| 848 |
+
|
| 849 |
+
def idx_to_index(self, idx):
|
| 850 |
+
if hasattr(self, "indices"):
|
| 851 |
+
return self.indices[idx]
|
| 852 |
+
|
| 853 |
+
return idx
|
| 854 |
+
|
| 855 |
+
def _infer_file_type_and_filenames(
|
| 856 |
+
self, data_dir: pathlib.Path
|
| 857 |
+
) -> tuple[str, list[str]]:
|
| 858 |
+
"""
|
| 859 |
+
Infer the file type and filenames from the data directory.
|
| 860 |
+
"""
|
| 861 |
+
|
| 862 |
+
# We validated the directory exists and is a directory already.
|
| 863 |
+
|
| 864 |
+
# List the files:
|
| 865 |
+
files = list(data_dir.iterdir())
|
| 866 |
+
|
| 867 |
+
# Initialize the file reader object
|
| 868 |
+
# Note that for some of these, they could be functions
|
| 869 |
+
# But others benefit from having a state, so we use classes:
|
| 870 |
+
|
| 871 |
+
if all(file.suffix == ".npy" for file in files):
|
| 872 |
+
file_reader = NpyFileReader(
|
| 873 |
+
self._keys_to_read, self._keys_to_read_if_available
|
| 874 |
+
)
|
| 875 |
+
return file_reader, files
|
| 876 |
+
elif all(file.suffix == ".npz" for file in files):
|
| 877 |
+
file_reader = NpzFileReader(
|
| 878 |
+
self._keys_to_read, self._keys_to_read_if_available
|
| 879 |
+
)
|
| 880 |
+
return file_reader, files
|
| 881 |
+
elif all(file.suffix == ".zarr" and file.is_dir() for file in files):
|
| 882 |
+
if TENSORSTORE_AVAILABLE:
|
| 883 |
+
file_reader = TensorStoreZarrReader(
|
| 884 |
+
self._keys_to_read, self._keys_to_read_if_available
|
| 885 |
+
)
|
| 886 |
+
else:
|
| 887 |
+
file_reader = ZarrFileReader(
|
| 888 |
+
self._keys_to_read, self._keys_to_read_if_available
|
| 889 |
+
)
|
| 890 |
+
return file_reader, files
|
| 891 |
+
elif all(is_vtk_directory(file) for file in files):
|
| 892 |
+
file_reader = VTKFileReader(
|
| 893 |
+
self._keys_to_read, self._keys_to_read_if_available
|
| 894 |
+
)
|
| 895 |
+
return file_reader, files
|
| 896 |
+
# Each "file" here is a directory of .vtp, stl, etc.
|
| 897 |
+
else:
|
| 898 |
+
# TODO - support folders of stl, vtp, vtu.
|
| 899 |
+
raise ValueError(f"Unsupported file type: {files[0]}")
|
| 900 |
+
|
| 901 |
+
def _move_to_gpu(
|
| 902 |
+
self, data: dict[str, torch.Tensor], idx: int
|
| 903 |
+
) -> dict[str, torch.Tensor]:
|
| 904 |
+
"""Convert numpy arrays to torch tensors and move to GPU if available.
|
| 905 |
+
|
| 906 |
+
Args:
|
| 907 |
+
data: Dictionary of key to torch tensor.
|
| 908 |
+
|
| 909 |
+
Returns:
|
| 910 |
+
Dictionary of key to torch tensor on GPU if available.
|
| 911 |
+
"""
|
| 912 |
+
|
| 913 |
+
if self.output_device.type != "cuda":
|
| 914 |
+
return data
|
| 915 |
+
|
| 916 |
+
result = {}
|
| 917 |
+
|
| 918 |
+
with torch.cuda.stream(self._data_loader_stream):
|
| 919 |
+
for key in data.keys():
|
| 920 |
+
if data[key].device == self.output_device:
|
| 921 |
+
result[key] = data[key]
|
| 922 |
+
continue
|
| 923 |
+
if self.pin_memory:
|
| 924 |
+
result[key] = (
|
| 925 |
+
data[key].pin_memory().to(self.output_device, non_blocking=True)
|
| 926 |
+
)
|
| 927 |
+
else:
|
| 928 |
+
result[key] = data[key].to(self.output_device, non_blocking=True)
|
| 929 |
+
# Move to GPU if available
|
| 930 |
+
# result[key] = data[key].to(self.output_device, non_blocking=True)
|
| 931 |
+
result[key].record_stream(self.consumer_stream)
|
| 932 |
+
|
| 933 |
+
# Mark the consumer stream:
|
| 934 |
+
transfer_event = torch.cuda.Event()
|
| 935 |
+
transfer_event.record(self._data_loader_stream)
|
| 936 |
+
self._transfer_events[idx] = transfer_event
|
| 937 |
+
|
| 938 |
+
return result
|
| 939 |
+
|
| 940 |
+
def _convert_to_shard_tensors(
|
| 941 |
+
self,
|
| 942 |
+
tensors: dict[str, torch.Tensor],
|
| 943 |
+
filename: str,
|
| 944 |
+
) -> dict[str, ShardTensor]:
|
| 945 |
+
"""Convert tensors to ShardTensor objects for distributed training.
|
| 946 |
+
|
| 947 |
+
Args:
|
| 948 |
+
tensors: Dictionary of key to torch tensor.
|
| 949 |
+
|
| 950 |
+
Returns:
|
| 951 |
+
Dictionary of key to torch tensor or ShardTensor.
|
| 952 |
+
"""
|
| 953 |
+
|
| 954 |
+
if self.device_mesh is None:
|
| 955 |
+
return tensors
|
| 956 |
+
|
| 957 |
+
spec_dict = self.shard_spec.pop(filename)
|
| 958 |
+
result = {}
|
| 959 |
+
for key in tensors.keys():
|
| 960 |
+
placement, chunk_sizes = spec_dict[key]
|
| 961 |
+
|
| 962 |
+
result[key] = ShardTensor.from_local(
|
| 963 |
+
local_tensor=tensors[key],
|
| 964 |
+
device_mesh=self.device_mesh,
|
| 965 |
+
placements=placement,
|
| 966 |
+
sharding_shapes=chunk_sizes,
|
| 967 |
+
)
|
| 968 |
+
|
| 969 |
+
return result
|
| 970 |
+
|
| 971 |
+
def preload(self, idx: int) -> None:
|
| 972 |
+
"""
|
| 973 |
+
Asynchronously preload the data for the given index (up to CPU, not GPU).
|
| 974 |
+
Only one preload operation is supported at a time.
|
| 975 |
+
|
| 976 |
+
Args:
|
| 977 |
+
idx: Index of the sample to preload.
|
| 978 |
+
"""
|
| 979 |
+
if idx in self._preload_queue:
|
| 980 |
+
# Skip items that are already in the queue
|
| 981 |
+
return
|
| 982 |
+
|
| 983 |
+
def _preload_worker():
|
| 984 |
+
data = self._read_file(self._filenames[idx])
|
| 985 |
+
if "stl_faces" in data:
|
| 986 |
+
data["stl_faces"] = data["stl_faces"].to(torch.int32)
|
| 987 |
+
# Convert to torch tensors
|
| 988 |
+
return self._move_to_gpu(data, idx)
|
| 989 |
+
|
| 990 |
+
self._preload_queue[idx] = self.preload_executor.submit(_preload_worker)
|
| 991 |
+
|
| 992 |
+
def get_preloaded(self, idx: int) -> dict[str, torch.Tensor] | None:
|
| 993 |
+
"""
|
| 994 |
+
Retrieve the preloaded data (blocking if not ready).
|
| 995 |
+
|
| 996 |
+
Returns:
|
| 997 |
+
(idx, data) tuple where data is a dictionary of key to numpy array or torch tensor.
|
| 998 |
+
|
| 999 |
+
Raises:
|
| 1000 |
+
RuntimeError: If no preload is in progress.
|
| 1001 |
+
Exception: If preload failed.
|
| 1002 |
+
"""
|
| 1003 |
+
|
| 1004 |
+
if idx not in self._preload_queue:
|
| 1005 |
+
return None
|
| 1006 |
+
|
| 1007 |
+
result = self._preload_queue[
|
| 1008 |
+
idx
|
| 1009 |
+
].result() # This will block until the result is ready
|
| 1010 |
+
self._preload_queue.pop(idx) # Clear the future after getting the result
|
| 1011 |
+
|
| 1012 |
+
return result
|
| 1013 |
+
|
| 1014 |
+
def __iter__(self):
|
| 1015 |
+
# When starting the iterator method, start loading the data
|
| 1016 |
+
# at idx = 0, idx = 1
|
| 1017 |
+
# Start preprocessing at idx = 0, when the load completes
|
| 1018 |
+
|
| 1019 |
+
self.i = 0
|
| 1020 |
+
|
| 1021 |
+
N = len(self.indices) if hasattr(self, "indices") else len(self)
|
| 1022 |
+
for i in range(self.preload_depth):
|
| 1023 |
+
# Trigger the dataset to start loading index 0:
|
| 1024 |
+
if N > i + 1:
|
| 1025 |
+
self.preload(self.idx_to_index(self.i + i))
|
| 1026 |
+
|
| 1027 |
+
return self
|
| 1028 |
+
|
| 1029 |
+
def __next__(self):
|
| 1030 |
+
N = len(self.indices) if hasattr(self, "indices") else len(self._filenames)
|
| 1031 |
+
|
| 1032 |
+
# Iteration bounds are based on the counter, not the random-access index
|
| 1033 |
+
if self.i >= N:
|
| 1034 |
+
self.i = 0
|
| 1035 |
+
raise StopIteration
|
| 1036 |
+
|
| 1037 |
+
# This is the file random access index
|
| 1038 |
+
target_index = self.idx_to_index(self.i)
|
| 1039 |
+
|
| 1040 |
+
# Before returning, put the next two target indexes into the queue:
|
| 1041 |
+
for preload_i in range(self.preload_depth):
|
| 1042 |
+
next_iteration_index = self.i + preload_i + 1
|
| 1043 |
+
if N > next_iteration_index:
|
| 1044 |
+
preload_idx = self.idx_to_index(next_iteration_index)
|
| 1045 |
+
self.preload(preload_idx)
|
| 1046 |
+
|
| 1047 |
+
# Send up the random-access data:
|
| 1048 |
+
data = self.__getitem__(target_index)
|
| 1049 |
+
|
| 1050 |
+
self.i += 1
|
| 1051 |
+
|
| 1052 |
+
return data
|
| 1053 |
+
|
| 1054 |
+
def __len__(self):
|
| 1055 |
+
return len(self._filenames)
|
| 1056 |
+
|
| 1057 |
+
def _read_file(self, filename: pathlib.Path) -> dict[str, torch.Tensor]:
|
| 1058 |
+
"""
|
| 1059 |
+
Read a file and return a dictionary of tensors.
|
| 1060 |
+
"""
|
| 1061 |
+
if self.device_mesh is not None:
|
| 1062 |
+
tensor_dict, spec_dict = self.file_reader.read_file_sharded(
|
| 1063 |
+
filename, self.device_mesh
|
| 1064 |
+
)
|
| 1065 |
+
self.shard_spec[filename] = spec_dict
|
| 1066 |
+
return tensor_dict
|
| 1067 |
+
else:
|
| 1068 |
+
return self.file_reader.read_file(filename)
|
| 1069 |
+
|
| 1070 |
+
def __getitem__(self, idx: int) -> dict[str, torch.Tensor | ShardTensor]:
|
| 1071 |
+
"""
|
| 1072 |
+
Get a data sample.
|
| 1073 |
+
|
| 1074 |
+
Flow is:
|
| 1075 |
+
- Read data, or get preloaded data if this idx is preloaded.
|
| 1076 |
+
- Move data to GPU, if needed.
|
| 1077 |
+
- Preloading data will move to GPU if it can.
|
| 1078 |
+
- If domain parallelism is enabled, convert to ShardTensors.
|
| 1079 |
+
- Return
|
| 1080 |
+
|
| 1081 |
+
Args:
|
| 1082 |
+
idx: Index of the sample to retrieve
|
| 1083 |
+
|
| 1084 |
+
Returns:
|
| 1085 |
+
Dictionary containing tensors/ShardTensors for the requested data
|
| 1086 |
+
"""
|
| 1087 |
+
|
| 1088 |
+
if idx >= len(self._filenames):
|
| 1089 |
+
raise IndexError(
|
| 1090 |
+
f"Index {idx} out of range for dataset of size {len(self._filenames)}"
|
| 1091 |
+
)
|
| 1092 |
+
|
| 1093 |
+
# Attempt to get preloaded data:
|
| 1094 |
+
data = self.get_preloaded(idx)
|
| 1095 |
+
if data is None:
|
| 1096 |
+
# Read data from zarr file
|
| 1097 |
+
data = self._read_file(self._filenames[idx])
|
| 1098 |
+
data = self._move_to_gpu(data, idx)
|
| 1099 |
+
|
| 1100 |
+
# This blocks until the preprocessing has transferred to GPU
|
| 1101 |
+
if idx in self._transfer_events:
|
| 1102 |
+
self.consumer_stream.wait_event(self._transfer_events[idx])
|
| 1103 |
+
self._transfer_events.pop(idx)
|
| 1104 |
+
|
| 1105 |
+
# Convert to ShardTensors if using domain parallelism
|
| 1106 |
+
if self.device_mesh is not None:
|
| 1107 |
+
data = self._convert_to_shard_tensors(data, self._filenames[idx])
|
| 1108 |
+
|
| 1109 |
+
return data
|
| 1110 |
+
|
| 1111 |
+
def set_volume_sampling_size(self, volume_sampling_size: int):
|
| 1112 |
+
"""
|
| 1113 |
+
Set the volume sampling size. When set, the readers will
|
| 1114 |
+
assume the volumetric data is shuffled on disk and read only
|
| 1115 |
+
contiguous chunks of the data up to the sampling size.
|
| 1116 |
+
|
| 1117 |
+
Args:
|
| 1118 |
+
volume_sampling_size: The total size of the volume sampling.
|
| 1119 |
+
"""
|
| 1120 |
+
self.file_reader.set_volume_sampling_size(volume_sampling_size)
|
| 1121 |
+
|
| 1122 |
+
def close(self):
|
| 1123 |
+
"""
|
| 1124 |
+
Explicitly close the dataset and cleanup resources, including the ThreadPoolExecutor.
|
| 1125 |
+
"""
|
| 1126 |
+
if hasattr(self, "preload_executor") and self.preload_executor is not None:
|
| 1127 |
+
self.preload_executor.shutdown(wait=True)
|
| 1128 |
+
self.preload_executor = None
|
| 1129 |
+
|
| 1130 |
+
def __del__(self):
|
| 1131 |
+
"""
|
| 1132 |
+
Cleanup resources when the dataset is destroyed.
|
| 1133 |
+
"""
|
| 1134 |
+
self.close()
|
| 1135 |
+
|
| 1136 |
+
|
| 1137 |
+
def compute_mean_std_min_max(
|
| 1138 |
+
dataset: CAEDataset, field_keys: list[str], max_samples: int = 20
|
| 1139 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 1140 |
+
"""
|
| 1141 |
+
Compute the mean, standard deviation, minimum, and maximum for a specified field
|
| 1142 |
+
across all samples in a dataset.
|
| 1143 |
+
|
| 1144 |
+
Uses a numerically stable online algorithm for mean and variance.
|
| 1145 |
+
|
| 1146 |
+
Args:
|
| 1147 |
+
dataset (CAEDataset): The dataset to process.
|
| 1148 |
+
field_key (str): The key for the field to normalize.
|
| 1149 |
+
|
| 1150 |
+
Returns:
|
| 1151 |
+
tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 1152 |
+
mean, std, min, max tensors for the field.
|
| 1153 |
+
"""
|
| 1154 |
+
N = {}
|
| 1155 |
+
mean = {}
|
| 1156 |
+
M2 = {} # Sum of squares of differences from the current mean
|
| 1157 |
+
min_val = {}
|
| 1158 |
+
max_val = {}
|
| 1159 |
+
|
| 1160 |
+
# Read the first data item to get the shapes:
|
| 1161 |
+
example_data = dataset[0]
|
| 1162 |
+
|
| 1163 |
+
# Create placeholders for the accumulators:
|
| 1164 |
+
for key in field_keys:
|
| 1165 |
+
N[key] = torch.zeros(1, dtype=torch.int64, device=example_data[key].device)
|
| 1166 |
+
mean[key] = torch.zeros(
|
| 1167 |
+
example_data[key].shape[-1],
|
| 1168 |
+
device=example_data[key].device,
|
| 1169 |
+
dtype=torch.float64,
|
| 1170 |
+
)
|
| 1171 |
+
M2[key] = torch.zeros(
|
| 1172 |
+
example_data[key].shape[-1],
|
| 1173 |
+
device=example_data[key].device,
|
| 1174 |
+
dtype=torch.float64,
|
| 1175 |
+
)
|
| 1176 |
+
min_val[key] = torch.full(
|
| 1177 |
+
(example_data[key].shape[-1],),
|
| 1178 |
+
float("inf"),
|
| 1179 |
+
device=example_data[key].device,
|
| 1180 |
+
)
|
| 1181 |
+
max_val[key] = torch.full(
|
| 1182 |
+
(example_data[key].shape[-1],),
|
| 1183 |
+
float("-inf"),
|
| 1184 |
+
device=example_data[key].device,
|
| 1185 |
+
)
|
| 1186 |
+
|
| 1187 |
+
global_start = time.perf_counter()
|
| 1188 |
+
start = time.perf_counter()
|
| 1189 |
+
data_list = np.arange(len(dataset))
|
| 1190 |
+
np.random.shuffle(data_list)
|
| 1191 |
+
for i, j in enumerate(data_list):
|
| 1192 |
+
data = dataset[j]
|
| 1193 |
+
if i >= max_samples:
|
| 1194 |
+
break
|
| 1195 |
+
|
| 1196 |
+
for field_key in field_keys:
|
| 1197 |
+
field_data = data[field_key]
|
| 1198 |
+
|
| 1199 |
+
# Compute batch statistics
|
| 1200 |
+
batch_mean = field_data.mean(axis=(0))
|
| 1201 |
+
batch_M2 = ((field_data - batch_mean) ** 2).sum(axis=(0))
|
| 1202 |
+
batch_n = field_data.shape[0]
|
| 1203 |
+
|
| 1204 |
+
# Update running mean and M2 (Welford's algorithm)
|
| 1205 |
+
delta = batch_mean - mean[field_key]
|
| 1206 |
+
N[field_key] += batch_n # batch_n should also be torch.int64
|
| 1207 |
+
mean[field_key] = mean[field_key] + delta * (batch_n / N[field_key])
|
| 1208 |
+
M2[field_key] = (
|
| 1209 |
+
M2[field_key]
|
| 1210 |
+
+ batch_M2
|
| 1211 |
+
+ delta**2 * (batch_n * N[field_key]) / N[field_key]
|
| 1212 |
+
)
|
| 1213 |
+
|
| 1214 |
+
end = time.perf_counter()
|
| 1215 |
+
iteration_time = end - start
|
| 1216 |
+
print(
|
| 1217 |
+
f"on iteration {i} of {max_samples}, time: {iteration_time:.2f} seconds for file: {j}"
|
| 1218 |
+
)
|
| 1219 |
+
start = time.perf_counter()
|
| 1220 |
+
|
| 1221 |
+
var = {}
|
| 1222 |
+
std = {}
|
| 1223 |
+
for field_key in field_keys:
|
| 1224 |
+
var[field_key] = M2[field_key] / (
|
| 1225 |
+
N[field_key].item() - 1
|
| 1226 |
+
) # Convert N to Python int for division
|
| 1227 |
+
std[field_key] = torch.sqrt(var[field_key])
|
| 1228 |
+
|
| 1229 |
+
start = time.perf_counter()
|
| 1230 |
+
for i, j in enumerate(data_list):
|
| 1231 |
+
data = dataset[j]
|
| 1232 |
+
if i >= max_samples:
|
| 1233 |
+
break
|
| 1234 |
+
|
| 1235 |
+
for field_key in field_keys:
|
| 1236 |
+
field_data = data[field_key]
|
| 1237 |
+
|
| 1238 |
+
batch_n = field_data.shape[0]
|
| 1239 |
+
|
| 1240 |
+
# # Update min/max
|
| 1241 |
+
|
| 1242 |
+
mean_sample = mean[field_key]
|
| 1243 |
+
std_sample = std[field_key]
|
| 1244 |
+
mask = torch.ones_like(field_data, dtype=torch.bool)
|
| 1245 |
+
for v in range(field_data.shape[-1]):
|
| 1246 |
+
outliers = (field_data[:, v] < mean_sample[v] - 9.0 * std_sample[v]) | (
|
| 1247 |
+
field_data[:, v] > mean_sample[v] + 9.0 * std_sample[v]
|
| 1248 |
+
)
|
| 1249 |
+
mask[:, v] = ~outliers
|
| 1250 |
+
|
| 1251 |
+
batch_min = []
|
| 1252 |
+
batch_max = []
|
| 1253 |
+
for v in range(field_data.shape[-1]):
|
| 1254 |
+
batch_min.append(field_data[mask[:, v], v].min())
|
| 1255 |
+
batch_max.append(field_data[mask[:, v], v].max())
|
| 1256 |
+
|
| 1257 |
+
batch_min = torch.stack(batch_min)
|
| 1258 |
+
batch_max = torch.stack(batch_max)
|
| 1259 |
+
|
| 1260 |
+
min_val[field_key] = torch.minimum(min_val[field_key], batch_min)
|
| 1261 |
+
max_val[field_key] = torch.maximum(max_val[field_key], batch_max)
|
| 1262 |
+
|
| 1263 |
+
end = time.perf_counter()
|
| 1264 |
+
iteration_time = end - start
|
| 1265 |
+
print(
|
| 1266 |
+
f"on iteration {i} of {max_samples}, time: {iteration_time:.2f} seconds for file: {j}"
|
| 1267 |
+
)
|
| 1268 |
+
start = time.perf_counter()
|
| 1269 |
+
|
| 1270 |
+
global_end = time.perf_counter()
|
| 1271 |
+
global_time = global_end - global_start
|
| 1272 |
+
|
| 1273 |
+
print(f"Total time: {global_time:.2f} seconds for {max_samples} samples")
|
| 1274 |
+
|
| 1275 |
+
return mean, std, min_val, max_val
|
physics_mcp/source/physicsnemo/datapipes/cae/domino_datapipe.py
ADDED
|
@@ -0,0 +1,1334 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
"""
|
| 18 |
+
This code provides the datapipe for reading the processed npy files,
|
| 19 |
+
generating multi-res grids, calculating signed distance fields,
|
| 20 |
+
sampling random points in the volume and on surface,
|
| 21 |
+
normalizing fields and returning the output tensors as a dictionary.
|
| 22 |
+
|
| 23 |
+
This datapipe also non-dimensionalizes the fields, so the order in which the variables should
|
| 24 |
+
be fixed: velocity, pressure, turbulent viscosity for volume variables and
|
| 25 |
+
pressure, wall-shear-stress for surface variables. The different parameters such as
|
| 26 |
+
variable names, domain resolution, sampling size etc. are configurable in config.yaml.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
from dataclasses import dataclass
|
| 30 |
+
from pathlib import Path
|
| 31 |
+
from typing import Iterable, Literal, Optional, Protocol, Sequence, Union
|
| 32 |
+
|
| 33 |
+
import numpy as np
|
| 34 |
+
import torch
|
| 35 |
+
import torch.cuda.nvtx as nvtx
|
| 36 |
+
import torch.distributed as dist
|
| 37 |
+
from omegaconf import DictConfig
|
| 38 |
+
from torch.distributed.tensor.placement_types import Replicate, Shard
|
| 39 |
+
from torch.utils.data import Dataset
|
| 40 |
+
|
| 41 |
+
from physicsnemo.datapipes.cae.cae_dataset import (
|
| 42 |
+
CAEDataset,
|
| 43 |
+
compute_mean_std_min_max,
|
| 44 |
+
)
|
| 45 |
+
from physicsnemo.distributed import DistributedManager
|
| 46 |
+
from physicsnemo.distributed.shard_tensor import ShardTensor, scatter_tensor
|
| 47 |
+
from physicsnemo.utils.domino.utils import (
|
| 48 |
+
calculate_center_of_mass,
|
| 49 |
+
create_grid,
|
| 50 |
+
get_filenames,
|
| 51 |
+
normalize,
|
| 52 |
+
pad,
|
| 53 |
+
shuffle_array,
|
| 54 |
+
standardize,
|
| 55 |
+
unnormalize,
|
| 56 |
+
unstandardize,
|
| 57 |
+
)
|
| 58 |
+
from physicsnemo.utils.neighbors import knn
|
| 59 |
+
from physicsnemo.utils.profiling import profile
|
| 60 |
+
from physicsnemo.utils.sdf import signed_distance_field
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
class BoundingBox(Protocol):
|
| 64 |
+
"""
|
| 65 |
+
Type definition for the required format of bounding box dimensions.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
min: Sequence
|
| 69 |
+
max: Sequence
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
@dataclass
|
| 73 |
+
class DoMINODataConfig:
|
| 74 |
+
"""Configuration for DoMINO dataset processing pipeline.
|
| 75 |
+
|
| 76 |
+
Attributes:
|
| 77 |
+
data_path: Path to the dataset to load.
|
| 78 |
+
phase: Which phase of data to load ("train", "val", or "test").
|
| 79 |
+
surface_variables: (Surface specific) Names of surface variables.
|
| 80 |
+
surface_points_sample: (Surface specific) Number of surface points to sample per batch.
|
| 81 |
+
num_surface_neighbors: (Surface specific) Number of surface neighbors to consider for nearest neighbors approach.
|
| 82 |
+
surface_sampling_algorithm: (Surface specific) Algorithm to use for surface sampling ("area_weighted" or "random").
|
| 83 |
+
surface_factors: (Surface specific) Non-dimensionalization factors for surface variables.
|
| 84 |
+
If set, and scaling_type is:
|
| 85 |
+
- min_max_scaling -> rescale surface_fields to the min/max set here
|
| 86 |
+
- mean_std_scaling -> rescale surface_fields to the mean and std set here.
|
| 87 |
+
bounding_box_dims_surf: (Surface specific) Dimensions of bounding box. Must be an object with min/max
|
| 88 |
+
attributes that are arraylike.
|
| 89 |
+
volume_variables: (Volume specific) Names of volume variables.
|
| 90 |
+
volume_points_sample: (Volume specific) Number of volume points to sample per batch.
|
| 91 |
+
volume_sample_from_disk: (Volume specific) If the volume data is in a shuffled state on disk,
|
| 92 |
+
read contiguous chunks of the data rather than the entire volume data. This greatly
|
| 93 |
+
accelerates IO in bandwidth limited systems or when the volumetric data is very large.
|
| 94 |
+
volume_factors: (Volume specific) Non-dimensionalization factors for volume variables scaling.
|
| 95 |
+
If set, and scaling_type is:
|
| 96 |
+
- min_max_scaling -> rescale volume_fields to the min/max set here
|
| 97 |
+
- mean_std_scaling -> rescale volume_fields to the mean and std set here.
|
| 98 |
+
bounding_box_dims: (Volume specific) Dimensions of bounding box. Must be an object with min/max
|
| 99 |
+
attributes that are arraylike.
|
| 100 |
+
grid_resolution: Resolution of the latent grid.
|
| 101 |
+
normalize_coordinates: Whether to normalize coordinates based on min/max values.
|
| 102 |
+
For surfaces: uses s_min/s_max, defined from:
|
| 103 |
+
- Surface bounding box, if defined.
|
| 104 |
+
- Min/max of the stl_vertices
|
| 105 |
+
For volumes: uses c_min/c_max, defined from:
|
| 106 |
+
- Volume bounding_box if defined,
|
| 107 |
+
- 1.5x s_min/max otherwise, except c_min[2] = s_min[2] in this case
|
| 108 |
+
sample_in_bbox: Whether to sample points in a specified bounding box.
|
| 109 |
+
Uses the same min/max points as coordinate normalization.
|
| 110 |
+
Only performed if compute_scaling_factors is false.
|
| 111 |
+
sampling: Whether to downsample the full resolution mesh to fit in GPU memory.
|
| 112 |
+
Surface and volume sampling points are configured separately as:
|
| 113 |
+
- surface.points_sample
|
| 114 |
+
- volume.points_sample
|
| 115 |
+
geom_points_sample: Number of STL points sampled per batch.
|
| 116 |
+
Independent of volume.points_sample and surface.points_sample.
|
| 117 |
+
scaling_type: Scaling type for volume variables.
|
| 118 |
+
If used, will rescale the volume_fields and surface fields outputs.
|
| 119 |
+
Requires volume.factor and surface.factor to be set.
|
| 120 |
+
compute_scaling_factors: Whether to compute scaling factors.
|
| 121 |
+
Not available if caching.
|
| 122 |
+
Many preprocessing pieces are disabled if computing scaling factors.
|
| 123 |
+
caching: Whether this is for caching or serving.
|
| 124 |
+
deterministic: Whether to use a deterministic seed for sampling and random numbers.
|
| 125 |
+
gpu_preprocessing: Whether to do preprocessing on the GPU (False for CPU).
|
| 126 |
+
gpu_output: Whether to return output on the GPU as cupy arrays.
|
| 127 |
+
If False, returns numpy arrays.
|
| 128 |
+
You might choose gpu_preprocessing=True and gpu_output=False if caching.
|
| 129 |
+
shard_grid: Whether to shard the grid across GPUs for domain parallelism.
|
| 130 |
+
Applies to the surf_grid and similiar tensors.
|
| 131 |
+
shard_points: Whether to shard the points across GPUs for domain parallelism.
|
| 132 |
+
Applies to the volume_fields/surface_fields and similiar tensors.
|
| 133 |
+
"""
|
| 134 |
+
|
| 135 |
+
data_path: Path | None
|
| 136 |
+
phase: Literal["train", "val", "test"]
|
| 137 |
+
|
| 138 |
+
# Surface-specific variables:
|
| 139 |
+
surface_variables: Optional[Sequence] = ("pMean", "wallShearStress")
|
| 140 |
+
surface_points_sample: int = 1024
|
| 141 |
+
num_surface_neighbors: int = 11
|
| 142 |
+
surface_sampling_algorithm: str = Literal["area_weighted", "random"]
|
| 143 |
+
surface_factors: Optional[Sequence] = None
|
| 144 |
+
bounding_box_dims_surf: Optional[Union[BoundingBox, Sequence]] = None
|
| 145 |
+
|
| 146 |
+
# Volume specific variables:
|
| 147 |
+
volume_variables: Optional[Sequence] = ("UMean", "pMean")
|
| 148 |
+
volume_points_sample: int = 1024
|
| 149 |
+
volume_sample_from_disk: bool = False
|
| 150 |
+
volume_factors: Optional[Sequence] = None
|
| 151 |
+
bounding_box_dims: Optional[Union[BoundingBox, Sequence]] = None
|
| 152 |
+
|
| 153 |
+
grid_resolution: Sequence = (256, 96, 64)
|
| 154 |
+
normalize_coordinates: bool = False
|
| 155 |
+
sample_in_bbox: bool = False
|
| 156 |
+
sampling: bool = False
|
| 157 |
+
geom_points_sample: int = 300000
|
| 158 |
+
scaling_type: Optional[Literal["min_max_scaling", "mean_std_scaling"]] = None
|
| 159 |
+
compute_scaling_factors: bool = False
|
| 160 |
+
caching: bool = False
|
| 161 |
+
deterministic: bool = False
|
| 162 |
+
gpu_preprocessing: bool = True
|
| 163 |
+
gpu_output: bool = True
|
| 164 |
+
|
| 165 |
+
shard_grid: bool = False
|
| 166 |
+
shard_points: bool = False
|
| 167 |
+
|
| 168 |
+
def __post_init__(self):
|
| 169 |
+
if self.data_path is not None:
|
| 170 |
+
# Ensure data_path is a Path object:
|
| 171 |
+
if isinstance(self.data_path, str):
|
| 172 |
+
self.data_path = Path(self.data_path)
|
| 173 |
+
self.data_path = self.data_path.expanduser()
|
| 174 |
+
|
| 175 |
+
if not self.data_path.exists():
|
| 176 |
+
raise ValueError(f"Path {self.data_path} does not exist")
|
| 177 |
+
|
| 178 |
+
if not self.data_path.is_dir():
|
| 179 |
+
raise ValueError(f"Path {self.data_path} is not a directory")
|
| 180 |
+
|
| 181 |
+
# Object if caching settings are impossible:
|
| 182 |
+
if self.caching:
|
| 183 |
+
if self.sampling:
|
| 184 |
+
raise ValueError("Sampling should be False for caching")
|
| 185 |
+
if self.compute_scaling_factors:
|
| 186 |
+
raise ValueError("Compute scaling factors should be False for caching")
|
| 187 |
+
|
| 188 |
+
if self.phase not in [
|
| 189 |
+
"train",
|
| 190 |
+
"val",
|
| 191 |
+
"test",
|
| 192 |
+
]:
|
| 193 |
+
raise ValueError(
|
| 194 |
+
f"phase should be one of ['train', 'val', 'test'], got {self.phase}"
|
| 195 |
+
)
|
| 196 |
+
if self.scaling_type is not None:
|
| 197 |
+
if self.scaling_type not in [
|
| 198 |
+
"min_max_scaling",
|
| 199 |
+
"mean_std_scaling",
|
| 200 |
+
]:
|
| 201 |
+
raise ValueError(
|
| 202 |
+
f"scaling_type should be one of ['min_max_scaling', 'mean_std_scaling'], got {self.scaling_type}"
|
| 203 |
+
)
|
| 204 |
+
|
| 205 |
+
|
| 206 |
+
##### TODO
|
| 207 |
+
# - The SDF normalization here is based on using a normalized mesh and
|
| 208 |
+
# a normalized coordinate. The alternate method is to normalize to the min/max of the grid.
|
| 209 |
+
|
| 210 |
+
|
| 211 |
+
class DoMINODataPipe(Dataset):
|
| 212 |
+
"""
|
| 213 |
+
Datapipe for DoMINO
|
| 214 |
+
|
| 215 |
+
Leverages a dataset for the actual reading of the data, and this
|
| 216 |
+
object is responsible for preprocessing the data.
|
| 217 |
+
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
def __init__(
|
| 221 |
+
self,
|
| 222 |
+
input_path,
|
| 223 |
+
model_type: Literal["surface", "volume", "combined"],
|
| 224 |
+
pin_memory: bool = False,
|
| 225 |
+
**data_config_overrides,
|
| 226 |
+
):
|
| 227 |
+
# Perform config packaging and validation
|
| 228 |
+
self.config = DoMINODataConfig(data_path=input_path, **data_config_overrides)
|
| 229 |
+
|
| 230 |
+
# Set up the distributed manager:
|
| 231 |
+
if not DistributedManager.is_initialized():
|
| 232 |
+
DistributedManager.initialize()
|
| 233 |
+
|
| 234 |
+
dist = DistributedManager()
|
| 235 |
+
|
| 236 |
+
# Set devices for the preprocessing and IO target
|
| 237 |
+
self.preproc_device = (
|
| 238 |
+
dist.device if self.config.gpu_preprocessing else torch.device("cpu")
|
| 239 |
+
)
|
| 240 |
+
# The cae_dataset will automatically target this device
|
| 241 |
+
# In an async transfer.
|
| 242 |
+
self.output_device = (
|
| 243 |
+
dist.device if self.config.gpu_output else torch.device("cpu")
|
| 244 |
+
)
|
| 245 |
+
|
| 246 |
+
# Model type determines whether we process surface, volume, or both.
|
| 247 |
+
self.model_type = model_type
|
| 248 |
+
|
| 249 |
+
# Update the arrays for bounding boxes:
|
| 250 |
+
if hasattr(self.config.bounding_box_dims, "max") and hasattr(
|
| 251 |
+
self.config.bounding_box_dims, "min"
|
| 252 |
+
):
|
| 253 |
+
self.config.bounding_box_dims = [
|
| 254 |
+
torch.tensor(
|
| 255 |
+
self.config.bounding_box_dims.max,
|
| 256 |
+
device=self.preproc_device,
|
| 257 |
+
dtype=torch.float32,
|
| 258 |
+
),
|
| 259 |
+
torch.tensor(
|
| 260 |
+
self.config.bounding_box_dims.min,
|
| 261 |
+
device=self.preproc_device,
|
| 262 |
+
dtype=torch.float32,
|
| 263 |
+
),
|
| 264 |
+
]
|
| 265 |
+
self.default_volume_grid = create_grid(
|
| 266 |
+
self.config.bounding_box_dims[0],
|
| 267 |
+
self.config.bounding_box_dims[1],
|
| 268 |
+
self.config.grid_resolution,
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
# And, do the surface bounding box if supplied:
|
| 272 |
+
if hasattr(self.config.bounding_box_dims_surf, "max") and hasattr(
|
| 273 |
+
self.config.bounding_box_dims_surf, "min"
|
| 274 |
+
):
|
| 275 |
+
self.config.bounding_box_dims_surf = [
|
| 276 |
+
torch.tensor(
|
| 277 |
+
self.config.bounding_box_dims_surf.max,
|
| 278 |
+
device=self.preproc_device,
|
| 279 |
+
dtype=torch.float32,
|
| 280 |
+
),
|
| 281 |
+
torch.tensor(
|
| 282 |
+
self.config.bounding_box_dims_surf.min,
|
| 283 |
+
device=self.preproc_device,
|
| 284 |
+
dtype=torch.float32,
|
| 285 |
+
),
|
| 286 |
+
]
|
| 287 |
+
|
| 288 |
+
self.default_surface_grid = create_grid(
|
| 289 |
+
self.config.bounding_box_dims_surf[0],
|
| 290 |
+
self.config.bounding_box_dims_surf[1],
|
| 291 |
+
self.config.grid_resolution,
|
| 292 |
+
)
|
| 293 |
+
|
| 294 |
+
# Ensure the volume and surface scaling factors are torch tensors
|
| 295 |
+
# and on the right device:
|
| 296 |
+
if self.config.volume_factors is not None:
|
| 297 |
+
if not isinstance(self.config.volume_factors, torch.Tensor):
|
| 298 |
+
self.config.volume_factors = torch.from_numpy(
|
| 299 |
+
self.config.volume_factors
|
| 300 |
+
)
|
| 301 |
+
self.config.volume_factors = self.config.volume_factors.to(
|
| 302 |
+
self.preproc_device, dtype=torch.float32
|
| 303 |
+
)
|
| 304 |
+
if self.config.surface_factors is not None:
|
| 305 |
+
if not isinstance(self.config.surface_factors, torch.Tensor):
|
| 306 |
+
self.config.surface_factors = torch.from_numpy(
|
| 307 |
+
self.config.surface_factors
|
| 308 |
+
)
|
| 309 |
+
self.config.surface_factors = self.config.surface_factors.to(
|
| 310 |
+
self.preproc_device, dtype=torch.float32
|
| 311 |
+
)
|
| 312 |
+
|
| 313 |
+
self.dataset = None
|
| 314 |
+
|
| 315 |
+
def compute_stl_scaling_and_surface_grids(
|
| 316 |
+
self,
|
| 317 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 318 |
+
"""
|
| 319 |
+
Compute the min and max for the defining mesh.
|
| 320 |
+
|
| 321 |
+
If the user supplies a bounding box, we use that. Otherwise,
|
| 322 |
+
it raises an error.
|
| 323 |
+
|
| 324 |
+
The returned min/max and grid are used for surface data.
|
| 325 |
+
"""
|
| 326 |
+
|
| 327 |
+
# Check the bounding box is not unit length
|
| 328 |
+
|
| 329 |
+
if self.config.bounding_box_dims_surf is not None:
|
| 330 |
+
s_max = self.config.bounding_box_dims_surf[0]
|
| 331 |
+
s_min = self.config.bounding_box_dims_surf[1]
|
| 332 |
+
surf_grid = self.default_surface_grid
|
| 333 |
+
else:
|
| 334 |
+
raise ValueError("Bounding box dimensions are not set in config")
|
| 335 |
+
|
| 336 |
+
return s_min, s_max, surf_grid
|
| 337 |
+
|
| 338 |
+
def compute_volume_scaling_and_grids(
|
| 339 |
+
self,
|
| 340 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 341 |
+
"""
|
| 342 |
+
Compute the min and max and grid for volume data.
|
| 343 |
+
|
| 344 |
+
If the user supplies a bounding box, we use that. Otherwise,
|
| 345 |
+
it raises an error.
|
| 346 |
+
|
| 347 |
+
"""
|
| 348 |
+
|
| 349 |
+
# Determine the volume min / max locations
|
| 350 |
+
if self.config.bounding_box_dims is not None:
|
| 351 |
+
c_max = self.config.bounding_box_dims[0]
|
| 352 |
+
c_min = self.config.bounding_box_dims[1]
|
| 353 |
+
volume_grid = self.default_volume_grid
|
| 354 |
+
else:
|
| 355 |
+
raise ValueError("Bounding box dimensions are not set in config")
|
| 356 |
+
|
| 357 |
+
return c_min, c_max, volume_grid
|
| 358 |
+
|
| 359 |
+
@profile
|
| 360 |
+
def downsample_geometry(
|
| 361 |
+
self,
|
| 362 |
+
stl_vertices,
|
| 363 |
+
) -> torch.Tensor:
|
| 364 |
+
"""
|
| 365 |
+
Downsample the geometry to the desired number of points.
|
| 366 |
+
|
| 367 |
+
Args:
|
| 368 |
+
stl_vertices: The vertices of the surface.
|
| 369 |
+
"""
|
| 370 |
+
|
| 371 |
+
if self.config.sampling:
|
| 372 |
+
geometry_points = self.config.geom_points_sample
|
| 373 |
+
|
| 374 |
+
geometry_coordinates_sampled, idx_geometry = shuffle_array(
|
| 375 |
+
stl_vertices, geometry_points
|
| 376 |
+
)
|
| 377 |
+
if geometry_coordinates_sampled.shape[0] < geometry_points:
|
| 378 |
+
raise ValueError(
|
| 379 |
+
"Surface mesh has fewer points than requested sample size"
|
| 380 |
+
)
|
| 381 |
+
geom_centers = geometry_coordinates_sampled
|
| 382 |
+
else:
|
| 383 |
+
geom_centers = stl_vertices
|
| 384 |
+
|
| 385 |
+
return geom_centers
|
| 386 |
+
|
| 387 |
+
def process_surface(
|
| 388 |
+
self,
|
| 389 |
+
s_min: torch.Tensor,
|
| 390 |
+
s_max: torch.Tensor,
|
| 391 |
+
c_min: torch.Tensor,
|
| 392 |
+
c_max: torch.Tensor,
|
| 393 |
+
*, # Forcing the rest by keyword only since it's a long list ...
|
| 394 |
+
center_of_mass: torch.Tensor,
|
| 395 |
+
surf_grid: torch.Tensor,
|
| 396 |
+
surface_coordinates: torch.Tensor,
|
| 397 |
+
surface_normals: torch.Tensor,
|
| 398 |
+
surface_sizes: torch.Tensor,
|
| 399 |
+
stl_vertices: torch.Tensor,
|
| 400 |
+
stl_indices: torch.Tensor,
|
| 401 |
+
surface_fields: torch.Tensor | None,
|
| 402 |
+
) -> dict[str, torch.Tensor]:
|
| 403 |
+
nx, ny, nz = self.config.grid_resolution
|
| 404 |
+
|
| 405 |
+
return_dict = {}
|
| 406 |
+
|
| 407 |
+
########################################################################
|
| 408 |
+
# Remove any sizes <= 0:
|
| 409 |
+
########################################################################
|
| 410 |
+
idx = surface_sizes > 0
|
| 411 |
+
surface_sizes = surface_sizes[idx]
|
| 412 |
+
surface_normals = surface_normals[idx]
|
| 413 |
+
surface_coordinates = surface_coordinates[idx]
|
| 414 |
+
if surface_fields is not None:
|
| 415 |
+
surface_fields = surface_fields[idx]
|
| 416 |
+
|
| 417 |
+
########################################################################
|
| 418 |
+
# Reject surface points outside of the Bounding Box
|
| 419 |
+
# NOTE - this is using the VOLUME bounding box!
|
| 420 |
+
########################################################################
|
| 421 |
+
if self.config.sample_in_bbox:
|
| 422 |
+
ids_min = surface_coordinates[:] > c_min
|
| 423 |
+
ids_max = surface_coordinates[:] < c_max
|
| 424 |
+
|
| 425 |
+
ids_in_bbox = ids_min & ids_max
|
| 426 |
+
ids_in_bbox = ids_in_bbox.all(dim=-1)
|
| 427 |
+
|
| 428 |
+
surface_coordinates = surface_coordinates[ids_in_bbox]
|
| 429 |
+
surface_normals = surface_normals[ids_in_bbox]
|
| 430 |
+
surface_sizes = surface_sizes[ids_in_bbox]
|
| 431 |
+
if surface_fields is not None:
|
| 432 |
+
surface_fields = surface_fields[ids_in_bbox]
|
| 433 |
+
|
| 434 |
+
########################################################################
|
| 435 |
+
# Perform Down sampling of the surface fields.
|
| 436 |
+
# Note that we snapshot the full surface coordinates for
|
| 437 |
+
# use in the kNN in the next step.
|
| 438 |
+
########################################################################
|
| 439 |
+
|
| 440 |
+
full_surface_coordinates = surface_coordinates
|
| 441 |
+
full_surface_normals = surface_normals
|
| 442 |
+
full_surface_sizes = surface_sizes
|
| 443 |
+
|
| 444 |
+
if self.config.sampling:
|
| 445 |
+
# Perform the down sampling:
|
| 446 |
+
if self.config.surface_sampling_algorithm == "area_weighted":
|
| 447 |
+
weights = surface_sizes
|
| 448 |
+
else:
|
| 449 |
+
weights = None
|
| 450 |
+
|
| 451 |
+
surface_coordinates_sampled, idx_surface = shuffle_array(
|
| 452 |
+
surface_coordinates,
|
| 453 |
+
self.config.surface_points_sample,
|
| 454 |
+
weights=weights,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
if surface_coordinates_sampled.shape[0] < self.config.surface_points_sample:
|
| 458 |
+
raise ValueError(
|
| 459 |
+
"Surface mesh has fewer points than requested sample size"
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Select out the sampled points for non-neighbor arrays:
|
| 463 |
+
if surface_fields is not None:
|
| 464 |
+
surface_fields = surface_fields[idx_surface]
|
| 465 |
+
|
| 466 |
+
# Subsample the normals and sizes:
|
| 467 |
+
surface_normals = surface_normals[idx_surface]
|
| 468 |
+
surface_sizes = surface_sizes[idx_surface]
|
| 469 |
+
# Update the coordinates to the sampled points:
|
| 470 |
+
surface_coordinates = surface_coordinates_sampled
|
| 471 |
+
|
| 472 |
+
########################################################################
|
| 473 |
+
# Perform a kNN on the surface to find the neighbor information
|
| 474 |
+
########################################################################
|
| 475 |
+
if self.config.num_surface_neighbors > 1:
|
| 476 |
+
# Perform the kNN:
|
| 477 |
+
neighbor_indices, neighbor_distances = knn(
|
| 478 |
+
points=full_surface_coordinates,
|
| 479 |
+
queries=surface_coordinates,
|
| 480 |
+
k=self.config.num_surface_neighbors,
|
| 481 |
+
)
|
| 482 |
+
# print(f"Full surface coordinates shape: {full_surface_coordinates.shape}")
|
| 483 |
+
# Pull out the neighbor elements.
|
| 484 |
+
# Note that `neighbor_indices` is the index into the original,
|
| 485 |
+
# full sized tensors (full_surface_coordinates, etc).
|
| 486 |
+
surface_neighbors = full_surface_coordinates[neighbor_indices][:, 1:]
|
| 487 |
+
surface_neighbors_normals = full_surface_normals[neighbor_indices][:, 1:]
|
| 488 |
+
surface_neighbors_sizes = full_surface_sizes[neighbor_indices][:, 1:]
|
| 489 |
+
else:
|
| 490 |
+
surface_neighbors = surface_coordinates
|
| 491 |
+
surface_neighbors_normals = surface_normals
|
| 492 |
+
surface_neighbors_sizes = surface_sizes
|
| 493 |
+
|
| 494 |
+
# Better to normalize everything after the kNN and sampling
|
| 495 |
+
if self.config.normalize_coordinates:
|
| 496 |
+
surface_coordinates = normalize(surface_coordinates, s_max, s_min)
|
| 497 |
+
surface_neighbors = normalize(surface_neighbors, s_max, s_min)
|
| 498 |
+
center_of_mass = normalize(center_of_mass, s_max, s_min)
|
| 499 |
+
|
| 500 |
+
pos_normals_com_surface = surface_coordinates - center_of_mass
|
| 501 |
+
|
| 502 |
+
########################################################################
|
| 503 |
+
# Apply scaling to the targets, if desired:
|
| 504 |
+
########################################################################
|
| 505 |
+
if self.config.scaling_type is not None and surface_fields is not None:
|
| 506 |
+
surface_fields = self.scale_model_targets(
|
| 507 |
+
surface_fields, self.config.surface_factors
|
| 508 |
+
)
|
| 509 |
+
|
| 510 |
+
return_dict.update(
|
| 511 |
+
{
|
| 512 |
+
"pos_surface_center_of_mass": pos_normals_com_surface,
|
| 513 |
+
"surface_mesh_centers": surface_coordinates,
|
| 514 |
+
"surface_mesh_neighbors": surface_neighbors,
|
| 515 |
+
"surface_normals": surface_normals,
|
| 516 |
+
"surface_neighbors_normals": surface_neighbors_normals,
|
| 517 |
+
"surface_areas": surface_sizes,
|
| 518 |
+
"surface_neighbors_areas": surface_neighbors_sizes,
|
| 519 |
+
}
|
| 520 |
+
)
|
| 521 |
+
if surface_fields is not None:
|
| 522 |
+
return_dict["surface_fields"] = surface_fields
|
| 523 |
+
|
| 524 |
+
return return_dict
|
| 525 |
+
|
| 526 |
+
def process_volume(
|
| 527 |
+
self,
|
| 528 |
+
c_min: torch.Tensor,
|
| 529 |
+
c_max: torch.Tensor,
|
| 530 |
+
volume_coordinates: torch.Tensor,
|
| 531 |
+
volume_grid: torch.Tensor,
|
| 532 |
+
center_of_mass: torch.Tensor,
|
| 533 |
+
stl_vertices: torch.Tensor,
|
| 534 |
+
stl_indices: torch.Tensor,
|
| 535 |
+
volume_fields: torch.Tensor | None,
|
| 536 |
+
) -> dict[str, torch.Tensor]:
|
| 537 |
+
"""
|
| 538 |
+
Preprocess the volume data.
|
| 539 |
+
|
| 540 |
+
First, if configured, we reject points not in the volume bounding box.
|
| 541 |
+
|
| 542 |
+
Next, if sampling is enabled, we sample the volume points and apply that
|
| 543 |
+
sampling to the ground truth too, if it's present.
|
| 544 |
+
|
| 545 |
+
"""
|
| 546 |
+
########################################################################
|
| 547 |
+
# Reject points outside the volumetric BBox
|
| 548 |
+
########################################################################
|
| 549 |
+
if self.config.sample_in_bbox:
|
| 550 |
+
# Remove points in the volume that are outside
|
| 551 |
+
# of the bbox area.
|
| 552 |
+
min_check = volume_coordinates[:] > c_min
|
| 553 |
+
max_check = volume_coordinates[:] < c_max
|
| 554 |
+
|
| 555 |
+
ids_in_bbox = min_check & max_check
|
| 556 |
+
ids_in_bbox = ids_in_bbox.all(dim=1)
|
| 557 |
+
|
| 558 |
+
volume_coordinates = volume_coordinates[ids_in_bbox]
|
| 559 |
+
if volume_fields is not None:
|
| 560 |
+
volume_fields = volume_fields[ids_in_bbox]
|
| 561 |
+
|
| 562 |
+
########################################################################
|
| 563 |
+
# Apply sampling to the volume coordinates and fields
|
| 564 |
+
########################################################################
|
| 565 |
+
|
| 566 |
+
# If the volume data has been sampled from disk, directly, then
|
| 567 |
+
# still apply sampling. We over-pull from disk deliberately.
|
| 568 |
+
if self.config.sampling:
|
| 569 |
+
# Generate a series of idx to sample the volume
|
| 570 |
+
# without replacement
|
| 571 |
+
volume_coordinates_sampled, idx_volume = shuffle_array(
|
| 572 |
+
volume_coordinates, self.config.volume_points_sample
|
| 573 |
+
)
|
| 574 |
+
volume_coordinates_sampled = volume_coordinates[idx_volume]
|
| 575 |
+
# In case too few points are in the sampled data (because the
|
| 576 |
+
# inputs were too few), pad the outputs:
|
| 577 |
+
if volume_coordinates_sampled.shape[0] < self.config.volume_points_sample:
|
| 578 |
+
raise ValueError(
|
| 579 |
+
"Volume mesh has fewer points than requested sample size"
|
| 580 |
+
)
|
| 581 |
+
|
| 582 |
+
# Apply the same sampling to the targets, too:
|
| 583 |
+
if volume_fields is not None:
|
| 584 |
+
volume_fields = volume_fields[idx_volume]
|
| 585 |
+
|
| 586 |
+
volume_coordinates = volume_coordinates_sampled
|
| 587 |
+
|
| 588 |
+
########################################################################
|
| 589 |
+
# Apply normalization to the coordinates, if desired:
|
| 590 |
+
########################################################################
|
| 591 |
+
if self.config.normalize_coordinates:
|
| 592 |
+
volume_coordinates = normalize(volume_coordinates, c_max, c_min)
|
| 593 |
+
grid = normalize(volume_grid, c_max, c_min)
|
| 594 |
+
normed_vertices = normalize(stl_vertices, c_max, c_min)
|
| 595 |
+
center_of_mass = normalize(center_of_mass, c_max, c_min)
|
| 596 |
+
else:
|
| 597 |
+
grid = volume_grid
|
| 598 |
+
normed_vertices = stl_vertices
|
| 599 |
+
center_of_mass = center_of_mass
|
| 600 |
+
|
| 601 |
+
########################################################################
|
| 602 |
+
# Apply scaling to the targets, if desired:
|
| 603 |
+
########################################################################
|
| 604 |
+
if self.config.scaling_type is not None and volume_fields is not None:
|
| 605 |
+
volume_fields = self.scale_model_targets(
|
| 606 |
+
volume_fields, self.config.volume_factors
|
| 607 |
+
)
|
| 608 |
+
|
| 609 |
+
########################################################################
|
| 610 |
+
# Compute Signed Distance Function for volumetric quantities
|
| 611 |
+
# Note - the SDF happens here, after volume data processing finishes,
|
| 612 |
+
# because we need to use the (maybe) normalized volume coordinates and grid
|
| 613 |
+
########################################################################
|
| 614 |
+
|
| 615 |
+
# SDF calculation on the volume grid using WARP
|
| 616 |
+
sdf_grid, _ = signed_distance_field(
|
| 617 |
+
normed_vertices,
|
| 618 |
+
stl_indices,
|
| 619 |
+
grid,
|
| 620 |
+
use_sign_winding_number=True,
|
| 621 |
+
)
|
| 622 |
+
|
| 623 |
+
# Get the SDF of all the selected volume coordinates,
|
| 624 |
+
# And keep the closest point to each one.
|
| 625 |
+
sdf_nodes, sdf_node_closest_point = signed_distance_field(
|
| 626 |
+
normed_vertices,
|
| 627 |
+
stl_indices,
|
| 628 |
+
volume_coordinates,
|
| 629 |
+
use_sign_winding_number=True,
|
| 630 |
+
)
|
| 631 |
+
sdf_nodes = sdf_nodes.reshape((-1, 1))
|
| 632 |
+
|
| 633 |
+
# Use the closest point from the mesh to compute the volume encodings:
|
| 634 |
+
pos_normals_closest_vol, pos_normals_com_vol = self.calculate_volume_encoding(
|
| 635 |
+
volume_coordinates, sdf_node_closest_point, center_of_mass
|
| 636 |
+
)
|
| 637 |
+
|
| 638 |
+
return_dict = {
|
| 639 |
+
"volume_mesh_centers": volume_coordinates,
|
| 640 |
+
"sdf_nodes": sdf_nodes,
|
| 641 |
+
"grid": grid,
|
| 642 |
+
"sdf_grid": sdf_grid,
|
| 643 |
+
"pos_volume_closest": pos_normals_closest_vol,
|
| 644 |
+
"pos_volume_center_of_mass": pos_normals_com_vol,
|
| 645 |
+
}
|
| 646 |
+
|
| 647 |
+
if volume_fields is not None:
|
| 648 |
+
return_dict["volume_fields"] = volume_fields
|
| 649 |
+
|
| 650 |
+
return return_dict
|
| 651 |
+
|
| 652 |
+
def calculate_volume_encoding(
|
| 653 |
+
self,
|
| 654 |
+
volume_coordinates: torch.Tensor,
|
| 655 |
+
sdf_node_closest_point: torch.Tensor,
|
| 656 |
+
center_of_mass: torch.Tensor,
|
| 657 |
+
):
|
| 658 |
+
pos_normals_closest_vol = volume_coordinates - sdf_node_closest_point
|
| 659 |
+
pos_normals_com_vol = volume_coordinates - center_of_mass
|
| 660 |
+
|
| 661 |
+
return pos_normals_closest_vol, pos_normals_com_vol
|
| 662 |
+
|
| 663 |
+
@torch.no_grad()
|
| 664 |
+
def process_data(self, data_dict):
|
| 665 |
+
# Validate that all required keys are present in data_dict
|
| 666 |
+
required_keys = [
|
| 667 |
+
"global_params_values",
|
| 668 |
+
"global_params_reference",
|
| 669 |
+
"stl_coordinates",
|
| 670 |
+
"stl_faces",
|
| 671 |
+
"stl_centers",
|
| 672 |
+
"stl_areas",
|
| 673 |
+
]
|
| 674 |
+
missing_keys = [key for key in required_keys if key not in data_dict]
|
| 675 |
+
if missing_keys:
|
| 676 |
+
raise ValueError(
|
| 677 |
+
f"Missing required keys in data_dict: {missing_keys}. "
|
| 678 |
+
f"Required keys are: {required_keys}"
|
| 679 |
+
)
|
| 680 |
+
|
| 681 |
+
# Start building the preprocessed return dict:
|
| 682 |
+
return_dict = {
|
| 683 |
+
"global_params_values": data_dict["global_params_values"],
|
| 684 |
+
"global_params_reference": data_dict["global_params_reference"],
|
| 685 |
+
}
|
| 686 |
+
|
| 687 |
+
# DoMINO's sharded datapipe can be tricky - output shapes are not always
|
| 688 |
+
# so simple to calculate, since much of the datapipe is dynamic.
|
| 689 |
+
# The datset will read in sharded data, to minimize IO.
|
| 690 |
+
# We collect it all locally, here, and then scatter
|
| 691 |
+
# Appropriately for the outputs
|
| 692 |
+
|
| 693 |
+
if self.config.shard_grid or self.config.shard_points:
|
| 694 |
+
# Get the mesh:
|
| 695 |
+
mesh = data_dict["stl_coordinates"]._spec.mesh
|
| 696 |
+
local_data_dict = {}
|
| 697 |
+
for key, value in data_dict.items():
|
| 698 |
+
local_data_dict[key] = value.full_tensor()
|
| 699 |
+
|
| 700 |
+
data_dict = local_data_dict
|
| 701 |
+
|
| 702 |
+
########################################################################
|
| 703 |
+
# Process the core STL information
|
| 704 |
+
########################################################################
|
| 705 |
+
|
| 706 |
+
# This function gets information about the surface scale,
|
| 707 |
+
# and decides what the surface grid will be:
|
| 708 |
+
|
| 709 |
+
s_min, s_max, surf_grid = self.compute_stl_scaling_and_surface_grids()
|
| 710 |
+
|
| 711 |
+
# We always need to calculate the SDF on the surface grid:
|
| 712 |
+
# This is for the SDF Later:
|
| 713 |
+
if self.config.normalize_coordinates:
|
| 714 |
+
normed_vertices = normalize(data_dict["stl_coordinates"], s_max, s_min)
|
| 715 |
+
surf_grid = normalize(surf_grid, s_max, s_min)
|
| 716 |
+
else:
|
| 717 |
+
normed_vertices = data_dict["stl_coordinates"]
|
| 718 |
+
|
| 719 |
+
# For SDF calculations, make sure the mesh_indices_flattened is an integer array:
|
| 720 |
+
mesh_indices_flattened = data_dict["stl_faces"].to(torch.int32)
|
| 721 |
+
|
| 722 |
+
# Compute signed distance function for the surface grid:
|
| 723 |
+
sdf_surf_grid, _ = signed_distance_field(
|
| 724 |
+
mesh_vertices=normed_vertices,
|
| 725 |
+
mesh_indices=mesh_indices_flattened,
|
| 726 |
+
input_points=surf_grid,
|
| 727 |
+
use_sign_winding_number=True,
|
| 728 |
+
)
|
| 729 |
+
return_dict["sdf_surf_grid"] = sdf_surf_grid
|
| 730 |
+
return_dict["surf_grid"] = surf_grid
|
| 731 |
+
|
| 732 |
+
# Store this only if normalization is active:
|
| 733 |
+
if self.config.normalize_coordinates:
|
| 734 |
+
return_dict["surface_min_max"] = torch.stack([s_min, s_max])
|
| 735 |
+
|
| 736 |
+
# This is a center of mass computation for the stl surface,
|
| 737 |
+
# using the size of each mesh point as weight.
|
| 738 |
+
center_of_mass = calculate_center_of_mass(
|
| 739 |
+
data_dict["stl_centers"], data_dict["stl_areas"]
|
| 740 |
+
)
|
| 741 |
+
|
| 742 |
+
# This will apply downsampling if needed to the geometry coordinates
|
| 743 |
+
geom_centers = self.downsample_geometry(
|
| 744 |
+
stl_vertices=data_dict["stl_coordinates"],
|
| 745 |
+
)
|
| 746 |
+
return_dict["geometry_coordinates"] = geom_centers
|
| 747 |
+
|
| 748 |
+
########################################################################
|
| 749 |
+
# Determine the volumetric bounds of the data:
|
| 750 |
+
########################################################################
|
| 751 |
+
# Compute the min/max for volume an the unnomralized grid:
|
| 752 |
+
c_min, c_max, volume_grid = self.compute_volume_scaling_and_grids()
|
| 753 |
+
|
| 754 |
+
########################################################################
|
| 755 |
+
# Process the surface data
|
| 756 |
+
########################################################################
|
| 757 |
+
if self.model_type == "surface" or self.model_type == "combined":
|
| 758 |
+
surface_fields_raw = (
|
| 759 |
+
data_dict["surface_fields"] if "surface_fields" in data_dict else None
|
| 760 |
+
)
|
| 761 |
+
surface_dict = self.process_surface(
|
| 762 |
+
s_min,
|
| 763 |
+
s_max,
|
| 764 |
+
c_min,
|
| 765 |
+
c_max,
|
| 766 |
+
center_of_mass=center_of_mass,
|
| 767 |
+
surf_grid=surf_grid,
|
| 768 |
+
surface_coordinates=data_dict["surface_mesh_centers"],
|
| 769 |
+
surface_normals=data_dict["surface_normals"],
|
| 770 |
+
surface_sizes=data_dict["surface_areas"],
|
| 771 |
+
stl_vertices=data_dict["stl_coordinates"],
|
| 772 |
+
stl_indices=mesh_indices_flattened,
|
| 773 |
+
surface_fields=surface_fields_raw,
|
| 774 |
+
)
|
| 775 |
+
|
| 776 |
+
return_dict.update(surface_dict)
|
| 777 |
+
|
| 778 |
+
########################################################################
|
| 779 |
+
# Process the volume data
|
| 780 |
+
########################################################################
|
| 781 |
+
# For volume data, we store this only if normalizing coordinates:
|
| 782 |
+
if self.model_type == "volume" or self.model_type == "combined":
|
| 783 |
+
if self.config.normalize_coordinates:
|
| 784 |
+
return_dict["volume_min_max"] = torch.stack([c_min, c_max])
|
| 785 |
+
|
| 786 |
+
if self.model_type == "volume" or self.model_type == "combined":
|
| 787 |
+
volume_fields_raw = (
|
| 788 |
+
data_dict["volume_fields"] if "volume_fields" in data_dict else None
|
| 789 |
+
)
|
| 790 |
+
volume_dict = self.process_volume(
|
| 791 |
+
c_min,
|
| 792 |
+
c_max,
|
| 793 |
+
volume_coordinates=data_dict["volume_mesh_centers"],
|
| 794 |
+
volume_grid=volume_grid,
|
| 795 |
+
center_of_mass=center_of_mass,
|
| 796 |
+
stl_vertices=data_dict["stl_coordinates"],
|
| 797 |
+
stl_indices=mesh_indices_flattened,
|
| 798 |
+
volume_fields=volume_fields_raw,
|
| 799 |
+
)
|
| 800 |
+
|
| 801 |
+
return_dict.update(volume_dict)
|
| 802 |
+
|
| 803 |
+
# For domain parallelism, shard everything appropriately:
|
| 804 |
+
if self.config.shard_grid or self.config.shard_points:
|
| 805 |
+
# Mesh was defined above!
|
| 806 |
+
output_dict = {}
|
| 807 |
+
|
| 808 |
+
# For scattering, we need to know the _global_ index of rank
|
| 809 |
+
# 0 on this mesh:
|
| 810 |
+
global_index = dist.get_global_rank(mesh.get_group(), 0)
|
| 811 |
+
|
| 812 |
+
for key, value in return_dict.items():
|
| 813 |
+
grid_placements = (
|
| 814 |
+
[
|
| 815 |
+
Shard(0),
|
| 816 |
+
]
|
| 817 |
+
if self.config.shard_grid
|
| 818 |
+
else [
|
| 819 |
+
Replicate(),
|
| 820 |
+
]
|
| 821 |
+
)
|
| 822 |
+
point_placements = (
|
| 823 |
+
[
|
| 824 |
+
Shard(0),
|
| 825 |
+
]
|
| 826 |
+
if self.config.shard_points
|
| 827 |
+
else [
|
| 828 |
+
Replicate(),
|
| 829 |
+
]
|
| 830 |
+
)
|
| 831 |
+
if key == "volume_min_max":
|
| 832 |
+
output_dict[key] = ShardTensor.from_local(
|
| 833 |
+
value,
|
| 834 |
+
mesh,
|
| 835 |
+
[
|
| 836 |
+
Replicate(),
|
| 837 |
+
],
|
| 838 |
+
)
|
| 839 |
+
elif key == "surface_min_max":
|
| 840 |
+
output_dict[key] = ShardTensor.from_local(
|
| 841 |
+
value,
|
| 842 |
+
mesh,
|
| 843 |
+
[
|
| 844 |
+
Replicate(),
|
| 845 |
+
],
|
| 846 |
+
)
|
| 847 |
+
elif not isinstance(value, ShardTensor):
|
| 848 |
+
if "grid" in key:
|
| 849 |
+
output_dict[key] = scatter_tensor(
|
| 850 |
+
value.contiguous(),
|
| 851 |
+
global_index,
|
| 852 |
+
mesh,
|
| 853 |
+
grid_placements,
|
| 854 |
+
global_shape=value.shape,
|
| 855 |
+
dtype=value.dtype,
|
| 856 |
+
)
|
| 857 |
+
else:
|
| 858 |
+
output_dict[key] = scatter_tensor(
|
| 859 |
+
value.contiguous(),
|
| 860 |
+
global_index,
|
| 861 |
+
mesh,
|
| 862 |
+
point_placements,
|
| 863 |
+
global_shape=value.shape,
|
| 864 |
+
dtype=value.dtype,
|
| 865 |
+
)
|
| 866 |
+
else:
|
| 867 |
+
output_dict[key] = value
|
| 868 |
+
|
| 869 |
+
return_dict = output_dict
|
| 870 |
+
|
| 871 |
+
return return_dict
|
| 872 |
+
|
| 873 |
+
def scale_model_targets(
|
| 874 |
+
self, fields: torch.Tensor, factors: torch.Tensor
|
| 875 |
+
) -> torch.Tensor:
|
| 876 |
+
"""
|
| 877 |
+
Scale the model targets based on the configured scaling factors.
|
| 878 |
+
"""
|
| 879 |
+
if self.config.scaling_type == "mean_std_scaling":
|
| 880 |
+
field_mean = factors[0]
|
| 881 |
+
field_std = factors[1]
|
| 882 |
+
return standardize(fields, field_mean, field_std)
|
| 883 |
+
elif self.config.scaling_type == "min_max_scaling":
|
| 884 |
+
field_min = factors[1]
|
| 885 |
+
field_max = factors[0]
|
| 886 |
+
return normalize(fields, field_max, field_min)
|
| 887 |
+
|
| 888 |
+
def unscale_model_outputs(
|
| 889 |
+
self,
|
| 890 |
+
volume_fields: torch.Tensor | None = None,
|
| 891 |
+
surface_fields: torch.Tensor | None = None,
|
| 892 |
+
):
|
| 893 |
+
"""
|
| 894 |
+
Unscale the model outputs based on the configured scaling factors.
|
| 895 |
+
|
| 896 |
+
The unscaling is included here to make it a consistent interface regardless
|
| 897 |
+
of the scaling factors and type used.
|
| 898 |
+
|
| 899 |
+
"""
|
| 900 |
+
|
| 901 |
+
# This is a step to make sure we can apply to sharded outputs:
|
| 902 |
+
if volume_fields is not None and isinstance(volume_fields, ShardTensor):
|
| 903 |
+
volume_spec = volume_fields._spec
|
| 904 |
+
volume_fields = ShardTensor.to_local(volume_fields)
|
| 905 |
+
else:
|
| 906 |
+
volume_spec = None
|
| 907 |
+
|
| 908 |
+
if surface_fields is not None and isinstance(surface_fields, ShardTensor):
|
| 909 |
+
surface_spec = surface_fields._spec
|
| 910 |
+
surface_fields = ShardTensor.to_local(surface_fields)
|
| 911 |
+
else:
|
| 912 |
+
surface_spec = None
|
| 913 |
+
|
| 914 |
+
if volume_fields is not None:
|
| 915 |
+
if self.config.scaling_type == "mean_std_scaling":
|
| 916 |
+
vol_mean = self.config.volume_factors[0]
|
| 917 |
+
vol_std = self.config.volume_factors[1]
|
| 918 |
+
volume_fields = unstandardize(volume_fields, vol_mean, vol_std)
|
| 919 |
+
elif self.config.scaling_type == "min_max_scaling":
|
| 920 |
+
vol_min = self.config.volume_factors[1]
|
| 921 |
+
vol_max = self.config.volume_factors[0]
|
| 922 |
+
volume_fields = unnormalize(volume_fields, vol_max, vol_min)
|
| 923 |
+
if surface_fields is not None:
|
| 924 |
+
if self.config.scaling_type == "mean_std_scaling":
|
| 925 |
+
surf_mean = self.config.surface_factors[0]
|
| 926 |
+
surf_std = self.config.surface_factors[1]
|
| 927 |
+
surface_fields = unstandardize(surface_fields, surf_mean, surf_std)
|
| 928 |
+
elif self.config.scaling_type == "min_max_scaling":
|
| 929 |
+
surf_min = self.config.surface_factors[1]
|
| 930 |
+
surf_max = self.config.surface_factors[0]
|
| 931 |
+
surface_fields = unnormalize(surface_fields, surf_max, surf_min)
|
| 932 |
+
|
| 933 |
+
if volume_spec is not None:
|
| 934 |
+
volume_fields = ShardTensor.from_local(
|
| 935 |
+
volume_fields,
|
| 936 |
+
device_mesh=volume_spec.mesh,
|
| 937 |
+
placements=volume_spec.placements,
|
| 938 |
+
sharding_shapes=volume_spec.sharding_shapes(),
|
| 939 |
+
)
|
| 940 |
+
if surface_spec is not None:
|
| 941 |
+
surface_fields = ShardTensor.from_local(
|
| 942 |
+
surface_fields,
|
| 943 |
+
device_mesh=surface_spec.mesh,
|
| 944 |
+
placements=surface_spec.placements,
|
| 945 |
+
sharding_shapes=surface_spec.sharding_shapes(),
|
| 946 |
+
)
|
| 947 |
+
|
| 948 |
+
return volume_fields, surface_fields
|
| 949 |
+
|
| 950 |
+
def set_dataset(self, dataset: Iterable) -> None:
|
| 951 |
+
"""
|
| 952 |
+
Pass a dataset to the datapipe to enable iterating over both in one pass.
|
| 953 |
+
"""
|
| 954 |
+
self.dataset = dataset
|
| 955 |
+
|
| 956 |
+
if self.config.volume_sample_from_disk:
|
| 957 |
+
# We deliberately double the data to read compared to the sampling size:
|
| 958 |
+
self.dataset.set_volume_sampling_size(
|
| 959 |
+
100 * self.config.volume_points_sample
|
| 960 |
+
)
|
| 961 |
+
|
| 962 |
+
def __len__(self):
|
| 963 |
+
if self.dataset is not None:
|
| 964 |
+
return len(self.dataset)
|
| 965 |
+
else:
|
| 966 |
+
return 0
|
| 967 |
+
|
| 968 |
+
def __getitem__(self, idx):
|
| 969 |
+
"""
|
| 970 |
+
Function for fetching and processing a single file's data.
|
| 971 |
+
|
| 972 |
+
Domino, in general, expects one example per file and the files
|
| 973 |
+
are relatively large due to the mesh size.
|
| 974 |
+
|
| 975 |
+
Requires the user to have set a dataset via `set_dataset`.
|
| 976 |
+
"""
|
| 977 |
+
if self.dataset is None:
|
| 978 |
+
raise ValueError("Dataset is not present")
|
| 979 |
+
|
| 980 |
+
# Get the data from the dataset.
|
| 981 |
+
# Under the hood, this may be fetching preloaded data.
|
| 982 |
+
data_dict = self.dataset[idx]
|
| 983 |
+
|
| 984 |
+
return self.__call__(data_dict)
|
| 985 |
+
|
| 986 |
+
def __call__(self, data_dict: dict[str, torch.Tensor]) -> dict[str, torch.Tensor]:
|
| 987 |
+
"""
|
| 988 |
+
Process the incoming data dictionary.
|
| 989 |
+
- Processes the data
|
| 990 |
+
- moves it to GPU
|
| 991 |
+
- adds a batch dimension
|
| 992 |
+
|
| 993 |
+
Args:
|
| 994 |
+
data_dict: Dictionary containing the data to process as torch.Tensors.
|
| 995 |
+
|
| 996 |
+
Returns:
|
| 997 |
+
Dictionary containing the processed data as torch.Tensors.
|
| 998 |
+
|
| 999 |
+
"""
|
| 1000 |
+
data_dict = self.process_data(data_dict)
|
| 1001 |
+
|
| 1002 |
+
# If the data is not on the target device, put it there:
|
| 1003 |
+
for key, value in data_dict.items():
|
| 1004 |
+
if value.device != self.output_device:
|
| 1005 |
+
data_dict[key] = value.to(self.output_device)
|
| 1006 |
+
|
| 1007 |
+
# Add a batch dimension to the data_dict
|
| 1008 |
+
data_dict = {k: v.unsqueeze(0) for k, v in data_dict.items()}
|
| 1009 |
+
|
| 1010 |
+
return data_dict
|
| 1011 |
+
|
| 1012 |
+
def __iter__(self):
|
| 1013 |
+
if self.dataset is None:
|
| 1014 |
+
raise ValueError(
|
| 1015 |
+
"Dataset is not present, can not use the datapipe as an iterator."
|
| 1016 |
+
)
|
| 1017 |
+
|
| 1018 |
+
for i, batch in enumerate(self.dataset):
|
| 1019 |
+
yield self.__call__(batch)
|
| 1020 |
+
|
| 1021 |
+
|
| 1022 |
+
def compute_scaling_factors(
|
| 1023 |
+
cfg: DictConfig,
|
| 1024 |
+
input_path: str,
|
| 1025 |
+
target_keys: list[str],
|
| 1026 |
+
max_samples=20,
|
| 1027 |
+
) -> tuple[torch.Tensor, torch.Tensor, torch.Tensor, torch.Tensor]:
|
| 1028 |
+
"""
|
| 1029 |
+
Using the dataset at the path, compute the mean, std, min, and max of the target keys.
|
| 1030 |
+
|
| 1031 |
+
Args:
|
| 1032 |
+
cfg: Hydra configuration object containing all parameters
|
| 1033 |
+
input_path: Path to the dataset to load.
|
| 1034 |
+
target_keys: List of keys to compute the mean, std, min, and max of.
|
| 1035 |
+
use_cache: (deprecated) This argument has no effect.
|
| 1036 |
+
"""
|
| 1037 |
+
|
| 1038 |
+
device = torch.device("cuda") if torch.cuda.is_available() else torch.device("cpu")
|
| 1039 |
+
|
| 1040 |
+
dataset = CAEDataset(
|
| 1041 |
+
data_dir=input_path,
|
| 1042 |
+
keys_to_read=target_keys,
|
| 1043 |
+
keys_to_read_if_available={},
|
| 1044 |
+
output_device=device,
|
| 1045 |
+
)
|
| 1046 |
+
|
| 1047 |
+
mean, std, min_val, max_val = compute_mean_std_min_max(
|
| 1048 |
+
dataset,
|
| 1049 |
+
field_keys=target_keys,
|
| 1050 |
+
max_samples=max_samples,
|
| 1051 |
+
)
|
| 1052 |
+
|
| 1053 |
+
return mean, std, min_val, max_val
|
| 1054 |
+
|
| 1055 |
+
|
| 1056 |
+
class CachedDoMINODataset(Dataset):
|
| 1057 |
+
"""
|
| 1058 |
+
Dataset for reading cached DoMINO data files, with optional resampling.
|
| 1059 |
+
Acts as a drop-in replacement for DoMINODataPipe.
|
| 1060 |
+
"""
|
| 1061 |
+
|
| 1062 |
+
# @nvtx_annotate(message="CachedDoMINODataset __init__")
|
| 1063 |
+
def __init__(
|
| 1064 |
+
self,
|
| 1065 |
+
data_path: Union[str, Path],
|
| 1066 |
+
phase: Literal["train", "val", "test"] = "train",
|
| 1067 |
+
sampling: bool = False,
|
| 1068 |
+
volume_points_sample: Optional[int] = None,
|
| 1069 |
+
surface_points_sample: Optional[int] = None,
|
| 1070 |
+
geom_points_sample: Optional[int] = None,
|
| 1071 |
+
model_type=None, # Model_type, surface, volume or combined
|
| 1072 |
+
deterministic_seed=False,
|
| 1073 |
+
surface_sampling_algorithm="area_weighted",
|
| 1074 |
+
):
|
| 1075 |
+
super().__init__()
|
| 1076 |
+
|
| 1077 |
+
self.model_type = model_type
|
| 1078 |
+
if deterministic_seed:
|
| 1079 |
+
np.random.seed(42)
|
| 1080 |
+
|
| 1081 |
+
if isinstance(data_path, str):
|
| 1082 |
+
data_path = Path(data_path)
|
| 1083 |
+
self.data_path = data_path.expanduser()
|
| 1084 |
+
|
| 1085 |
+
if not self.data_path.exists():
|
| 1086 |
+
raise AssertionError(f"Path {self.data_path} does not exist")
|
| 1087 |
+
if not self.data_path.is_dir():
|
| 1088 |
+
raise AssertionError(f"Path {self.data_path} is not a directory")
|
| 1089 |
+
|
| 1090 |
+
self.deterministic_seed = deterministic_seed
|
| 1091 |
+
self.sampling = sampling
|
| 1092 |
+
self.volume_points = volume_points_sample
|
| 1093 |
+
self.surface_points = surface_points_sample
|
| 1094 |
+
self.geom_points = geom_points_sample
|
| 1095 |
+
self.surface_sampling_algorithm = surface_sampling_algorithm
|
| 1096 |
+
|
| 1097 |
+
self.filenames = get_filenames(self.data_path, exclude_dirs=True)
|
| 1098 |
+
|
| 1099 |
+
total_files = len(self.filenames)
|
| 1100 |
+
|
| 1101 |
+
self.phase = phase
|
| 1102 |
+
self.indices = np.array(range(total_files))
|
| 1103 |
+
|
| 1104 |
+
np.random.shuffle(self.indices)
|
| 1105 |
+
|
| 1106 |
+
if not self.filenames:
|
| 1107 |
+
raise AssertionError(f"No cached files found in {self.data_path}")
|
| 1108 |
+
|
| 1109 |
+
def __len__(self):
|
| 1110 |
+
return len(self.indices)
|
| 1111 |
+
|
| 1112 |
+
# @nvtx_annotate(message="CachedDoMINODataset __getitem__")
|
| 1113 |
+
def __getitem__(self, idx):
|
| 1114 |
+
if self.deterministic_seed:
|
| 1115 |
+
np.random.seed(idx)
|
| 1116 |
+
nvtx.range_push("Load cached file")
|
| 1117 |
+
|
| 1118 |
+
index = self.indices[idx]
|
| 1119 |
+
cfd_filename = self.filenames[index]
|
| 1120 |
+
|
| 1121 |
+
filepath = self.data_path / cfd_filename
|
| 1122 |
+
result = np.load(filepath, allow_pickle=True).item()
|
| 1123 |
+
result = {
|
| 1124 |
+
k: torch.from_numpy(v) if isinstance(v, np.ndarray) else v
|
| 1125 |
+
for k, v in result.items()
|
| 1126 |
+
}
|
| 1127 |
+
|
| 1128 |
+
nvtx.range_pop()
|
| 1129 |
+
if not self.sampling:
|
| 1130 |
+
return result
|
| 1131 |
+
|
| 1132 |
+
nvtx.range_push("Sample points")
|
| 1133 |
+
|
| 1134 |
+
# Sample volume points if present
|
| 1135 |
+
if "volume_mesh_centers" in result and self.volume_points:
|
| 1136 |
+
coords_sampled, idx_volume = shuffle_array(
|
| 1137 |
+
result["volume_mesh_centers"], self.volume_points
|
| 1138 |
+
)
|
| 1139 |
+
if coords_sampled.shape[0] < self.volume_points:
|
| 1140 |
+
coords_sampled = pad(
|
| 1141 |
+
coords_sampled, self.volume_points, pad_value=-10.0
|
| 1142 |
+
)
|
| 1143 |
+
|
| 1144 |
+
result["volume_mesh_centers"] = coords_sampled
|
| 1145 |
+
for key in [
|
| 1146 |
+
"volume_fields",
|
| 1147 |
+
"pos_volume_closest",
|
| 1148 |
+
"pos_volume_center_of_mass",
|
| 1149 |
+
"sdf_nodes",
|
| 1150 |
+
]:
|
| 1151 |
+
if key in result:
|
| 1152 |
+
result[key] = result[key][idx_volume]
|
| 1153 |
+
|
| 1154 |
+
# Sample surface points if present
|
| 1155 |
+
if "surface_mesh_centers" in result and self.surface_points:
|
| 1156 |
+
if self.surface_sampling_algorithm == "area_weighted":
|
| 1157 |
+
coords_sampled, idx_surface = shuffle_array(
|
| 1158 |
+
points=result["surface_mesh_centers"],
|
| 1159 |
+
n_points=self.surface_points,
|
| 1160 |
+
weights=result["surface_areas"],
|
| 1161 |
+
)
|
| 1162 |
+
else:
|
| 1163 |
+
coords_sampled, idx_surface = shuffle_array(
|
| 1164 |
+
result["surface_mesh_centers"], self.surface_points
|
| 1165 |
+
)
|
| 1166 |
+
|
| 1167 |
+
if coords_sampled.shape[0] < self.surface_points:
|
| 1168 |
+
coords_sampled = pad(
|
| 1169 |
+
coords_sampled, self.surface_points, pad_value=-10.0
|
| 1170 |
+
)
|
| 1171 |
+
|
| 1172 |
+
ii = result["neighbor_indices"]
|
| 1173 |
+
result["surface_mesh_neighbors"] = result["surface_mesh_centers"][ii]
|
| 1174 |
+
result["surface_neighbors_normals"] = result["surface_normals"][ii]
|
| 1175 |
+
result["surface_neighbors_areas"] = result["surface_areas"][ii]
|
| 1176 |
+
|
| 1177 |
+
result["surface_mesh_centers"] = coords_sampled
|
| 1178 |
+
|
| 1179 |
+
for key in [
|
| 1180 |
+
"surface_fields",
|
| 1181 |
+
"surface_areas",
|
| 1182 |
+
"surface_normals",
|
| 1183 |
+
"pos_surface_center_of_mass",
|
| 1184 |
+
"surface_mesh_neighbors",
|
| 1185 |
+
"surface_neighbors_normals",
|
| 1186 |
+
"surface_neighbors_areas",
|
| 1187 |
+
]:
|
| 1188 |
+
if key in result:
|
| 1189 |
+
result[key] = result[key][idx_surface]
|
| 1190 |
+
|
| 1191 |
+
del result["neighbor_indices"]
|
| 1192 |
+
|
| 1193 |
+
# Sample geometry points if present
|
| 1194 |
+
if "geometry_coordinates" in result and self.geom_points:
|
| 1195 |
+
coords_sampled, _ = shuffle_array(
|
| 1196 |
+
result["geometry_coordinates"], self.geom_points
|
| 1197 |
+
)
|
| 1198 |
+
if coords_sampled.shape[0] < self.geom_points:
|
| 1199 |
+
coords_sampled = pad(coords_sampled, self.geom_points, pad_value=-100.0)
|
| 1200 |
+
result["geometry_coordinates"] = coords_sampled
|
| 1201 |
+
|
| 1202 |
+
nvtx.range_pop()
|
| 1203 |
+
return result
|
| 1204 |
+
|
| 1205 |
+
|
| 1206 |
+
def create_domino_dataset(
|
| 1207 |
+
cfg: DictConfig,
|
| 1208 |
+
phase: Literal["train", "val", "test"],
|
| 1209 |
+
keys_to_read: list[str],
|
| 1210 |
+
keys_to_read_if_available: dict[str, torch.Tensor],
|
| 1211 |
+
vol_factors: list[float],
|
| 1212 |
+
surf_factors: list[float],
|
| 1213 |
+
normalize_coordinates: bool = True,
|
| 1214 |
+
sample_in_bbox: bool = True,
|
| 1215 |
+
sampling: bool = True,
|
| 1216 |
+
device_mesh: torch.distributed.DeviceMesh | None = None,
|
| 1217 |
+
placements: dict[str, torch.distributed.tensor.Placement] | None = None,
|
| 1218 |
+
):
|
| 1219 |
+
model_type = cfg.model.model_type
|
| 1220 |
+
if phase == "train":
|
| 1221 |
+
input_path = cfg.data.input_dir
|
| 1222 |
+
dataloader_cfg = cfg.train.dataloader
|
| 1223 |
+
elif phase == "val":
|
| 1224 |
+
input_path = cfg.data.input_dir_val
|
| 1225 |
+
dataloader_cfg = cfg.val.dataloader
|
| 1226 |
+
elif phase == "test":
|
| 1227 |
+
input_path = cfg.eval.test_path
|
| 1228 |
+
dataloader_cfg = None
|
| 1229 |
+
else:
|
| 1230 |
+
raise ValueError(f"Invalid phase {phase}")
|
| 1231 |
+
|
| 1232 |
+
if cfg.data_processor.use_cache:
|
| 1233 |
+
return CachedDoMINODataset(
|
| 1234 |
+
input_path,
|
| 1235 |
+
phase=phase,
|
| 1236 |
+
sampling=sampling,
|
| 1237 |
+
volume_points_sample=cfg.model.volume_points_sample,
|
| 1238 |
+
surface_points_sample=cfg.model.surface_points_sample,
|
| 1239 |
+
geom_points_sample=cfg.model.geom_points_sample,
|
| 1240 |
+
model_type=cfg.model.model_type,
|
| 1241 |
+
surface_sampling_algorithm=cfg.model.surface_sampling_algorithm,
|
| 1242 |
+
)
|
| 1243 |
+
else:
|
| 1244 |
+
# The dataset path works in two pieces:
|
| 1245 |
+
# There is a core "dataset" which is loading data and moving to GPU
|
| 1246 |
+
# And there is the preprocess step, here.
|
| 1247 |
+
|
| 1248 |
+
# Optionally, and for backwards compatibility, the preprocess
|
| 1249 |
+
# object can accept a dataset which will enable it as an iterator.
|
| 1250 |
+
# The iteration function will loop over the dataset, preprocess the
|
| 1251 |
+
# output, and return it.
|
| 1252 |
+
|
| 1253 |
+
overrides = {}
|
| 1254 |
+
if hasattr(cfg.data, "gpu_preprocessing"):
|
| 1255 |
+
overrides["gpu_preprocessing"] = cfg.data.gpu_preprocessing
|
| 1256 |
+
|
| 1257 |
+
if hasattr(cfg.data, "gpu_output"):
|
| 1258 |
+
overrides["gpu_output"] = cfg.data.gpu_output
|
| 1259 |
+
|
| 1260 |
+
dm = DistributedManager()
|
| 1261 |
+
|
| 1262 |
+
if cfg.data.gpu_preprocessing:
|
| 1263 |
+
device = dm.device
|
| 1264 |
+
consumer_stream = torch.cuda.default_stream()
|
| 1265 |
+
else:
|
| 1266 |
+
device = torch.device("cpu")
|
| 1267 |
+
consumer_stream = None
|
| 1268 |
+
|
| 1269 |
+
if dataloader_cfg is not None:
|
| 1270 |
+
preload_depth = dataloader_cfg.preload_depth
|
| 1271 |
+
pin_memory = dataloader_cfg.pin_memory
|
| 1272 |
+
else:
|
| 1273 |
+
preload_depth = 1
|
| 1274 |
+
pin_memory = False
|
| 1275 |
+
|
| 1276 |
+
dataset = CAEDataset(
|
| 1277 |
+
data_dir=input_path,
|
| 1278 |
+
keys_to_read=keys_to_read,
|
| 1279 |
+
keys_to_read_if_available=keys_to_read_if_available,
|
| 1280 |
+
output_device=device,
|
| 1281 |
+
preload_depth=preload_depth,
|
| 1282 |
+
pin_memory=pin_memory,
|
| 1283 |
+
device_mesh=device_mesh,
|
| 1284 |
+
placements=placements,
|
| 1285 |
+
consumer_stream=consumer_stream,
|
| 1286 |
+
)
|
| 1287 |
+
|
| 1288 |
+
# Domain parallelism configuration:
|
| 1289 |
+
# (By default, the dataset will shard as aggressively as possible,
|
| 1290 |
+
# to improve IO speed and prevent bottlenecks - the datapipe
|
| 1291 |
+
# has to reshard to the final shape.)
|
| 1292 |
+
|
| 1293 |
+
# NOTE: we can always capture the mesh and placements from the dataset
|
| 1294 |
+
# outputs, so no need to pass them here.
|
| 1295 |
+
if cfg.get("domain_parallelism", {}).get("domain_size", 1) > 1:
|
| 1296 |
+
shard_grid = cfg.get("domain_parallelism", {}).get("shard_grid", False)
|
| 1297 |
+
shard_points = cfg.get("domain_parallelism", {}).get("shard_points", False)
|
| 1298 |
+
overrides["shard_grid"] = shard_grid
|
| 1299 |
+
overrides["shard_points"] = shard_points
|
| 1300 |
+
|
| 1301 |
+
datapipe = DoMINODataPipe(
|
| 1302 |
+
input_path,
|
| 1303 |
+
phase=phase,
|
| 1304 |
+
grid_resolution=cfg.model.interp_res,
|
| 1305 |
+
normalize_coordinates=normalize_coordinates,
|
| 1306 |
+
sampling=sampling,
|
| 1307 |
+
sample_in_bbox=sample_in_bbox,
|
| 1308 |
+
volume_points_sample=cfg.model.volume_points_sample,
|
| 1309 |
+
surface_points_sample=cfg.model.surface_points_sample,
|
| 1310 |
+
geom_points_sample=cfg.model.geom_points_sample,
|
| 1311 |
+
volume_factors=vol_factors,
|
| 1312 |
+
surface_factors=surf_factors,
|
| 1313 |
+
scaling_type=cfg.model.normalization,
|
| 1314 |
+
model_type=model_type,
|
| 1315 |
+
bounding_box_dims=cfg.data.bounding_box,
|
| 1316 |
+
bounding_box_dims_surf=cfg.data.bounding_box_surface,
|
| 1317 |
+
volume_sample_from_disk=cfg.data.volume_sample_from_disk,
|
| 1318 |
+
num_surface_neighbors=cfg.model.num_neighbors_surface,
|
| 1319 |
+
surface_sampling_algorithm=cfg.model.surface_sampling_algorithm,
|
| 1320 |
+
**overrides,
|
| 1321 |
+
)
|
| 1322 |
+
|
| 1323 |
+
datapipe.set_dataset(dataset)
|
| 1324 |
+
|
| 1325 |
+
return datapipe
|
| 1326 |
+
|
| 1327 |
+
|
| 1328 |
+
if __name__ == "__main__":
|
| 1329 |
+
fm_data = DoMINODataPipe(
|
| 1330 |
+
data_path="/code/processed_data/new_models_1/",
|
| 1331 |
+
phase="train",
|
| 1332 |
+
sampling=False,
|
| 1333 |
+
sample_in_bbox=False,
|
| 1334 |
+
)
|
physics_mcp/source/physicsnemo/datapipes/cae/mesh_datapipe.py
ADDED
|
@@ -0,0 +1,490 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
import vtk
|
| 21 |
+
|
| 22 |
+
try:
|
| 23 |
+
import nvidia.dali as dali
|
| 24 |
+
import nvidia.dali.plugin.pytorch as dali_pth
|
| 25 |
+
except ImportError:
|
| 26 |
+
raise ImportError(
|
| 27 |
+
"DALI dataset requires NVIDIA DALI package to be installed. "
|
| 28 |
+
+ "The package can be installed at:\n"
|
| 29 |
+
+ "https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html"
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
from dataclasses import dataclass
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
from typing import Iterable, List, Tuple, Union
|
| 35 |
+
|
| 36 |
+
from torch import Tensor
|
| 37 |
+
|
| 38 |
+
from physicsnemo.datapipes.datapipe import Datapipe
|
| 39 |
+
from physicsnemo.datapipes.meta import DatapipeMetaData
|
| 40 |
+
|
| 41 |
+
from .readers import read_cgns, read_vtp, read_vtu
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
@dataclass
|
| 45 |
+
class MetaData(DatapipeMetaData):
|
| 46 |
+
name: str = "MeshDatapipe"
|
| 47 |
+
# Optimization
|
| 48 |
+
auto_device: bool = True
|
| 49 |
+
cuda_graphs: bool = True
|
| 50 |
+
# Parallel
|
| 51 |
+
ddp_sharding: bool = True
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class MeshDatapipe(Datapipe):
|
| 55 |
+
"""DALI data pipeline for mesh data
|
| 56 |
+
|
| 57 |
+
Parameters
|
| 58 |
+
----------
|
| 59 |
+
data_dir : str
|
| 60 |
+
Directory where ERA5 data is stored
|
| 61 |
+
variables : List[str, None]
|
| 62 |
+
Ordered list of variables to be loaded from the files
|
| 63 |
+
num_variables : int
|
| 64 |
+
Number of variables to be loaded from the files
|
| 65 |
+
file_format : str, optional
|
| 66 |
+
File format of the data, by default "vtp"
|
| 67 |
+
Supported formats: "vtp", "vtu", "cgns"
|
| 68 |
+
stats_dir : Union[str, None], optional
|
| 69 |
+
Directory where statistics are stored, by default None
|
| 70 |
+
If provided, the statistics are used to normalize the attributes
|
| 71 |
+
batch_size : int, optional
|
| 72 |
+
Batch size, by default 1
|
| 73 |
+
num_steps : int, optional
|
| 74 |
+
Number of timesteps are included in the output variables, by default 1
|
| 75 |
+
shuffle : bool, optional
|
| 76 |
+
Shuffle dataset, by default True
|
| 77 |
+
num_workers : int, optional
|
| 78 |
+
Number of workers, by default 1
|
| 79 |
+
device: Union[str, torch.device], optional
|
| 80 |
+
Device for DALI pipeline to run on, by default cuda
|
| 81 |
+
process_rank : int, optional
|
| 82 |
+
Rank ID of local process, by default 0
|
| 83 |
+
world_size : int, optional
|
| 84 |
+
Number of training processes, by default 1
|
| 85 |
+
cache_data : False, optional
|
| 86 |
+
Whether to cache the data in memory for faster access in subsequent epochs, by default False
|
| 87 |
+
Parallel: True, optional
|
| 88 |
+
Setting parallel=True for an external_source node indicates to the pipeline to run the source in Python worker processes started by DALI.
|
| 89 |
+
"""
|
| 90 |
+
|
| 91 |
+
def __init__(
|
| 92 |
+
self,
|
| 93 |
+
data_dir: str,
|
| 94 |
+
variables: List[str],
|
| 95 |
+
num_variables: int,
|
| 96 |
+
file_format: str = "vtp",
|
| 97 |
+
stats_dir: Union[str, None] = None,
|
| 98 |
+
batch_size: int = 1,
|
| 99 |
+
num_samples: int = 1,
|
| 100 |
+
shuffle: bool = True,
|
| 101 |
+
num_workers: int = 1,
|
| 102 |
+
device: Union[str, torch.device] = "cuda",
|
| 103 |
+
process_rank: int = 0,
|
| 104 |
+
world_size: int = 1,
|
| 105 |
+
cache_data: bool = False,
|
| 106 |
+
parallel: bool = True,
|
| 107 |
+
):
|
| 108 |
+
super().__init__(meta=MetaData())
|
| 109 |
+
self.file_format = file_format
|
| 110 |
+
self.variables = variables
|
| 111 |
+
self.num_variables = num_variables
|
| 112 |
+
self.batch_size = batch_size
|
| 113 |
+
self.num_workers = num_workers
|
| 114 |
+
self.shuffle = shuffle
|
| 115 |
+
self.data_dir = Path(data_dir)
|
| 116 |
+
self.stats_dir = Path(stats_dir) if stats_dir is not None else None
|
| 117 |
+
self.num_samples = num_samples
|
| 118 |
+
self.process_rank = process_rank
|
| 119 |
+
self.world_size = world_size
|
| 120 |
+
self.cache_data = cache_data
|
| 121 |
+
self.parallel = parallel
|
| 122 |
+
|
| 123 |
+
# if self.batch_size > 1:
|
| 124 |
+
# raise NotImplementedError("Batch size greater than 1 is not supported yet")
|
| 125 |
+
|
| 126 |
+
# Set up device, needed for pipeline
|
| 127 |
+
if isinstance(device, str):
|
| 128 |
+
device = torch.device(device)
|
| 129 |
+
# Need a index id if cuda
|
| 130 |
+
if device.type == "cuda" and device.index is None:
|
| 131 |
+
device = torch.device("cuda:0")
|
| 132 |
+
self.device = device
|
| 133 |
+
|
| 134 |
+
# check root directory exists
|
| 135 |
+
if not self.data_dir.is_dir():
|
| 136 |
+
raise IOError(f"Error, data directory {self.data_dir} does not exist")
|
| 137 |
+
|
| 138 |
+
self.parse_dataset_files()
|
| 139 |
+
self.load_statistics()
|
| 140 |
+
|
| 141 |
+
self.pipe = self._create_pipeline()
|
| 142 |
+
|
| 143 |
+
def parse_dataset_files(self) -> None:
|
| 144 |
+
"""Parses the data directory for valid files and determines training samples
|
| 145 |
+
|
| 146 |
+
Raises
|
| 147 |
+
------
|
| 148 |
+
ValueError
|
| 149 |
+
In channels specified or number of samples per year is not valid
|
| 150 |
+
"""
|
| 151 |
+
# get all input data files
|
| 152 |
+
match self.file_format:
|
| 153 |
+
case "vtp":
|
| 154 |
+
pattern = "*.vtp"
|
| 155 |
+
case "vtu":
|
| 156 |
+
pattern = "*.vtu"
|
| 157 |
+
case "cgns":
|
| 158 |
+
pattern = "*.cgns"
|
| 159 |
+
case _:
|
| 160 |
+
raise NotImplementedError(
|
| 161 |
+
f"Data type {self.file_format} is not supported yet"
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
self.data_paths = sorted(str(path) for path in self.data_dir.glob(pattern))
|
| 165 |
+
|
| 166 |
+
for data_path in self.data_paths:
|
| 167 |
+
self.logger.info(f"File found: {data_path}")
|
| 168 |
+
self.total_samples = len(self.data_paths)
|
| 169 |
+
|
| 170 |
+
if self.num_samples > self.total_samples:
|
| 171 |
+
raise ValueError(
|
| 172 |
+
"Number of requested samples is greater than the total number of available samples!"
|
| 173 |
+
)
|
| 174 |
+
self.logger.info(
|
| 175 |
+
f"Total number of samples: {self.total_samples}, number of requested samples: {self.num_samples}"
|
| 176 |
+
)
|
| 177 |
+
|
| 178 |
+
def load_statistics(
|
| 179 |
+
self,
|
| 180 |
+
) -> None: # TODO generalize and combine with climate/era5_hdf5 datapipes
|
| 181 |
+
"""Loads statistics from pre-computed numpy files
|
| 182 |
+
|
| 183 |
+
The statistic files should be of name global_means.npy and global_std.npy with
|
| 184 |
+
a shape of [1, C] located in the stat_dir.
|
| 185 |
+
|
| 186 |
+
Raises
|
| 187 |
+
------
|
| 188 |
+
IOError
|
| 189 |
+
If mean or std numpy files are not found
|
| 190 |
+
AssertionError
|
| 191 |
+
If loaded numpy arrays are not of correct size
|
| 192 |
+
"""
|
| 193 |
+
# If no stats dir we just skip loading the stats
|
| 194 |
+
if self.stats_dir is None:
|
| 195 |
+
self.mu = None
|
| 196 |
+
self.std = None
|
| 197 |
+
return
|
| 198 |
+
# load normalisation values
|
| 199 |
+
mean_stat_file = self.stats_dir / Path("global_means.npy")
|
| 200 |
+
std_stat_file = self.stats_dir / Path("global_stds.npy")
|
| 201 |
+
|
| 202 |
+
if not mean_stat_file.exists():
|
| 203 |
+
raise IOError(f"Mean statistics file {mean_stat_file} not found")
|
| 204 |
+
if not std_stat_file.exists():
|
| 205 |
+
raise IOError(f"Std statistics file {std_stat_file} not found")
|
| 206 |
+
|
| 207 |
+
# has shape [1, C]
|
| 208 |
+
self.mu = np.load(str(mean_stat_file))[:, 0 : self.num_variables]
|
| 209 |
+
# has shape [1, C]
|
| 210 |
+
self.sd = np.load(str(std_stat_file))[:, 0 : self.num_variables]
|
| 211 |
+
|
| 212 |
+
if not self.mu.shape == self.sd.shape == (1, self.num_variables):
|
| 213 |
+
raise AssertionError("Error, normalisation arrays have wrong shape")
|
| 214 |
+
|
| 215 |
+
def _create_pipeline(self) -> dali.Pipeline:
|
| 216 |
+
"""Create DALI pipeline
|
| 217 |
+
|
| 218 |
+
Returns
|
| 219 |
+
-------
|
| 220 |
+
dali.Pipeline
|
| 221 |
+
Mesh DALI pipeline
|
| 222 |
+
"""
|
| 223 |
+
pipe = dali.Pipeline(
|
| 224 |
+
batch_size=self.batch_size,
|
| 225 |
+
num_threads=2,
|
| 226 |
+
prefetch_queue_depth=2,
|
| 227 |
+
py_num_workers=self.num_workers,
|
| 228 |
+
device_id=self.device.index,
|
| 229 |
+
py_start_method="spawn",
|
| 230 |
+
)
|
| 231 |
+
|
| 232 |
+
with pipe:
|
| 233 |
+
source = MeshDaliExternalSource(
|
| 234 |
+
data_paths=self.data_paths,
|
| 235 |
+
file_format=self.file_format,
|
| 236 |
+
variables=self.variables,
|
| 237 |
+
num_samples=self.num_samples,
|
| 238 |
+
batch_size=self.batch_size,
|
| 239 |
+
shuffle=self.shuffle,
|
| 240 |
+
process_rank=self.process_rank,
|
| 241 |
+
world_size=self.world_size,
|
| 242 |
+
cache_data=self.cache_data,
|
| 243 |
+
)
|
| 244 |
+
# Update length of dataset
|
| 245 |
+
self.length = len(source) // self.batch_size
|
| 246 |
+
# Read current batch.
|
| 247 |
+
vertices, attributes, edges = dali.fn.external_source(
|
| 248 |
+
source,
|
| 249 |
+
num_outputs=3,
|
| 250 |
+
parallel=self.parallel,
|
| 251 |
+
batch=False,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
if self.device.type == "cuda":
|
| 255 |
+
# Move tensors to GPU as external_source won't do that.
|
| 256 |
+
vertices = vertices.gpu()
|
| 257 |
+
attributes = attributes.gpu()
|
| 258 |
+
edges = edges.gpu()
|
| 259 |
+
|
| 260 |
+
# Normalize attributes if statistics are available.
|
| 261 |
+
if self.stats_dir is not None:
|
| 262 |
+
attributes = dali.fn.normalize(attributes, mean=self.mu, stddev=self.sd)
|
| 263 |
+
|
| 264 |
+
# Set outputs.
|
| 265 |
+
pipe.set_outputs(vertices, attributes, edges)
|
| 266 |
+
|
| 267 |
+
return pipe
|
| 268 |
+
|
| 269 |
+
def __iter__(self):
|
| 270 |
+
# Reset the pipeline before creating an iterator to enable epochs.
|
| 271 |
+
self.pipe.reset()
|
| 272 |
+
# Create DALI PyTorch iterator.
|
| 273 |
+
return dali_pth.DALIGenericIterator([self.pipe], ["vertices", "x", "edges"])
|
| 274 |
+
|
| 275 |
+
def __len__(self):
|
| 276 |
+
return self.length
|
| 277 |
+
|
| 278 |
+
|
| 279 |
+
class MeshDaliExternalSource:
|
| 280 |
+
"""DALI Source for lazy-loading with caching of mesh data
|
| 281 |
+
|
| 282 |
+
Parameters
|
| 283 |
+
----------
|
| 284 |
+
data_paths : Iterable[str]
|
| 285 |
+
Directory where data is stored
|
| 286 |
+
num_samples : int
|
| 287 |
+
Total number of training samples
|
| 288 |
+
batch_size : int, optional
|
| 289 |
+
Batch size, by default 1
|
| 290 |
+
shuffle : bool, optional
|
| 291 |
+
Shuffle dataset, by default True
|
| 292 |
+
process_rank : int, optional
|
| 293 |
+
Rank ID of local process, by default 0
|
| 294 |
+
world_size : int, optional
|
| 295 |
+
Number of training processes, by default 1
|
| 296 |
+
cache_data : False, optional
|
| 297 |
+
Whether to cache the data in memory for faster access in subsequent epochs, by default False
|
| 298 |
+
|
| 299 |
+
Note
|
| 300 |
+
----
|
| 301 |
+
For more information about DALI external source operator:
|
| 302 |
+
https://docs.nvidia.com/deeplearning/dali/archives/dali_1_13_0/user-guide/docs/examples/general/data_loading/parallel_external_source.html
|
| 303 |
+
"""
|
| 304 |
+
|
| 305 |
+
def __init__(
|
| 306 |
+
self,
|
| 307 |
+
data_paths: Iterable[str],
|
| 308 |
+
file_format: str,
|
| 309 |
+
variables: List[str],
|
| 310 |
+
num_samples: int,
|
| 311 |
+
batch_size: int = 1,
|
| 312 |
+
shuffle: bool = True,
|
| 313 |
+
process_rank: int = 0,
|
| 314 |
+
world_size: int = 1,
|
| 315 |
+
cache_data: bool = False,
|
| 316 |
+
):
|
| 317 |
+
self.data_paths = list(data_paths)
|
| 318 |
+
self.file_format = file_format
|
| 319 |
+
self.variables = variables
|
| 320 |
+
# Will be populated later once each worker starts running in its own process.
|
| 321 |
+
self.poly_data = None
|
| 322 |
+
self.num_samples = num_samples
|
| 323 |
+
self.batch_size = batch_size
|
| 324 |
+
self.shuffle = shuffle
|
| 325 |
+
self.cache_data = cache_data
|
| 326 |
+
|
| 327 |
+
self.last_epoch = None
|
| 328 |
+
|
| 329 |
+
self.indices = np.arange(num_samples)
|
| 330 |
+
# Shard from indices if running in parallel
|
| 331 |
+
self.indices = np.array_split(self.indices, world_size)[process_rank]
|
| 332 |
+
|
| 333 |
+
# Get number of full batches, ignore possible last incomplete batch for now.
|
| 334 |
+
# Also, DALI external source does not support incomplete batches in parallel mode.
|
| 335 |
+
self.num_batches = len(self.indices) // self.batch_size
|
| 336 |
+
|
| 337 |
+
self.mesh_reader_fn = self.mesh_reader()
|
| 338 |
+
self.parse_vtk_data_fn = self.parse_vtk_data()
|
| 339 |
+
|
| 340 |
+
if self.cache_data:
|
| 341 |
+
# Make cache for the data
|
| 342 |
+
self.data_cache = {}
|
| 343 |
+
for data_path in self.data_paths:
|
| 344 |
+
self.data_cache[data_path] = None
|
| 345 |
+
|
| 346 |
+
def __call__(self, sample_info: dali.types.SampleInfo) -> Tuple[Tensor, Tensor]:
|
| 347 |
+
if sample_info.iteration >= self.num_batches:
|
| 348 |
+
raise StopIteration()
|
| 349 |
+
|
| 350 |
+
# Shuffle before the next epoch starts.
|
| 351 |
+
if self.shuffle and sample_info.epoch_idx != self.last_epoch:
|
| 352 |
+
# All workers use the same rng seed so the resulting
|
| 353 |
+
# indices are the same across workers.
|
| 354 |
+
np.random.default_rng(seed=sample_info.epoch_idx).shuffle(self.indices)
|
| 355 |
+
self.last_epoch = sample_info.epoch_idx
|
| 356 |
+
|
| 357 |
+
# Get local indices from global index.
|
| 358 |
+
idx = self.indices[sample_info.idx_in_epoch]
|
| 359 |
+
|
| 360 |
+
# if self.poly_data is None: # TODO check
|
| 361 |
+
# This will be called once per worker. Workers are persistent,
|
| 362 |
+
# so there is no need to explicitly close the files - this will be done
|
| 363 |
+
# when corresponding pipeline/dataset is destroyed.
|
| 364 |
+
if self.cache_data:
|
| 365 |
+
processed_data = self.data_cache.get(self.data_paths[idx])
|
| 366 |
+
if processed_data is None:
|
| 367 |
+
data = self.mesh_reader_fn(self.data_paths[idx])
|
| 368 |
+
processed_data = self.parse_vtk_data_fn(data, self.variables)
|
| 369 |
+
self.data_cache[self.data_paths[idx]] = processed_data
|
| 370 |
+
else:
|
| 371 |
+
data = self.mesh_reader_fn(self.data_paths[idx])
|
| 372 |
+
processed_data = self.parse_vtk_data_fn(data, self.variables)
|
| 373 |
+
|
| 374 |
+
return processed_data
|
| 375 |
+
|
| 376 |
+
def __len__(self):
|
| 377 |
+
return len(self.indices)
|
| 378 |
+
|
| 379 |
+
def mesh_reader(self):
|
| 380 |
+
if self.file_format == "vtp":
|
| 381 |
+
return read_vtp
|
| 382 |
+
if self.file_format == "vtu":
|
| 383 |
+
return read_vtu
|
| 384 |
+
if self.file_format == "cgns":
|
| 385 |
+
return read_cgns
|
| 386 |
+
else:
|
| 387 |
+
raise NotImplementedError(
|
| 388 |
+
f"Data type {self.file_format} is not supported yet"
|
| 389 |
+
)
|
| 390 |
+
|
| 391 |
+
def parse_vtk_data(self):
|
| 392 |
+
if self.file_format == "vtp":
|
| 393 |
+
return _parse_vtk_polydata
|
| 394 |
+
elif self.file_format in ["vtu", "cgns"]:
|
| 395 |
+
return _parse_vtk_unstructuredgrid
|
| 396 |
+
else:
|
| 397 |
+
raise NotImplementedError(
|
| 398 |
+
f"Data type {self.file_format} is not supported yet"
|
| 399 |
+
)
|
| 400 |
+
|
| 401 |
+
|
| 402 |
+
def _parse_vtk_polydata(polydata, variables):
|
| 403 |
+
# Fetch vertices
|
| 404 |
+
points = polydata.GetPoints()
|
| 405 |
+
if points is None:
|
| 406 |
+
raise ValueError("Failed to get points from the polydata.")
|
| 407 |
+
vertices = torch.tensor(
|
| 408 |
+
np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]),
|
| 409 |
+
dtype=torch.float32,
|
| 410 |
+
)
|
| 411 |
+
|
| 412 |
+
# Fetch node attributes # TODO modularize
|
| 413 |
+
attributes = []
|
| 414 |
+
point_data = polydata.GetPointData()
|
| 415 |
+
if point_data is None:
|
| 416 |
+
raise ValueError("Failed to get point data from the unstructured grid.")
|
| 417 |
+
for array_name in variables:
|
| 418 |
+
try:
|
| 419 |
+
array = point_data.GetArray(array_name)
|
| 420 |
+
except ValueError:
|
| 421 |
+
raise ValueError(
|
| 422 |
+
f"Failed to get array {array_name} from the unstructured grid."
|
| 423 |
+
)
|
| 424 |
+
array_data = np.zeros(
|
| 425 |
+
(points.GetNumberOfPoints(), array.GetNumberOfComponents())
|
| 426 |
+
)
|
| 427 |
+
for j in range(points.GetNumberOfPoints()):
|
| 428 |
+
array.GetTuple(j, array_data[j])
|
| 429 |
+
attributes.append(torch.tensor(array_data, dtype=torch.float32))
|
| 430 |
+
attributes = torch.cat(attributes, dim=-1)
|
| 431 |
+
# TODO torch.cat is usually very inefficient when the number of items is large.
|
| 432 |
+
# If possible, the resulting tensor should be pre-allocated and filled in during the loop.
|
| 433 |
+
|
| 434 |
+
# Fetch edges
|
| 435 |
+
polys = polydata.GetPolys()
|
| 436 |
+
if polys is None:
|
| 437 |
+
raise ValueError("Failed to get polygons from the polydata.")
|
| 438 |
+
polys.InitTraversal()
|
| 439 |
+
edges = []
|
| 440 |
+
id_list = vtk.vtkIdList()
|
| 441 |
+
for _ in range(polys.GetNumberOfCells()):
|
| 442 |
+
polys.GetNextCell(id_list)
|
| 443 |
+
num_ids = id_list.GetNumberOfIds()
|
| 444 |
+
edges = [
|
| 445 |
+
(id_list.GetId(j), id_list.GetId((j + 1) % num_ids)) for j in range(num_ids)
|
| 446 |
+
]
|
| 447 |
+
edges = torch.tensor(edges, dtype=torch.long)
|
| 448 |
+
|
| 449 |
+
return vertices, attributes, edges
|
| 450 |
+
|
| 451 |
+
|
| 452 |
+
def _parse_vtk_unstructuredgrid(grid, variables):
|
| 453 |
+
# Fetch vertices
|
| 454 |
+
points = grid.GetPoints()
|
| 455 |
+
if points is None:
|
| 456 |
+
raise ValueError("Failed to get points from the unstructured grid.")
|
| 457 |
+
vertices = torch.tensor(
|
| 458 |
+
np.array([points.GetPoint(i) for i in range(points.GetNumberOfPoints())]),
|
| 459 |
+
dtype=torch.float32,
|
| 460 |
+
)
|
| 461 |
+
|
| 462 |
+
# Fetch node attributes # TODO modularize
|
| 463 |
+
attributes = []
|
| 464 |
+
point_data = grid.GetPointData()
|
| 465 |
+
if point_data is None:
|
| 466 |
+
raise ValueError("Failed to get point data from the unstructured grid.")
|
| 467 |
+
for array_name in variables:
|
| 468 |
+
try:
|
| 469 |
+
array = point_data.GetArray(array_name)
|
| 470 |
+
except ValueError:
|
| 471 |
+
raise ValueError(
|
| 472 |
+
f"Failed to get array {array_name} from the unstructured grid."
|
| 473 |
+
)
|
| 474 |
+
array_data = np.zeros(
|
| 475 |
+
(points.GetNumberOfPoints(), array.GetNumberOfComponents())
|
| 476 |
+
)
|
| 477 |
+
for j in range(points.GetNumberOfPoints()):
|
| 478 |
+
array.GetTuple(j, array_data[j])
|
| 479 |
+
attributes.append(torch.tensor(array_data, dtype=torch.float32))
|
| 480 |
+
if variables:
|
| 481 |
+
attributes = torch.cat(attributes, dim=-1)
|
| 482 |
+
else:
|
| 483 |
+
attributes = torch.zeros((1,), dtype=torch.float32)
|
| 484 |
+
|
| 485 |
+
# Return a dummy tensor of zeros for edges since they are not directly computable
|
| 486 |
+
return (
|
| 487 |
+
vertices,
|
| 488 |
+
attributes,
|
| 489 |
+
torch.zeros((0, 2), dtype=torch.long),
|
| 490 |
+
) # Dummy tensor for edges
|
physics_mcp/source/physicsnemo/datapipes/cae/readers.py
ADDED
|
@@ -0,0 +1,191 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import os
|
| 18 |
+
from typing import Any
|
| 19 |
+
|
| 20 |
+
import torch
|
| 21 |
+
import vtk
|
| 22 |
+
|
| 23 |
+
Tensor = torch.Tensor
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
def read_vtp(file_path: str) -> Any: # TODO add support for older format (VTK)
|
| 27 |
+
"""
|
| 28 |
+
Read a VTP file and return the polydata.
|
| 29 |
+
|
| 30 |
+
Parameters
|
| 31 |
+
----------
|
| 32 |
+
file_path : str
|
| 33 |
+
Path to the VTP file.
|
| 34 |
+
|
| 35 |
+
Returns
|
| 36 |
+
-------
|
| 37 |
+
vtkPolyData
|
| 38 |
+
The polydata read from the VTP file.
|
| 39 |
+
"""
|
| 40 |
+
# Check if file exists
|
| 41 |
+
if not os.path.exists(file_path):
|
| 42 |
+
raise FileNotFoundError(f"{file_path} does not exist.")
|
| 43 |
+
|
| 44 |
+
# Check if file has .vtp extension
|
| 45 |
+
if not file_path.endswith(".vtp"):
|
| 46 |
+
raise ValueError(f"Expected a .vtp file, got {file_path}")
|
| 47 |
+
|
| 48 |
+
reader = vtk.vtkXMLPolyDataReader()
|
| 49 |
+
reader.SetFileName(file_path)
|
| 50 |
+
reader.Update()
|
| 51 |
+
|
| 52 |
+
# Get the polydata
|
| 53 |
+
polydata = reader.GetOutput()
|
| 54 |
+
|
| 55 |
+
# Check if polydata is valid
|
| 56 |
+
if polydata is None:
|
| 57 |
+
raise ValueError(f"Failed to read polydata from {file_path}")
|
| 58 |
+
|
| 59 |
+
return polydata
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def read_vtu(file_path: str) -> Any:
|
| 63 |
+
"""
|
| 64 |
+
Read a VTU file and return the unstructured grid data.
|
| 65 |
+
|
| 66 |
+
Parameters
|
| 67 |
+
----------
|
| 68 |
+
file_path : str
|
| 69 |
+
Path to the VTU file.
|
| 70 |
+
|
| 71 |
+
Returns
|
| 72 |
+
-------
|
| 73 |
+
vtkUnstructuredGrid
|
| 74 |
+
The unstructured grid data read from the VTU file.
|
| 75 |
+
"""
|
| 76 |
+
# Check if file exists
|
| 77 |
+
if not os.path.exists(file_path):
|
| 78 |
+
raise FileNotFoundError(f"{file_path} does not exist.")
|
| 79 |
+
|
| 80 |
+
# Check if file has .vtu extension
|
| 81 |
+
if not file_path.endswith(".vtu"):
|
| 82 |
+
raise ValueError(f"Expected a .vtu file, got {file_path}")
|
| 83 |
+
|
| 84 |
+
reader = vtk.vtkXMLUnstructuredGridReader()
|
| 85 |
+
reader.SetFileName(file_path)
|
| 86 |
+
reader.Update()
|
| 87 |
+
|
| 88 |
+
# Get the unstructured grid data
|
| 89 |
+
grid = reader.GetOutput()
|
| 90 |
+
|
| 91 |
+
# Check if grid is valid
|
| 92 |
+
if grid is None:
|
| 93 |
+
raise ValueError(f"Failed to read unstructured grid data from {file_path}")
|
| 94 |
+
|
| 95 |
+
return grid
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
def read_cgns(file_path: str) -> Any:
|
| 99 |
+
"""
|
| 100 |
+
Read a CGNS file and return the unstructured grid data.
|
| 101 |
+
|
| 102 |
+
Parameters
|
| 103 |
+
----------
|
| 104 |
+
file_path : str
|
| 105 |
+
Path to the CGNS file.
|
| 106 |
+
|
| 107 |
+
Returns
|
| 108 |
+
-------
|
| 109 |
+
vtkUnstructuredGrid
|
| 110 |
+
The unstructured grid data read from the CGNS file.
|
| 111 |
+
"""
|
| 112 |
+
# Check if file exists
|
| 113 |
+
if not os.path.exists(file_path):
|
| 114 |
+
raise FileNotFoundError(f"{file_path} does not exist.")
|
| 115 |
+
|
| 116 |
+
# Check if file has .cgns extension
|
| 117 |
+
if not file_path.endswith(".cgns"):
|
| 118 |
+
raise ValueError(f"Expected a .cgns file, got {file_path}")
|
| 119 |
+
|
| 120 |
+
reader = vtk.vtkCGNSReader()
|
| 121 |
+
reader.SetFileName(file_path)
|
| 122 |
+
reader.Update()
|
| 123 |
+
|
| 124 |
+
# Get the multi-block dataset
|
| 125 |
+
multi_block = reader.GetOutput()
|
| 126 |
+
|
| 127 |
+
# Check if the multi-block dataset is valid
|
| 128 |
+
if multi_block is None:
|
| 129 |
+
raise ValueError(f"Failed to read multi-block data from {file_path}")
|
| 130 |
+
|
| 131 |
+
# Extract and return the vtkUnstructuredGrid from the multi-block dataset
|
| 132 |
+
return _extract_unstructured_grid(multi_block)
|
| 133 |
+
|
| 134 |
+
|
| 135 |
+
def read_stl(file_path: str) -> vtk.vtkPolyData:
|
| 136 |
+
"""
|
| 137 |
+
Read an STL file and return the polydata.
|
| 138 |
+
|
| 139 |
+
Parameters
|
| 140 |
+
----------
|
| 141 |
+
file_path : str
|
| 142 |
+
Path to the STL file.
|
| 143 |
+
|
| 144 |
+
Returns
|
| 145 |
+
-------
|
| 146 |
+
vtkPolyData
|
| 147 |
+
The polydata read from the STL file.
|
| 148 |
+
"""
|
| 149 |
+
# Check if file exists
|
| 150 |
+
if not os.path.exists(file_path):
|
| 151 |
+
raise FileNotFoundError(f"{file_path} does not exist.")
|
| 152 |
+
|
| 153 |
+
# Check if file has .stl extension
|
| 154 |
+
if not file_path.endswith(".stl"):
|
| 155 |
+
raise ValueError(f"Expected a .stl file, got {file_path}")
|
| 156 |
+
|
| 157 |
+
# Create an STL reader
|
| 158 |
+
reader = vtk.vtkSTLReader()
|
| 159 |
+
reader.SetFileName(file_path)
|
| 160 |
+
reader.Update()
|
| 161 |
+
|
| 162 |
+
# Get the polydata
|
| 163 |
+
polydata = reader.GetOutput()
|
| 164 |
+
|
| 165 |
+
# Check if polydata is valid
|
| 166 |
+
if polydata is None:
|
| 167 |
+
raise ValueError(f"Failed to read polydata from {file_path}")
|
| 168 |
+
|
| 169 |
+
return polydata
|
| 170 |
+
|
| 171 |
+
|
| 172 |
+
def _extract_unstructured_grid(
|
| 173 |
+
multi_block: vtk.vtkMultiBlockDataSet,
|
| 174 |
+
) -> vtk.vtkUnstructuredGrid:
|
| 175 |
+
"""
|
| 176 |
+
Extracts a vtkUnstructuredGrid from a vtkMultiBlockDataSet.
|
| 177 |
+
|
| 178 |
+
Parameters
|
| 179 |
+
----------
|
| 180 |
+
multi_block : vtk.vtkMultiBlockDataSet
|
| 181 |
+
The multi-block dataset containing various data blocks.
|
| 182 |
+
|
| 183 |
+
Returns
|
| 184 |
+
-------
|
| 185 |
+
vtk.vtkUnstructuredGrid
|
| 186 |
+
The unstructured grid extracted from the multi-block dataset.
|
| 187 |
+
"""
|
| 188 |
+
block = multi_block.GetBlock(0).GetBlock(0)
|
| 189 |
+
if isinstance(block, vtk.vtkUnstructuredGrid):
|
| 190 |
+
return block
|
| 191 |
+
raise ValueError("No vtkUnstructuredGrid found in the vtkMultiBlockDataSet.")
|
physics_mcp/source/physicsnemo/datapipes/climate/__init__.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from .climate import ClimateDatapipe, ClimateDataSourceSpec
|
| 18 |
+
from .era5_hdf5 import ERA5HDF5Datapipe
|
| 19 |
+
from .synthetic import SyntheticWeatherDataLoader, SyntheticWeatherDataset
|
physics_mcp/source/physicsnemo/datapipes/climate/climate.py
ADDED
|
@@ -0,0 +1,813 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import json
|
| 19 |
+
from abc import ABC, abstractmethod
|
| 20 |
+
from datetime import datetime, timedelta
|
| 21 |
+
from itertools import chain
|
| 22 |
+
|
| 23 |
+
import h5py
|
| 24 |
+
import netCDF4 as nc
|
| 25 |
+
import numpy as np
|
| 26 |
+
import pytz
|
| 27 |
+
import torch
|
| 28 |
+
|
| 29 |
+
try:
|
| 30 |
+
import nvidia.dali as dali
|
| 31 |
+
import nvidia.dali.plugin.pytorch as dali_pth
|
| 32 |
+
except ImportError:
|
| 33 |
+
raise ImportError(
|
| 34 |
+
"DALI dataset requires NVIDIA DALI package to be installed. "
|
| 35 |
+
+ "The package can be installed at:\n"
|
| 36 |
+
+ "https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html"
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
from dataclasses import dataclass
|
| 40 |
+
from pathlib import Path
|
| 41 |
+
from typing import Callable, Iterable, List, Mapping, Tuple, Union
|
| 42 |
+
|
| 43 |
+
from scipy.io import netcdf_file
|
| 44 |
+
|
| 45 |
+
from physicsnemo.datapipes.climate.utils.invariant import latlon_grid
|
| 46 |
+
from physicsnemo.datapipes.climate.utils.zenith_angle import cos_zenith_angle
|
| 47 |
+
from physicsnemo.datapipes.datapipe import Datapipe
|
| 48 |
+
from physicsnemo.datapipes.meta import DatapipeMetaData
|
| 49 |
+
from physicsnemo.launch.logging import PythonLogger
|
| 50 |
+
|
| 51 |
+
Tensor = torch.Tensor
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
@dataclass
|
| 55 |
+
class MetaData(DatapipeMetaData):
|
| 56 |
+
name: str = "Climate"
|
| 57 |
+
# Optimization
|
| 58 |
+
auto_device: bool = True
|
| 59 |
+
cuda_graphs: bool = True
|
| 60 |
+
# Parallel
|
| 61 |
+
ddp_sharding: bool = True
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
class ClimateDataSourceSpec:
|
| 65 |
+
"""
|
| 66 |
+
A data source specification for ClimateDatapipe.
|
| 67 |
+
|
| 68 |
+
HDF5 files should contain the following variable with the corresponding
|
| 69 |
+
name:
|
| 70 |
+
`fields`: Tensor of shape (num_timesteps, num_channels, height, width),
|
| 71 |
+
containing climate data. The order of the channels should match the order
|
| 72 |
+
of the channels in the statistics files. The statistics files should be
|
| 73 |
+
`.npy` files with the shape (1, num_channels, 1, 1).
|
| 74 |
+
The names of the variables are found in the metadata file found in
|
| 75 |
+
`metadata_path`.
|
| 76 |
+
|
| 77 |
+
NetCDF4 files should contain a variable of shape
|
| 78 |
+
(num_timesteps, height, width) for each variable they provide. Only the
|
| 79 |
+
variables listed in `variables` will be loaded.
|
| 80 |
+
|
| 81 |
+
Parameters
|
| 82 |
+
----------
|
| 83 |
+
data_dir : str
|
| 84 |
+
Directory where climate data is stored
|
| 85 |
+
name: Union[str, None], optional
|
| 86 |
+
The name that is used to label datapipe outputs from this source.
|
| 87 |
+
If None, the datapipe uses the number of the source in sequential order.
|
| 88 |
+
file_type: str
|
| 89 |
+
Type of files to read, supported values are "hdf5" (default) and "netcdf4"
|
| 90 |
+
stats_files: Union[Mapping[str, str], None], optional
|
| 91 |
+
Numpy files to data statistics for normalization. Supports either a channels
|
| 92 |
+
format, in which case the dict should contain the keys "mean" and "std", or a
|
| 93 |
+
named-variable format, in which case the dict should contain the key "norm" .
|
| 94 |
+
If None, no normalization will be used, by default None
|
| 95 |
+
metadata_path: Union[Mapping[str, str], None], optional for NetCDF, required for HDF5
|
| 96 |
+
Path to the metadata JSON file for the dataset (usually called data.json).
|
| 97 |
+
channels : Union[List[int], None], optional
|
| 98 |
+
Defines which climate variables to load, if None will use all in HDF5 file, by default None
|
| 99 |
+
variables: Union[List[str], None], optional for HDF5 files, mandatory for NetCDF4 files
|
| 100 |
+
List of named variables to load. Variables will be read in the order specified
|
| 101 |
+
by this parameter. Must be used for NetCDF4 files. Supported for HDF5 files
|
| 102 |
+
in which case it will override `channels`.
|
| 103 |
+
use_cos_zenith: bool, optional
|
| 104 |
+
If True, the cosine zenith angles corresponding to the coordinates of this
|
| 105 |
+
data source will be produced, default False
|
| 106 |
+
aux_variables : Union[Mapping[str, Callable], None], optional
|
| 107 |
+
A dictionary mapping strings to callables that accept arguments
|
| 108 |
+
(timestamps: numpy.ndarray, latlon: numpy.ndarray). These define any auxiliary
|
| 109 |
+
variables returned from this source.
|
| 110 |
+
num_steps : int, optional
|
| 111 |
+
Number of timesteps to return, by default 1
|
| 112 |
+
stride : int, optional
|
| 113 |
+
Number of steps between input and output variables. For example, if the dataset
|
| 114 |
+
contains data at every 6 hours, a stride 1 = 6 hour delta t and
|
| 115 |
+
stride 2 = 12 hours delta t, by default 1
|
| 116 |
+
"""
|
| 117 |
+
|
| 118 |
+
def __init__(
|
| 119 |
+
self,
|
| 120 |
+
data_dir: str,
|
| 121 |
+
name: Union[str, None] = None,
|
| 122 |
+
file_type: str = "hdf5",
|
| 123 |
+
stats_files: Union[Mapping[str, str], None] = None,
|
| 124 |
+
metadata_path: Union[str, None] = None,
|
| 125 |
+
channels: Union[List[int], None] = None,
|
| 126 |
+
variables: Union[List[str], None] = None,
|
| 127 |
+
use_cos_zenith: bool = False,
|
| 128 |
+
aux_variables: Union[Mapping[str, Callable], None] = None,
|
| 129 |
+
num_steps: int = 1,
|
| 130 |
+
stride: int = 1,
|
| 131 |
+
backend_kwargs: Union[dict, None] = None,
|
| 132 |
+
):
|
| 133 |
+
self.data_dir = Path(data_dir)
|
| 134 |
+
self.name = name
|
| 135 |
+
self.file_type = file_type
|
| 136 |
+
self.stats_files = (
|
| 137 |
+
{k: Path(fn) for (k, fn) in stats_files.items()}
|
| 138 |
+
if stats_files is not None
|
| 139 |
+
else None
|
| 140 |
+
)
|
| 141 |
+
self.metadata_path = Path(metadata_path) if metadata_path is not None else None
|
| 142 |
+
self.channels = channels
|
| 143 |
+
self.variables = variables
|
| 144 |
+
self.use_cos_zenith = use_cos_zenith
|
| 145 |
+
self.aux_variables = aux_variables if aux_variables is not None else {}
|
| 146 |
+
self.num_steps = num_steps
|
| 147 |
+
self.stride = stride
|
| 148 |
+
self.backend_kwargs = {} if backend_kwargs is None else backend_kwargs
|
| 149 |
+
self.logger = PythonLogger()
|
| 150 |
+
|
| 151 |
+
if file_type == "netcdf4" and not variables:
|
| 152 |
+
raise ValueError("Variables must be specified for a NetCDF4 source.")
|
| 153 |
+
|
| 154 |
+
# check root directory exists
|
| 155 |
+
if not self.data_dir.is_dir():
|
| 156 |
+
raise IOError(f"Error, data directory {self.data_dir} does not exist")
|
| 157 |
+
if self.stats_files is None:
|
| 158 |
+
self.logger.warning(
|
| 159 |
+
"Warning, no stats files specified, this will result in no normalisation"
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def dimensions_compatible(self, other) -> bool:
|
| 163 |
+
"""
|
| 164 |
+
Basic sanity check to test if two `ClimateDataSourceSpec` are
|
| 165 |
+
compatible.
|
| 166 |
+
"""
|
| 167 |
+
return (
|
| 168 |
+
self.data_shape == other.data_shape
|
| 169 |
+
and self.cropped_data_shape == other.cropped_data_shape
|
| 170 |
+
and self.num_samples_per_year == other.num_samples_per_year
|
| 171 |
+
and self.total_length == other.total_length
|
| 172 |
+
and self.n_years == other.n_years
|
| 173 |
+
)
|
| 174 |
+
|
| 175 |
+
def parse_dataset_files(
|
| 176 |
+
self,
|
| 177 |
+
num_samples_per_year: Union[int, None] = None,
|
| 178 |
+
patch_size: Union[int, None] = None,
|
| 179 |
+
) -> None:
|
| 180 |
+
"""Parses the data directory for valid files and determines training samples
|
| 181 |
+
|
| 182 |
+
Parameters
|
| 183 |
+
----------
|
| 184 |
+
num_samples_per_year : int, optional
|
| 185 |
+
Number of samples taken from each year. If None, all will be used, by default None
|
| 186 |
+
patch_size : Union[Tuple[int, int], int, None], optional
|
| 187 |
+
If specified, crops input and output variables so image dimensions are
|
| 188 |
+
divisible by patch_size, by default None
|
| 189 |
+
|
| 190 |
+
Raises
|
| 191 |
+
------
|
| 192 |
+
ValueError
|
| 193 |
+
In channels specified or number of samples per year is not valid
|
| 194 |
+
"""
|
| 195 |
+
# get all input data files
|
| 196 |
+
suffix = {"hdf5": "h5", "netcdf4": "nc"}[self.file_type]
|
| 197 |
+
self.data_paths = sorted(self.data_dir.glob(f"*.{suffix}"))
|
| 198 |
+
for data_path in self.data_paths:
|
| 199 |
+
self.logger.info(f"Climate data file found: {data_path}")
|
| 200 |
+
self.n_years = len(self.data_paths)
|
| 201 |
+
self.logger.info(f"Number of years: {self.n_years}")
|
| 202 |
+
|
| 203 |
+
# get total number of examples and image shape from the first file,
|
| 204 |
+
# assuming other files have exactly the same format.
|
| 205 |
+
self.logger.info(f"Getting file stats from {self.data_paths[0]}")
|
| 206 |
+
if self.file_type == "hdf5":
|
| 207 |
+
with h5py.File(self.data_paths[0], "r") as f:
|
| 208 |
+
dataset_shape = f["fields"].shape
|
| 209 |
+
else:
|
| 210 |
+
with nc.Dataset(self.data_paths[0], "r") as f:
|
| 211 |
+
var_shape = f[self.variables[0]].shape
|
| 212 |
+
dataset_shape = (var_shape[0], len(self.variables)) + var_shape[1:]
|
| 213 |
+
|
| 214 |
+
# truncate the dataset to avoid out-of-range sampling
|
| 215 |
+
data_samples_per_year = dataset_shape[0] - (self.num_steps - 1) * self.stride
|
| 216 |
+
self.data_shape = dataset_shape[2:]
|
| 217 |
+
|
| 218 |
+
# interpret list of variables into list of channels or vice versa
|
| 219 |
+
if self.file_type == "hdf5":
|
| 220 |
+
with open(self.metadata_path, "r") as f:
|
| 221 |
+
metadata = json.load(f)
|
| 222 |
+
data_vars = metadata["coords"]["channel"]
|
| 223 |
+
if self.variables is not None:
|
| 224 |
+
self.channels = [data_vars.index(v) for v in self.variables]
|
| 225 |
+
else:
|
| 226 |
+
if self.channels is None:
|
| 227 |
+
self.variables = data_vars
|
| 228 |
+
else:
|
| 229 |
+
self.variables = [data_vars[i] for i in self.channels]
|
| 230 |
+
|
| 231 |
+
# If channels not provided, use all of them
|
| 232 |
+
if self.channels is None:
|
| 233 |
+
self.channels = list(range(dataset_shape[1]))
|
| 234 |
+
|
| 235 |
+
# If num_samples_per_year use all
|
| 236 |
+
if num_samples_per_year is None:
|
| 237 |
+
num_samples_per_year = data_samples_per_year
|
| 238 |
+
self.num_samples_per_year = num_samples_per_year
|
| 239 |
+
|
| 240 |
+
# Adjust image shape if patch_size defined
|
| 241 |
+
if patch_size is not None:
|
| 242 |
+
self.cropped_data_shape = tuple(
|
| 243 |
+
s - s % patch_size[i] for i, s in enumerate(self.data_shape)
|
| 244 |
+
)
|
| 245 |
+
else:
|
| 246 |
+
self.cropped_data_shape = self.data_shape
|
| 247 |
+
self.logger.info(f"Input data shape: {self.cropped_data_shape}")
|
| 248 |
+
|
| 249 |
+
# Get total length
|
| 250 |
+
self.total_length = self.n_years * self.num_samples_per_year
|
| 251 |
+
|
| 252 |
+
# Sanity checks
|
| 253 |
+
if max(self.channels) >= dataset_shape[1]:
|
| 254 |
+
raise ValueError(
|
| 255 |
+
f"Provided channel has indexes greater than the number \
|
| 256 |
+
of fields {dataset_shape[1]}"
|
| 257 |
+
)
|
| 258 |
+
|
| 259 |
+
if self.num_samples_per_year > data_samples_per_year:
|
| 260 |
+
raise ValueError(
|
| 261 |
+
f"num_samples_per_year ({self.num_samples_per_year}) > number of \
|
| 262 |
+
samples available ({data_samples_per_year})!"
|
| 263 |
+
)
|
| 264 |
+
|
| 265 |
+
self._load_statistics()
|
| 266 |
+
|
| 267 |
+
self.logger.info(f"Number of samples/year: {self.num_samples_per_year}")
|
| 268 |
+
self.logger.info(f"Number of channels available: {dataset_shape[1]}")
|
| 269 |
+
|
| 270 |
+
def _load_statistics(self) -> None:
|
| 271 |
+
"""Loads climate statistics from pre-computed numpy files
|
| 272 |
+
|
| 273 |
+
The statistic files should be of name global_means.npy and global_std.npy with
|
| 274 |
+
a shape of [1, C, 1, 1] located in the stat_dir.
|
| 275 |
+
|
| 276 |
+
Raises
|
| 277 |
+
------
|
| 278 |
+
IOError
|
| 279 |
+
If statistics files are not found
|
| 280 |
+
AssertionError
|
| 281 |
+
If loaded numpy arrays are not of correct size
|
| 282 |
+
"""
|
| 283 |
+
# If no stats files we just skip loading the stats
|
| 284 |
+
if self.stats_files is None:
|
| 285 |
+
self.mu = None
|
| 286 |
+
self.sd = None
|
| 287 |
+
return
|
| 288 |
+
# load normalisation values
|
| 289 |
+
if set(self.stats_files) == {"mean", "std"}: # use mean and std files
|
| 290 |
+
mean_stat_file = self.stats_files["mean"]
|
| 291 |
+
std_stat_file = self.stats_files["std"]
|
| 292 |
+
|
| 293 |
+
if not mean_stat_file.exists():
|
| 294 |
+
raise IOError(f"Mean statistics file {mean_stat_file} not found")
|
| 295 |
+
if not std_stat_file.exists():
|
| 296 |
+
raise IOError(f"Std statistics file {std_stat_file} not found")
|
| 297 |
+
|
| 298 |
+
# has shape [1, C, 1, 1]
|
| 299 |
+
self.mu = np.load(str(mean_stat_file))[:, self.channels]
|
| 300 |
+
# has shape [1, C, 1, 1]
|
| 301 |
+
self.sd = np.load(str(std_stat_file))[:, self.channels]
|
| 302 |
+
elif set(self.stats_files) == {
|
| 303 |
+
"norm",
|
| 304 |
+
}: # use dict formatted file with named variables
|
| 305 |
+
norm_stat_file = self.stats_files["norm"]
|
| 306 |
+
if not norm_stat_file.exists():
|
| 307 |
+
raise IOError(f"Statistics file {norm_stat_file} not found")
|
| 308 |
+
|
| 309 |
+
norm = np.load(str(norm_stat_file), allow_pickle=True).item()
|
| 310 |
+
mu = np.array([norm[var]["mean"] for var in self.variables])
|
| 311 |
+
self.mu = mu.reshape((1, len(mu), 1, 1))
|
| 312 |
+
sd = np.array([norm[var]["std"] for var in self.variables])
|
| 313 |
+
self.sd = sd.reshape((1, len(sd), 1, 1))
|
| 314 |
+
else:
|
| 315 |
+
raise ValueError(("Invalid statistics file specification"))
|
| 316 |
+
|
| 317 |
+
if not self.mu.shape == self.sd.shape == (1, len(self.channels), 1, 1):
|
| 318 |
+
raise ValueError("Error, normalisation arrays have wrong shape")
|
| 319 |
+
|
| 320 |
+
|
| 321 |
+
class ClimateDatapipe(Datapipe):
|
| 322 |
+
"""
|
| 323 |
+
A Climate DALI data pipeline. This pipeline loads data from
|
| 324 |
+
HDF5/NetCDF4 files. It can also return additional data such as the
|
| 325 |
+
solar zenith angle for each time step. Additionally, it normalizes
|
| 326 |
+
the data if a statistics file is provided. The pipeline returns a dictionary
|
| 327 |
+
with the following structure, where {name} indicates the name of the data
|
| 328 |
+
source provided:
|
| 329 |
+
|
| 330 |
+
- ``state_seq-{name}``: Tensors of shape
|
| 331 |
+
(batch_size, num_steps, num_channels, height, width).
|
| 332 |
+
This sequence is drawn from the data file and normalized if a
|
| 333 |
+
statistics file is provided.
|
| 334 |
+
- ``timestamps-{name}``: Tensors of shape (batch_size, num_steps), containing
|
| 335 |
+
timestamps for each timestep in the sequence.
|
| 336 |
+
- ``{aux_variable}-{name}``: Tensors of shape
|
| 337 |
+
(batch_size, num_steps, aux_channels, height, width),
|
| 338 |
+
containing the auxiliary variables returned by each data source
|
| 339 |
+
- ``cos_zenith-{name}``: Tensors of shape (batch_size, num_steps, 1, height, width),
|
| 340 |
+
containing the cosine of the solar zenith angle if specified.
|
| 341 |
+
- ``{invariant_name}``: Tensors of shape (batch_size, invariant_channels, height, width),
|
| 342 |
+
containing the time-invariant data (depending only on spatial coordinates)
|
| 343 |
+
returned by the datapipe. These can include e.g.
|
| 344 |
+
land-sea mask and geopotential/surface elevation.
|
| 345 |
+
|
| 346 |
+
To use this data pipeline, your data directory must be structured as
|
| 347 |
+
follows:
|
| 348 |
+
```
|
| 349 |
+
data_dir
|
| 350 |
+
├── 1980.h5
|
| 351 |
+
├── 1981.h5
|
| 352 |
+
├── 1982.h5
|
| 353 |
+
├── ...
|
| 354 |
+
└── 2020.h5
|
| 355 |
+
```
|
| 356 |
+
|
| 357 |
+
The files are assumed have no metadata, such as timestamps.
|
| 358 |
+
Because of this, it's important to specify the `dt` parameter and the
|
| 359 |
+
`start_year` parameter so that the pipeline can compute the correct
|
| 360 |
+
timestamps for each timestep. These timestamps are then used to compute the
|
| 361 |
+
cosine of the solar zenith angle, if specified.
|
| 362 |
+
|
| 363 |
+
Parameters
|
| 364 |
+
----------
|
| 365 |
+
sources: Iterable[ClimateDataSourceSpec]
|
| 366 |
+
A list of data specifications defining the sources for the climate variables
|
| 367 |
+
batch_size : int, optional
|
| 368 |
+
Batch size, by default 1
|
| 369 |
+
dt : float, optional
|
| 370 |
+
Time in hours between each timestep in the dataset, by default 6 hr
|
| 371 |
+
start_year : int, optional
|
| 372 |
+
Start year of dataset, by default 1980
|
| 373 |
+
latlon_bounds : Tuple[Tuple[float, float], Tuple[float, float]], optional
|
| 374 |
+
Bounds of latitude and longitude in the data, in the format
|
| 375 |
+
((lat_start, lat_end,), (lon_start, lon_end)).
|
| 376 |
+
By default ((90, -90), (0, 360)).
|
| 377 |
+
crop_window: Union[Tuple[Tuple[float, float], Tuple[float, float]], None], optional
|
| 378 |
+
The window to crop the data to, in the format ((i0,i1), (j0,j1)) where the
|
| 379 |
+
first spatial dimension will be cropped to i0:i1 and the second to j0:j1.
|
| 380 |
+
If not given, all data will be used.
|
| 381 |
+
invariants : Mapping[str,Callable], optional
|
| 382 |
+
Specifies the time-invariant data (for example latitude and longitude)
|
| 383 |
+
included in the data samples. Should be a dict where the keys are the
|
| 384 |
+
names of the invariants and the values are the corresponding
|
| 385 |
+
functions. The functions need to accept an argument of the shape
|
| 386 |
+
(2, data_shape[0], data_shape[1]) where the first dimension contains
|
| 387 |
+
latitude and longitude in degrees and the other dimensions corresponding
|
| 388 |
+
to the shape of data in the data files. For example,
|
| 389 |
+
invariants={"trig_latlon": invariants.LatLon()}
|
| 390 |
+
will include the sin/cos of lat/lon in the output.
|
| 391 |
+
num_samples_per_year : int, optional
|
| 392 |
+
Number of samples taken from each year. If None, all will be used, by default None
|
| 393 |
+
shuffle : bool, optional
|
| 394 |
+
Shuffle dataset, by default True
|
| 395 |
+
num_workers : int, optional
|
| 396 |
+
Number of workers, by default 1
|
| 397 |
+
device: Union[str, torch.device], optional
|
| 398 |
+
Device for DALI pipeline to run on, by default cuda
|
| 399 |
+
process_rank : int, optional
|
| 400 |
+
Rank ID of local process, by default 0
|
| 401 |
+
world_size : int, optional
|
| 402 |
+
Number of training processes, by default 1
|
| 403 |
+
"""
|
| 404 |
+
|
| 405 |
+
def __init__(
|
| 406 |
+
self,
|
| 407 |
+
sources: Iterable[ClimateDataSourceSpec],
|
| 408 |
+
batch_size: int = 1,
|
| 409 |
+
dt: float = 6.0,
|
| 410 |
+
start_year: int = 1980,
|
| 411 |
+
latlon_bounds: Tuple[Tuple[float, float], Tuple[float, float]] = (
|
| 412 |
+
(90, -90),
|
| 413 |
+
(0, 360),
|
| 414 |
+
),
|
| 415 |
+
crop_window: Union[
|
| 416 |
+
Tuple[Tuple[float, float], Tuple[float, float]], None
|
| 417 |
+
] = None,
|
| 418 |
+
invariants: Union[Mapping[str, Callable], None] = None,
|
| 419 |
+
num_samples_per_year: Union[int, None] = None,
|
| 420 |
+
shuffle: bool = True,
|
| 421 |
+
num_workers: int = 1, # TODO: is there a faster good default?
|
| 422 |
+
device: Union[str, torch.device] = "cuda",
|
| 423 |
+
process_rank: int = 0,
|
| 424 |
+
world_size: int = 1,
|
| 425 |
+
):
|
| 426 |
+
super().__init__(meta=MetaData())
|
| 427 |
+
self.sources = list(sources)
|
| 428 |
+
self.batch_size = batch_size
|
| 429 |
+
self.num_workers = num_workers
|
| 430 |
+
self.shuffle = shuffle
|
| 431 |
+
self.dt = dt
|
| 432 |
+
self.start_year = start_year
|
| 433 |
+
self.data_latlon_bounds = latlon_bounds
|
| 434 |
+
self.process_rank = process_rank
|
| 435 |
+
self.world_size = world_size
|
| 436 |
+
self.num_samples_per_year = num_samples_per_year
|
| 437 |
+
self.logger = PythonLogger()
|
| 438 |
+
|
| 439 |
+
if invariants is None:
|
| 440 |
+
invariants = {}
|
| 441 |
+
|
| 442 |
+
# Determine outputs of pipeline
|
| 443 |
+
self.pipe_outputs = []
|
| 444 |
+
for i, spec in enumerate(self.sources):
|
| 445 |
+
name = spec.name if spec.name is not None else i
|
| 446 |
+
self.pipe_outputs += [f"state_seq-{name}", f"timestamps-{name}"]
|
| 447 |
+
self.pipe_outputs.extend(
|
| 448 |
+
f"{aux_var}-{name}" for aux_var in spec.aux_variables
|
| 449 |
+
)
|
| 450 |
+
if spec.use_cos_zenith:
|
| 451 |
+
self.pipe_outputs.append(f"cos_zenith-{name}")
|
| 452 |
+
self.pipe_outputs.extend(invariants.keys())
|
| 453 |
+
|
| 454 |
+
# Set up device, needed for pipeline
|
| 455 |
+
if isinstance(device, str):
|
| 456 |
+
device = torch.device(device)
|
| 457 |
+
|
| 458 |
+
# Need a index id if cuda
|
| 459 |
+
if device.type == "cuda" and device.index is None:
|
| 460 |
+
device = torch.device("cuda:0")
|
| 461 |
+
self.device = device
|
| 462 |
+
|
| 463 |
+
# Load all data files and statistics
|
| 464 |
+
for spec in sources:
|
| 465 |
+
spec.parse_dataset_files(num_samples_per_year=num_samples_per_year)
|
| 466 |
+
for i, spec_i in enumerate(sources):
|
| 467 |
+
for spec_j in sources[i + 1 :]:
|
| 468 |
+
if not spec_i.dimensions_compatible(spec_j):
|
| 469 |
+
raise ValueError("Incompatible data sources")
|
| 470 |
+
|
| 471 |
+
self.data_latlon = np.stack(
|
| 472 |
+
latlon_grid(bounds=self.data_latlon_bounds, shape=sources[0].data_shape),
|
| 473 |
+
axis=0,
|
| 474 |
+
)
|
| 475 |
+
if crop_window is None:
|
| 476 |
+
crop_window = (
|
| 477 |
+
(0, sources[0].cropped_data_shape[0]),
|
| 478 |
+
(0, sources[0].cropped_data_shape[1]),
|
| 479 |
+
)
|
| 480 |
+
self.crop_window = crop_window
|
| 481 |
+
self.window_latlon = self._crop_to_window(self.data_latlon)
|
| 482 |
+
self.window_latlon_dali = dali.types.Constant(self.window_latlon)
|
| 483 |
+
|
| 484 |
+
# load invariants
|
| 485 |
+
self.invariants = {
|
| 486 |
+
var: callback(self.window_latlon) for (var, callback) in invariants.items()
|
| 487 |
+
}
|
| 488 |
+
|
| 489 |
+
# Create pipeline
|
| 490 |
+
self.pipe = self._create_pipeline()
|
| 491 |
+
|
| 492 |
+
def _source_cls_from_type(self, source_type: str) -> type:
|
| 493 |
+
"""Get the external source class based on a string descriptor."""
|
| 494 |
+
return {
|
| 495 |
+
"hdf5": ClimateHDF5DaliExternalSource,
|
| 496 |
+
"netcdf4": ClimateNetCDF4DaliExternalSource,
|
| 497 |
+
}[source_type]
|
| 498 |
+
|
| 499 |
+
def _crop_to_window(self, x):
|
| 500 |
+
cw = self.crop_window
|
| 501 |
+
if isinstance(x, dali.pipeline.DataNode):
|
| 502 |
+
# DALI doesn't support ellipsis notation
|
| 503 |
+
return x[:, :, cw[0][0] : cw[0][1], cw[1][0] : cw[1][1]]
|
| 504 |
+
else:
|
| 505 |
+
return x[..., cw[0][0] : cw[0][1], cw[1][0] : cw[1][1]]
|
| 506 |
+
|
| 507 |
+
def _source_outputs(self, spec: ClimateDataSourceSpec) -> List:
|
| 508 |
+
"""Create DALI outputs for a given data source specification.
|
| 509 |
+
|
| 510 |
+
Parameters
|
| 511 |
+
----------
|
| 512 |
+
spec: ClimateDataSourceSpec
|
| 513 |
+
The data source specification.
|
| 514 |
+
"""
|
| 515 |
+
# HDF5/NetCDF source
|
| 516 |
+
source_cls = self._source_cls_from_type(spec.file_type)
|
| 517 |
+
source = source_cls(
|
| 518 |
+
data_paths=spec.data_paths,
|
| 519 |
+
num_samples=spec.total_length,
|
| 520 |
+
channels=spec.channels,
|
| 521 |
+
latlon=self.data_latlon,
|
| 522 |
+
variables=spec.variables,
|
| 523 |
+
aux_variables=spec.aux_variables,
|
| 524 |
+
stride=spec.stride,
|
| 525 |
+
dt=self.dt,
|
| 526 |
+
start_year=self.start_year,
|
| 527 |
+
num_steps=spec.num_steps,
|
| 528 |
+
num_samples_per_year=spec.num_samples_per_year,
|
| 529 |
+
batch_size=self.batch_size,
|
| 530 |
+
shuffle=self.shuffle,
|
| 531 |
+
process_rank=self.process_rank,
|
| 532 |
+
world_size=self.world_size,
|
| 533 |
+
)
|
| 534 |
+
|
| 535 |
+
# Update length of dataset
|
| 536 |
+
self.total_length = len(source) // self.batch_size
|
| 537 |
+
|
| 538 |
+
# Read current batch
|
| 539 |
+
(state_seq, timestamps, *aux) = dali.fn.external_source(
|
| 540 |
+
source,
|
| 541 |
+
num_outputs=source.num_outputs(),
|
| 542 |
+
parallel=True,
|
| 543 |
+
batch=False,
|
| 544 |
+
)
|
| 545 |
+
|
| 546 |
+
# Crop
|
| 547 |
+
state_seq = self._crop_to_window(state_seq)
|
| 548 |
+
aux = (self._crop_to_window(x) for x in aux)
|
| 549 |
+
|
| 550 |
+
# Normalize
|
| 551 |
+
if spec.stats_files is not None:
|
| 552 |
+
state_seq = dali.fn.normalize(state_seq, mean=spec.mu, stddev=spec.sd)
|
| 553 |
+
|
| 554 |
+
# Make output list
|
| 555 |
+
outputs = [state_seq, timestamps, *aux]
|
| 556 |
+
|
| 557 |
+
# Get cosine zenith angle
|
| 558 |
+
if spec.use_cos_zenith:
|
| 559 |
+
cos_zenith = dali.fn.cast(
|
| 560 |
+
cos_zenith_angle(timestamps, latlon=self.window_latlon_dali),
|
| 561 |
+
dtype=dali.types.FLOAT,
|
| 562 |
+
)
|
| 563 |
+
outputs.append(cos_zenith)
|
| 564 |
+
|
| 565 |
+
return outputs
|
| 566 |
+
|
| 567 |
+
def _invariant_outputs(self):
|
| 568 |
+
for inv in self.invariants.values():
|
| 569 |
+
if self.crop_window is not None:
|
| 570 |
+
inv = self._crop_to_window(inv)
|
| 571 |
+
yield dali.types.Constant(inv)
|
| 572 |
+
|
| 573 |
+
def _create_pipeline(self) -> dali.Pipeline:
|
| 574 |
+
"""Create DALI pipeline
|
| 575 |
+
|
| 576 |
+
Returns
|
| 577 |
+
-------
|
| 578 |
+
dali.Pipeline
|
| 579 |
+
Climate DALI pipeline
|
| 580 |
+
"""
|
| 581 |
+
pipe = dali.Pipeline(
|
| 582 |
+
batch_size=self.batch_size,
|
| 583 |
+
num_threads=2,
|
| 584 |
+
prefetch_queue_depth=2,
|
| 585 |
+
py_num_workers=self.num_workers,
|
| 586 |
+
device_id=self.device.index,
|
| 587 |
+
py_start_method="spawn",
|
| 588 |
+
)
|
| 589 |
+
|
| 590 |
+
with pipe:
|
| 591 |
+
# Concatenate outputs from all sources as well as invariants
|
| 592 |
+
outputs = list(
|
| 593 |
+
chain(
|
| 594 |
+
*(self._source_outputs(spec) for spec in self.sources),
|
| 595 |
+
self._invariant_outputs(),
|
| 596 |
+
)
|
| 597 |
+
)
|
| 598 |
+
|
| 599 |
+
if self.device.type == "cuda":
|
| 600 |
+
# Move tensors to GPU as external_source won't do that
|
| 601 |
+
outputs = [o.gpu() for o in outputs]
|
| 602 |
+
|
| 603 |
+
# Set outputs
|
| 604 |
+
pipe.set_outputs(*outputs)
|
| 605 |
+
|
| 606 |
+
return pipe
|
| 607 |
+
|
| 608 |
+
def __iter__(self):
|
| 609 |
+
# Reset the pipeline before creating an iterator to enable epochs.
|
| 610 |
+
self.pipe.reset()
|
| 611 |
+
# Create DALI PyTorch iterator.
|
| 612 |
+
return dali_pth.DALIGenericIterator([self.pipe], self.pipe_outputs)
|
| 613 |
+
|
| 614 |
+
def __len__(self):
|
| 615 |
+
return self.total_length
|
| 616 |
+
|
| 617 |
+
|
| 618 |
+
class ClimateDaliExternalSource(ABC):
|
| 619 |
+
"""DALI Source for lazy-loading the HDF5/NetCDF4 climate files
|
| 620 |
+
|
| 621 |
+
Parameters
|
| 622 |
+
----------
|
| 623 |
+
data_paths : Iterable[str]
|
| 624 |
+
Directory where climate data is stored
|
| 625 |
+
num_samples : int
|
| 626 |
+
Total number of training samples
|
| 627 |
+
channels : Iterable[int]
|
| 628 |
+
List representing which climate variables to load
|
| 629 |
+
num_steps : int
|
| 630 |
+
Number of timesteps to load
|
| 631 |
+
stride : int
|
| 632 |
+
Number of steps between input and output variables
|
| 633 |
+
dt : float, optional
|
| 634 |
+
Time in hours between each timestep in the dataset, by default 6 hr
|
| 635 |
+
start_year : int, optional
|
| 636 |
+
Start year of dataset, by default 1980
|
| 637 |
+
num_samples_per_year : int
|
| 638 |
+
Number of samples randomly taken from each year
|
| 639 |
+
variables: Union[List[str], None], optional for HDF5 files, mandatory for NetCDF4 files
|
| 640 |
+
List of named variables to load. Variables will be read in the order specified
|
| 641 |
+
by this parameter.
|
| 642 |
+
aux_variables : Union[Mapping[str, Callable], None], optional
|
| 643 |
+
A dictionary mapping strings to callables that accept arguments
|
| 644 |
+
(timestamps: numpy.ndarray, latlon: numpy.ndarray). These define any auxiliary
|
| 645 |
+
variables returned from this source.
|
| 646 |
+
batch_size : int, optional
|
| 647 |
+
Batch size, by default 1
|
| 648 |
+
shuffle : bool, optional
|
| 649 |
+
Shuffle dataset, by default True
|
| 650 |
+
process_rank : int, optional
|
| 651 |
+
Rank ID of local process, by default 0
|
| 652 |
+
world_size : int, optional
|
| 653 |
+
Number of training processes, by default 1
|
| 654 |
+
|
| 655 |
+
Note
|
| 656 |
+
----
|
| 657 |
+
For more information about DALI external source operator:
|
| 658 |
+
https://docs.nvidia.com/deeplearning/dali/archives/dali_1_13_0/user-guide/docs/examples/general/data_loading/parallel_external_source.html
|
| 659 |
+
"""
|
| 660 |
+
|
| 661 |
+
def __init__(
|
| 662 |
+
self,
|
| 663 |
+
data_paths: Iterable[str],
|
| 664 |
+
num_samples: int,
|
| 665 |
+
channels: Iterable[int],
|
| 666 |
+
num_steps: int,
|
| 667 |
+
stride: int,
|
| 668 |
+
dt: float,
|
| 669 |
+
start_year: int,
|
| 670 |
+
num_samples_per_year: int,
|
| 671 |
+
latlon: np.ndarray,
|
| 672 |
+
variables: Union[List[str], None] = None,
|
| 673 |
+
aux_variables: List[Union[str, Callable]] = (),
|
| 674 |
+
batch_size: int = 1,
|
| 675 |
+
shuffle: bool = True,
|
| 676 |
+
process_rank: int = 0,
|
| 677 |
+
world_size: int = 1,
|
| 678 |
+
backend_kwargs: Union[dict, None] = None,
|
| 679 |
+
):
|
| 680 |
+
self.data_paths = list(data_paths)
|
| 681 |
+
# Will be populated later once each worker starts running in its own process.
|
| 682 |
+
self.data_files = [None] * len(self.data_paths)
|
| 683 |
+
self.num_samples = num_samples
|
| 684 |
+
self.chans = list(channels)
|
| 685 |
+
self.latlon = latlon
|
| 686 |
+
self.variables = variables
|
| 687 |
+
self.aux_variables = aux_variables
|
| 688 |
+
self.num_steps = num_steps
|
| 689 |
+
self.stride = stride
|
| 690 |
+
self.dt = dt
|
| 691 |
+
self.start_year = start_year
|
| 692 |
+
self.num_samples_per_year = num_samples_per_year
|
| 693 |
+
self.batch_size = batch_size
|
| 694 |
+
self.shuffle = shuffle
|
| 695 |
+
self.backend_kwargs = {} if backend_kwargs is None else backend_kwargs
|
| 696 |
+
|
| 697 |
+
self.last_epoch = None
|
| 698 |
+
|
| 699 |
+
self.indices = np.arange(num_samples)
|
| 700 |
+
# Shard from indices if running in parallel
|
| 701 |
+
self.indices = np.array_split(self.indices, world_size)[process_rank]
|
| 702 |
+
|
| 703 |
+
# Get number of full batches, ignore possible last incomplete batch for now.
|
| 704 |
+
# Also, DALI external source does not support incomplete batches in parallel mode.
|
| 705 |
+
self.num_batches = len(self.indices) // self.batch_size
|
| 706 |
+
|
| 707 |
+
@abstractmethod
|
| 708 |
+
def _load_sequence(self, year_idx: int, idx: int) -> np.array:
|
| 709 |
+
"""Write data from year index `year_idx` and sample index `idx` to output"""
|
| 710 |
+
pass
|
| 711 |
+
|
| 712 |
+
def __call__(self, sample_info: dali.types.SampleInfo) -> Tuple[Tensor, np.ndarray]:
|
| 713 |
+
if sample_info.iteration >= self.num_batches:
|
| 714 |
+
raise StopIteration()
|
| 715 |
+
|
| 716 |
+
# Shuffle before the next epoch starts
|
| 717 |
+
if self.shuffle and sample_info.epoch_idx != self.last_epoch:
|
| 718 |
+
# All workers use the same rng seed so the resulting
|
| 719 |
+
# indices are the same across workers
|
| 720 |
+
np.random.default_rng(seed=sample_info.epoch_idx).shuffle(self.indices)
|
| 721 |
+
self.last_epoch = sample_info.epoch_idx
|
| 722 |
+
|
| 723 |
+
# Get local indices from global index
|
| 724 |
+
# TODO: This is very hacky, but it works for now
|
| 725 |
+
idx = self.indices[sample_info.idx_in_epoch]
|
| 726 |
+
year_idx = idx // self.num_samples_per_year
|
| 727 |
+
in_idx = idx % self.num_samples_per_year
|
| 728 |
+
|
| 729 |
+
state_seq = self._load_sequence(year_idx, in_idx)
|
| 730 |
+
|
| 731 |
+
# Load sequence of timestamps
|
| 732 |
+
year = self.start_year + year_idx
|
| 733 |
+
start_time = datetime(year, 1, 1, tzinfo=pytz.utc) + timedelta(
|
| 734 |
+
hours=int(in_idx) * self.dt
|
| 735 |
+
)
|
| 736 |
+
timestamps = np.array(
|
| 737 |
+
[
|
| 738 |
+
(start_time + timedelta(hours=i * self.stride * self.dt)).timestamp()
|
| 739 |
+
for i in range(self.num_steps)
|
| 740 |
+
]
|
| 741 |
+
)
|
| 742 |
+
|
| 743 |
+
# outputs from auxiliary sources
|
| 744 |
+
aux_outputs = (
|
| 745 |
+
callback(timestamps, self.latlon)
|
| 746 |
+
for callback in self.aux_variables.values()
|
| 747 |
+
)
|
| 748 |
+
|
| 749 |
+
return (state_seq, timestamps, *aux_outputs)
|
| 750 |
+
|
| 751 |
+
def num_outputs(self):
|
| 752 |
+
return 2 + len(self.aux_variables)
|
| 753 |
+
|
| 754 |
+
def __len__(self):
|
| 755 |
+
return len(self.indices)
|
| 756 |
+
|
| 757 |
+
|
| 758 |
+
class ClimateHDF5DaliExternalSource(ClimateDaliExternalSource):
|
| 759 |
+
"""DALI source for reading HDF5 formatted climate data files."""
|
| 760 |
+
|
| 761 |
+
def _get_data_file(self, year_idx: int) -> h5py.File:
|
| 762 |
+
"""Return the opened file for year `year_idx`."""
|
| 763 |
+
if self.data_files[year_idx] is None:
|
| 764 |
+
# This will be called once per worker. Workers are persistent,
|
| 765 |
+
# so there is no need to explicitly close the files - this will be done
|
| 766 |
+
# when corresponding pipeline/dataset is destroyed.
|
| 767 |
+
# Lazy opening avoids unnecessary file open ops when sharding.
|
| 768 |
+
self.data_files[year_idx] = h5py.File(self.data_paths[year_idx], "r")
|
| 769 |
+
return self.data_files[year_idx]
|
| 770 |
+
|
| 771 |
+
def _load_sequence(self, year_idx: int, idx: int) -> np.array:
|
| 772 |
+
# TODO: the data is returned in a weird (time, channels, width, height) shape
|
| 773 |
+
data = self._get_data_file(year_idx)["fields"]
|
| 774 |
+
return data[idx : idx + self.num_steps * self.stride : self.stride, self.chans]
|
| 775 |
+
|
| 776 |
+
|
| 777 |
+
class ClimateNetCDF4DaliExternalSource(ClimateDaliExternalSource):
|
| 778 |
+
"""DALI source for reading NetCDF4 formatted climate data files."""
|
| 779 |
+
|
| 780 |
+
def _get_data_file(self, year_idx: int) -> netcdf_file:
|
| 781 |
+
"""Return the opened file for year `year_idx`."""
|
| 782 |
+
if self.data_files[year_idx] is None:
|
| 783 |
+
# This will be called once per worker. Workers are persistent,
|
| 784 |
+
# so there is no need to explicitly close the files - this will be done
|
| 785 |
+
# when corresponding pipeline/dataset is destroyed
|
| 786 |
+
# Lazy opening avoids unnecessary file open ops when sharding.
|
| 787 |
+
# NOTE: The SciPy NetCDF reader can be used if the netCDF4 library
|
| 788 |
+
# causes crashes.
|
| 789 |
+
reader = self.backend_kwargs.get("reader", "netcdf4")
|
| 790 |
+
if reader == "scipy":
|
| 791 |
+
self.data_files[year_idx] = netcdf_file(self.data_paths[year_idx])
|
| 792 |
+
elif reader == "netcdf4":
|
| 793 |
+
self.data_files[year_idx] = nc.Dataset(self.data_paths[year_idx], "r")
|
| 794 |
+
self.data_files[year_idx].set_auto_maskandscale(False)
|
| 795 |
+
|
| 796 |
+
return self.data_files[year_idx]
|
| 797 |
+
|
| 798 |
+
def _load_sequence(self, year_idx: int, idx: int) -> np.array:
|
| 799 |
+
data_file = self._get_data_file(year_idx)
|
| 800 |
+
shape = data_file.variables[self.variables[0]].shape
|
| 801 |
+
shape = (self.num_steps, len(self.variables)) + shape[1:]
|
| 802 |
+
# TODO: this can be optimized to do the NetCDF scale/offset on GPU
|
| 803 |
+
output = np.empty(shape, dtype=np.float32)
|
| 804 |
+
for i, var in enumerate(self.variables):
|
| 805 |
+
v = data_file.variables[var]
|
| 806 |
+
output[:, i] = v[
|
| 807 |
+
idx : idx + self.num_steps * self.stride : self.stride
|
| 808 |
+
].copy() # .copy() avoids hanging references
|
| 809 |
+
if hasattr(v, "scale_factor"):
|
| 810 |
+
output[:, i] *= v.scale_factor
|
| 811 |
+
if hasattr(v, "add_offset"):
|
| 812 |
+
output[:, i] += v.add_offset
|
| 813 |
+
return output
|
physics_mcp/source/physicsnemo/datapipes/climate/era5_hdf5.py
ADDED
|
@@ -0,0 +1,622 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import h5py
|
| 18 |
+
import numpy as np
|
| 19 |
+
import torch
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
import nvidia.dali as dali
|
| 23 |
+
import nvidia.dali.plugin.pytorch as dali_pth
|
| 24 |
+
except ImportError:
|
| 25 |
+
raise ImportError(
|
| 26 |
+
"DALI dataset requires NVIDIA DALI package to be installed. "
|
| 27 |
+
+ "The package can be installed at:\n"
|
| 28 |
+
+ "https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html"
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
from dataclasses import dataclass
|
| 32 |
+
from datetime import datetime, timedelta
|
| 33 |
+
from pathlib import Path
|
| 34 |
+
from typing import Dict, Iterable, List, Tuple, Union
|
| 35 |
+
|
| 36 |
+
import pytz
|
| 37 |
+
|
| 38 |
+
from physicsnemo.datapipes.climate.utils.invariant import latlon_grid
|
| 39 |
+
from physicsnemo.datapipes.climate.utils.zenith_angle import cos_zenith_angle
|
| 40 |
+
|
| 41 |
+
from ..datapipe import Datapipe
|
| 42 |
+
from ..meta import DatapipeMetaData
|
| 43 |
+
|
| 44 |
+
Tensor = torch.Tensor
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class MetaData(DatapipeMetaData):
|
| 49 |
+
name: str = "ERA5HDF5"
|
| 50 |
+
# Optimization
|
| 51 |
+
auto_device: bool = True
|
| 52 |
+
cuda_graphs: bool = True
|
| 53 |
+
# Parallel
|
| 54 |
+
ddp_sharding: bool = True
|
| 55 |
+
|
| 56 |
+
|
| 57 |
+
class ERA5HDF5Datapipe(Datapipe):
|
| 58 |
+
"""ERA5 DALI data pipeline for HDF5 files
|
| 59 |
+
|
| 60 |
+
Parameters
|
| 61 |
+
----------
|
| 62 |
+
data_dir : str
|
| 63 |
+
Directory where ERA5 data is stored
|
| 64 |
+
stats_dir : Union[str, None], optional
|
| 65 |
+
Directory to data statistic numpy files for normalization, if None, no normalization
|
| 66 |
+
will be used, by default None
|
| 67 |
+
channels : Union[List[int], None], optional
|
| 68 |
+
Defines which ERA5 variables to load, if None will use all in HDF5 file, by default None
|
| 69 |
+
batch_size : int, optional
|
| 70 |
+
Batch size, by default 1
|
| 71 |
+
stride : int, optional
|
| 72 |
+
Number of steps between input and output variables. For example, if the dataset
|
| 73 |
+
contains data at every 6 hours, a stride 1 = 6 hour delta t and
|
| 74 |
+
stride 2 = 12 hours delta t, by default 1
|
| 75 |
+
num_steps : int, optional
|
| 76 |
+
Number of timesteps are included in the output variables, by default 1
|
| 77 |
+
num_history : int, optional
|
| 78 |
+
Number of previous timesteps included in the input variables, by default 0
|
| 79 |
+
latlon_resolution: Tuple[int, int], optional
|
| 80 |
+
The resolution for the latitude-longitude grid (H, W). Needs to be specified
|
| 81 |
+
for cos zenith angle computation, or interpolation. By default None
|
| 82 |
+
interpolation_type: str, optional
|
| 83 |
+
Interpolation type for resizing. Supports ["INTERP_NN", "INTERP_LINEAR", "INTERP_CUBIC",
|
| 84 |
+
"INTERP_LANCZOS3", "INTERP_TRIANGULAR", "INTERP_GAUSSIAN"]. By default None
|
| 85 |
+
(no interpolation is done)
|
| 86 |
+
patch_size : Union[Tuple[int, int], int, None], optional
|
| 87 |
+
If specified, crops input and output variables so image dimensions are
|
| 88 |
+
divisible by patch_size, by default None
|
| 89 |
+
num_samples_per_year : int, optional
|
| 90 |
+
Number of samples randomly taken from each year. If None, all will be used, by default None
|
| 91 |
+
use_cos_zenith: bool, optional
|
| 92 |
+
If True, the cosine zenith angles corresponding to the coordinates will be produced,
|
| 93 |
+
by default False
|
| 94 |
+
cos_zenith_args: Dict, optional
|
| 95 |
+
Dictionary containing the following:
|
| 96 |
+
|
| 97 |
+
dt: float, optional
|
| 98 |
+
Time in hours between each timestep in the dataset, by default 6 hr
|
| 99 |
+
|
| 100 |
+
start_year: int, optional
|
| 101 |
+
Start year of dataset, by default 1980
|
| 102 |
+
|
| 103 |
+
latlon_bounds : Tuple[Tuple[float, float], Tuple[float, float]], optional
|
| 104 |
+
Bounds of latitude and longitude in the data, in the format
|
| 105 |
+
((lat_start, lat_end,), (lon_start, lon_end)).
|
| 106 |
+
By default ((90, -90), (0, 360)).
|
| 107 |
+
|
| 108 |
+
Defaults are only applicable if use_cos_zenith is True. Otherwise, defaults to {}.
|
| 109 |
+
use_time_of_year_index: bool
|
| 110 |
+
If true, also returns the index that can be used to determine the time of the year
|
| 111 |
+
corresponding to each sample. By default False.
|
| 112 |
+
shuffle : bool, optional
|
| 113 |
+
Shuffle dataset, by default True
|
| 114 |
+
num_workers : int, optional
|
| 115 |
+
Number of workers, by default 1
|
| 116 |
+
device: Union[str, torch.device], optional
|
| 117 |
+
Device for DALI pipeline to run on, by default cuda
|
| 118 |
+
process_rank : int, optional
|
| 119 |
+
Rank ID of local process, by default 0
|
| 120 |
+
world_size : int, optional
|
| 121 |
+
Number of training processes, by default 1
|
| 122 |
+
"""
|
| 123 |
+
|
| 124 |
+
def __init__(
|
| 125 |
+
self,
|
| 126 |
+
data_dir: str,
|
| 127 |
+
stats_dir: Union[str, None] = None,
|
| 128 |
+
channels: Union[List[int], None] = None,
|
| 129 |
+
batch_size: int = 1,
|
| 130 |
+
num_steps: int = 1,
|
| 131 |
+
num_history: int = 0,
|
| 132 |
+
stride: int = 1,
|
| 133 |
+
latlon_resolution: Union[Tuple[int, int], None] = None,
|
| 134 |
+
interpolation_type: Union[str, None] = None,
|
| 135 |
+
patch_size: Union[Tuple[int, int], int, None] = None,
|
| 136 |
+
num_samples_per_year: Union[int, None] = None,
|
| 137 |
+
use_cos_zenith: bool = False,
|
| 138 |
+
cos_zenith_args: Dict = {},
|
| 139 |
+
use_time_of_year_index: bool = False,
|
| 140 |
+
shuffle: bool = True,
|
| 141 |
+
num_workers: int = 1,
|
| 142 |
+
device: Union[str, torch.device] = "cuda",
|
| 143 |
+
process_rank: int = 0,
|
| 144 |
+
world_size: int = 1,
|
| 145 |
+
):
|
| 146 |
+
super().__init__(meta=MetaData())
|
| 147 |
+
self.batch_size = batch_size
|
| 148 |
+
self.num_workers = num_workers
|
| 149 |
+
self.shuffle = shuffle
|
| 150 |
+
self.data_dir = Path(data_dir)
|
| 151 |
+
self.stats_dir = Path(stats_dir) if stats_dir is not None else None
|
| 152 |
+
self.channels = channels
|
| 153 |
+
self.stride = stride
|
| 154 |
+
self.latlon_resolution = latlon_resolution
|
| 155 |
+
self.interpolation_type = interpolation_type
|
| 156 |
+
self.num_steps = num_steps
|
| 157 |
+
self.num_history = num_history
|
| 158 |
+
self.num_samples_per_year = num_samples_per_year
|
| 159 |
+
self.use_cos_zenith = use_cos_zenith
|
| 160 |
+
self.cos_zenith_args = cos_zenith_args
|
| 161 |
+
self.use_time_of_year_index = use_time_of_year_index
|
| 162 |
+
self.process_rank = process_rank
|
| 163 |
+
self.world_size = world_size
|
| 164 |
+
|
| 165 |
+
# cos zenith defaults
|
| 166 |
+
if use_cos_zenith:
|
| 167 |
+
cos_zenith_args["dt"] = cos_zenith_args.get("dt", 6.0)
|
| 168 |
+
cos_zenith_args["start_year"] = cos_zenith_args.get("start_year", 1980)
|
| 169 |
+
cos_zenith_args["latlon_bounds"] = cos_zenith_args.get(
|
| 170 |
+
"latlon_bounds",
|
| 171 |
+
(
|
| 172 |
+
(90, -90),
|
| 173 |
+
(0, 360),
|
| 174 |
+
),
|
| 175 |
+
)
|
| 176 |
+
self.latlon_bounds = cos_zenith_args.get("latlon_bounds")
|
| 177 |
+
|
| 178 |
+
if isinstance(patch_size, int):
|
| 179 |
+
patch_size = (patch_size, patch_size)
|
| 180 |
+
self.patch_size = patch_size
|
| 181 |
+
|
| 182 |
+
# Set up device, needed for pipeline
|
| 183 |
+
if isinstance(device, str):
|
| 184 |
+
device = torch.device(device)
|
| 185 |
+
# Need a index id if cuda
|
| 186 |
+
if device.type == "cuda" and device.index is None:
|
| 187 |
+
device = torch.device("cuda:0")
|
| 188 |
+
self.device = device
|
| 189 |
+
|
| 190 |
+
# check root directory exists
|
| 191 |
+
if not self.data_dir.is_dir():
|
| 192 |
+
raise IOError(f"Error, data directory {self.data_dir} does not exist")
|
| 193 |
+
if self.stats_dir is not None and not self.stats_dir.is_dir():
|
| 194 |
+
raise IOError(f"Error, stats directory {self.stats_dir} does not exist")
|
| 195 |
+
|
| 196 |
+
# Check interpolation type
|
| 197 |
+
if self.interpolation_type is not None:
|
| 198 |
+
valid_interpolation = [
|
| 199 |
+
"INTERP_NN",
|
| 200 |
+
"INTERP_LINEAR",
|
| 201 |
+
"INTERP_CUBIC",
|
| 202 |
+
"INTERP_LANCZOS3",
|
| 203 |
+
"INTERP_TRIANGULAR",
|
| 204 |
+
"INTERP_GAUSSIAN",
|
| 205 |
+
]
|
| 206 |
+
if self.interpolation_type not in valid_interpolation:
|
| 207 |
+
raise ValueError(
|
| 208 |
+
f"Interpolation type {self.interpolation_type} not supported"
|
| 209 |
+
)
|
| 210 |
+
self.interpolation_type = getattr(dali.types, self.interpolation_type)
|
| 211 |
+
|
| 212 |
+
# Layout
|
| 213 |
+
# Avoiding API change for self.num_history == 0.
|
| 214 |
+
# Need to use FCHW layout in the future regardless of the num_history.
|
| 215 |
+
if self.num_history == 0:
|
| 216 |
+
self.layout = ["CHW", "FCHW"]
|
| 217 |
+
else:
|
| 218 |
+
self.layout = ["FCHW", "FCHW"]
|
| 219 |
+
|
| 220 |
+
self.output_keys = ["invar", "outvar"]
|
| 221 |
+
|
| 222 |
+
# Get latlon for zenith angle
|
| 223 |
+
if self.use_cos_zenith:
|
| 224 |
+
if not self.latlon_resolution:
|
| 225 |
+
raise ValueError("latlon_resolution must be set for cos zenith angle")
|
| 226 |
+
self.data_latlon = np.stack(
|
| 227 |
+
latlon_grid(bounds=self.latlon_bounds, shape=self.latlon_resolution),
|
| 228 |
+
axis=0,
|
| 229 |
+
)
|
| 230 |
+
self.latlon_dali = dali.types.Constant(self.data_latlon)
|
| 231 |
+
self.output_keys += ["cos_zenith"]
|
| 232 |
+
|
| 233 |
+
if self.use_time_of_year_index:
|
| 234 |
+
self.output_keys += ["time_of_year_idx"]
|
| 235 |
+
|
| 236 |
+
self.parse_dataset_files()
|
| 237 |
+
self.load_statistics()
|
| 238 |
+
|
| 239 |
+
self.pipe = self._create_pipeline()
|
| 240 |
+
|
| 241 |
+
def parse_dataset_files(self) -> None:
|
| 242 |
+
"""Parses the data directory for valid HDF5 files and determines training samples
|
| 243 |
+
|
| 244 |
+
Raises
|
| 245 |
+
------
|
| 246 |
+
ValueError
|
| 247 |
+
In channels specified or number of samples per year is not valid
|
| 248 |
+
"""
|
| 249 |
+
# get all input data files
|
| 250 |
+
self.data_paths = sorted(self.data_dir.glob("????.h5"))
|
| 251 |
+
for data_path in self.data_paths:
|
| 252 |
+
self.logger.info(f"ERA5 file found: {data_path}")
|
| 253 |
+
self.n_years = len(self.data_paths)
|
| 254 |
+
self.logger.info(f"Number of years: {self.n_years}")
|
| 255 |
+
|
| 256 |
+
# get total number of examples and image shape from the first file,
|
| 257 |
+
# assuming other files have exactly the same format.
|
| 258 |
+
self.logger.info(f"Getting file stats from {self.data_paths[0]}")
|
| 259 |
+
with h5py.File(self.data_paths[0], "r") as f:
|
| 260 |
+
# truncate the dataset to avoid out-of-range sampling and ensure each
|
| 261 |
+
# rank has same number of samples (to avoid deadlocks)
|
| 262 |
+
data_samples_per_year = (
|
| 263 |
+
(
|
| 264 |
+
f["fields"].shape[0]
|
| 265 |
+
- (self.num_steps + self.num_history) * self.stride
|
| 266 |
+
)
|
| 267 |
+
// self.world_size
|
| 268 |
+
) * self.world_size
|
| 269 |
+
if data_samples_per_year < 1:
|
| 270 |
+
raise ValueError(
|
| 271 |
+
f"Not enough number of samples per year ({data_samples_per_year})"
|
| 272 |
+
)
|
| 273 |
+
self.img_shape = f["fields"].shape[2:]
|
| 274 |
+
|
| 275 |
+
# If channels not provided, use all of them
|
| 276 |
+
if self.channels is None:
|
| 277 |
+
self.channels = [i for i in range(f["fields"].shape[1])]
|
| 278 |
+
|
| 279 |
+
# If num_samples_per_year use all
|
| 280 |
+
if self.num_samples_per_year is None:
|
| 281 |
+
self.num_samples_per_year = data_samples_per_year
|
| 282 |
+
|
| 283 |
+
# Adjust image shape if patch_size defined
|
| 284 |
+
if self.patch_size is not None:
|
| 285 |
+
if self.use_cos_zenith:
|
| 286 |
+
raise ValueError("Patching is not supported with cos zenith angle")
|
| 287 |
+
self.img_shape = [
|
| 288 |
+
s - s % self.patch_size[i] for i, s in enumerate(self.img_shape)
|
| 289 |
+
]
|
| 290 |
+
self.logger.info(f"Input image shape: {self.img_shape}")
|
| 291 |
+
|
| 292 |
+
# Get total length
|
| 293 |
+
self.total_length = self.n_years * self.num_samples_per_year
|
| 294 |
+
self.length = self.total_length
|
| 295 |
+
|
| 296 |
+
# Sanity checks
|
| 297 |
+
if max(self.channels) >= f["fields"].shape[1]:
|
| 298 |
+
raise ValueError(
|
| 299 |
+
f"Provided channel has indexes greater than the number \
|
| 300 |
+
of fields {f['fields'].shape[1]}"
|
| 301 |
+
)
|
| 302 |
+
|
| 303 |
+
if self.num_samples_per_year > data_samples_per_year:
|
| 304 |
+
raise ValueError(
|
| 305 |
+
f"num_samples_per_year ({self.num_samples_per_year}) > number of \
|
| 306 |
+
samples available ({data_samples_per_year})!"
|
| 307 |
+
)
|
| 308 |
+
|
| 309 |
+
self.logger.info(f"Number of samples/year: {self.num_samples_per_year}")
|
| 310 |
+
self.logger.info(f"Number of channels available: {f['fields'].shape[1]}")
|
| 311 |
+
|
| 312 |
+
def load_statistics(self) -> None:
|
| 313 |
+
"""Loads ERA5 statistics from pre-computed numpy files
|
| 314 |
+
|
| 315 |
+
The statistic files should be of name global_means.npy and global_std.npy with
|
| 316 |
+
a shape of [1, C, 1, 1] located in the stat_dir.
|
| 317 |
+
|
| 318 |
+
Raises
|
| 319 |
+
------
|
| 320 |
+
IOError
|
| 321 |
+
If mean or std numpy files are not found
|
| 322 |
+
AssertionError
|
| 323 |
+
If loaded numpy arrays are not of correct size
|
| 324 |
+
"""
|
| 325 |
+
# If no stats dir we just skip loading the stats
|
| 326 |
+
if self.stats_dir is None:
|
| 327 |
+
self.mu = None
|
| 328 |
+
self.std = None
|
| 329 |
+
return
|
| 330 |
+
# load normalisation values
|
| 331 |
+
mean_stat_file = self.stats_dir / Path("global_means.npy")
|
| 332 |
+
std_stat_file = self.stats_dir / Path("global_stds.npy")
|
| 333 |
+
|
| 334 |
+
if not mean_stat_file.exists():
|
| 335 |
+
raise IOError(f"Mean statistics file {mean_stat_file} not found")
|
| 336 |
+
if not std_stat_file.exists():
|
| 337 |
+
raise IOError(f"Std statistics file {std_stat_file} not found")
|
| 338 |
+
|
| 339 |
+
# has shape [1, C, 1, 1]
|
| 340 |
+
self.mu = np.load(str(mean_stat_file))[:, self.channels]
|
| 341 |
+
# has shape [1, C, 1, 1]
|
| 342 |
+
self.sd = np.load(str(std_stat_file))[:, self.channels]
|
| 343 |
+
|
| 344 |
+
if not self.mu.shape == self.sd.shape == (1, len(self.channels), 1, 1):
|
| 345 |
+
raise AssertionError("Error, normalisation arrays have wrong shape")
|
| 346 |
+
|
| 347 |
+
def _create_pipeline(self) -> dali.Pipeline:
|
| 348 |
+
"""Create DALI pipeline
|
| 349 |
+
|
| 350 |
+
Returns
|
| 351 |
+
-------
|
| 352 |
+
dali.Pipeline
|
| 353 |
+
HDF5 DALI pipeline
|
| 354 |
+
"""
|
| 355 |
+
pipe = dali.Pipeline(
|
| 356 |
+
batch_size=self.batch_size,
|
| 357 |
+
num_threads=2,
|
| 358 |
+
prefetch_queue_depth=2,
|
| 359 |
+
py_num_workers=self.num_workers,
|
| 360 |
+
device_id=self.device.index,
|
| 361 |
+
py_start_method="spawn",
|
| 362 |
+
)
|
| 363 |
+
|
| 364 |
+
with pipe:
|
| 365 |
+
source = ERA5DaliExternalSource(
|
| 366 |
+
data_paths=self.data_paths,
|
| 367 |
+
num_samples=self.total_length,
|
| 368 |
+
channels=self.channels,
|
| 369 |
+
stride=self.stride,
|
| 370 |
+
num_steps=self.num_steps,
|
| 371 |
+
num_history=self.num_history,
|
| 372 |
+
num_samples_per_year=self.num_samples_per_year,
|
| 373 |
+
use_cos_zenith=self.use_cos_zenith,
|
| 374 |
+
cos_zenith_args=self.cos_zenith_args,
|
| 375 |
+
use_time_of_year_index=self.use_time_of_year_index,
|
| 376 |
+
batch_size=self.batch_size,
|
| 377 |
+
shuffle=self.shuffle,
|
| 378 |
+
process_rank=self.process_rank,
|
| 379 |
+
world_size=self.world_size,
|
| 380 |
+
)
|
| 381 |
+
# Update length of dataset
|
| 382 |
+
self.length = len(source) // self.batch_size
|
| 383 |
+
# Read current batch.
|
| 384 |
+
invar, outvar, timestamps, time_of_year_idx = dali.fn.external_source(
|
| 385 |
+
source,
|
| 386 |
+
num_outputs=4,
|
| 387 |
+
parallel=True,
|
| 388 |
+
batch=False,
|
| 389 |
+
layout=self.layout,
|
| 390 |
+
)
|
| 391 |
+
if self.device.type == "cuda":
|
| 392 |
+
# Move tensors to GPU as external_source won't do that.
|
| 393 |
+
invar = invar.gpu()
|
| 394 |
+
outvar = outvar.gpu()
|
| 395 |
+
|
| 396 |
+
# Crop.
|
| 397 |
+
h, w = self.img_shape
|
| 398 |
+
if self.num_history == 0:
|
| 399 |
+
invar = invar[:, :h, :w]
|
| 400 |
+
else:
|
| 401 |
+
invar = invar[:, :, :h, :w]
|
| 402 |
+
outvar = outvar[:, :, :h, :w]
|
| 403 |
+
|
| 404 |
+
# Standardize.
|
| 405 |
+
if self.stats_dir is not None:
|
| 406 |
+
if self.num_history == 0:
|
| 407 |
+
invar = dali.fn.normalize(invar, mean=self.mu[0], stddev=self.sd[0])
|
| 408 |
+
else:
|
| 409 |
+
invar = dali.fn.normalize(invar, mean=self.mu, stddev=self.sd)
|
| 410 |
+
outvar = dali.fn.normalize(outvar, mean=self.mu, stddev=self.sd)
|
| 411 |
+
|
| 412 |
+
# Resize.
|
| 413 |
+
if self.interpolation_type is not None:
|
| 414 |
+
invar = dali.fn.resize(
|
| 415 |
+
invar,
|
| 416 |
+
resize_x=self.latlon_resolution[1],
|
| 417 |
+
resize_y=self.latlon_resolution[0],
|
| 418 |
+
interp_type=self.interpolation_type,
|
| 419 |
+
antialias=False,
|
| 420 |
+
)
|
| 421 |
+
outvar = dali.fn.resize(
|
| 422 |
+
outvar,
|
| 423 |
+
resize_x=self.latlon_resolution[1],
|
| 424 |
+
resize_y=self.latlon_resolution[0],
|
| 425 |
+
interp_type=self.interpolation_type,
|
| 426 |
+
antialias=False,
|
| 427 |
+
)
|
| 428 |
+
|
| 429 |
+
# cos zenith angle
|
| 430 |
+
if self.use_cos_zenith:
|
| 431 |
+
cos_zenith = dali.fn.cast(
|
| 432 |
+
cos_zenith_angle(timestamps, latlon=self.latlon_dali),
|
| 433 |
+
dtype=dali.types.FLOAT,
|
| 434 |
+
)
|
| 435 |
+
if self.device.type == "cuda":
|
| 436 |
+
cos_zenith = cos_zenith.gpu()
|
| 437 |
+
|
| 438 |
+
# # Time of the year
|
| 439 |
+
# time_of_year_idx = dali.fn.cast(
|
| 440 |
+
# time_of_year_idx,
|
| 441 |
+
# dtype=dali.types.UINT32,
|
| 442 |
+
# )
|
| 443 |
+
|
| 444 |
+
# Set outputs.
|
| 445 |
+
outputs = (invar, outvar)
|
| 446 |
+
if self.use_cos_zenith:
|
| 447 |
+
outputs += (cos_zenith,)
|
| 448 |
+
if self.use_time_of_year_index:
|
| 449 |
+
outputs += (time_of_year_idx,)
|
| 450 |
+
pipe.set_outputs(*outputs)
|
| 451 |
+
|
| 452 |
+
return pipe
|
| 453 |
+
|
| 454 |
+
def __iter__(self):
|
| 455 |
+
# Reset the pipeline before creating an iterator to enable epochs.
|
| 456 |
+
self.pipe.reset()
|
| 457 |
+
# Create DALI PyTorch iterator.
|
| 458 |
+
return dali_pth.DALIGenericIterator([self.pipe], self.output_keys)
|
| 459 |
+
|
| 460 |
+
def __len__(self):
|
| 461 |
+
return self.length
|
| 462 |
+
|
| 463 |
+
|
| 464 |
+
class ERA5DaliExternalSource:
|
| 465 |
+
"""DALI Source for lazy-loading the HDF5 ERA5 files
|
| 466 |
+
|
| 467 |
+
Parameters
|
| 468 |
+
----------
|
| 469 |
+
data_paths : Iterable[str]
|
| 470 |
+
Directory where ERA5 data is stored
|
| 471 |
+
num_samples : int
|
| 472 |
+
Total number of training samples
|
| 473 |
+
channels : Iterable[int]
|
| 474 |
+
List representing which ERA5 variables to load
|
| 475 |
+
start_year : int, optional
|
| 476 |
+
Start year of dataset
|
| 477 |
+
stride : int
|
| 478 |
+
Number of steps between input and output variables
|
| 479 |
+
num_steps : int
|
| 480 |
+
Number of timesteps are included in the output variables
|
| 481 |
+
num_history : int
|
| 482 |
+
Number of previous timesteps included in the input variables
|
| 483 |
+
num_samples_per_year : int
|
| 484 |
+
Number of samples randomly taken from each year
|
| 485 |
+
batch_size : int, optional
|
| 486 |
+
Batch size, by default 1
|
| 487 |
+
use_cos_zenith: bool
|
| 488 |
+
If True, the cosine zenith angles corresponding to the coordinates will be produced
|
| 489 |
+
cos_zenith_args: Dict
|
| 490 |
+
Dictionary containing the following:
|
| 491 |
+
|
| 492 |
+
dt: float
|
| 493 |
+
Time in hours between each timestep in the dataset
|
| 494 |
+
|
| 495 |
+
start_year: int
|
| 496 |
+
Start year of dataset
|
| 497 |
+
shuffle : bool, optional
|
| 498 |
+
Shuffle dataset, by default True
|
| 499 |
+
process_rank : int, optional
|
| 500 |
+
Rank ID of local process, by default 0
|
| 501 |
+
world_size : int, optional
|
| 502 |
+
Number of training processes, by default 1
|
| 503 |
+
|
| 504 |
+
Note
|
| 505 |
+
----
|
| 506 |
+
For more information about DALI external source operator:
|
| 507 |
+
https://docs.nvidia.com/deeplearning/dali/archives/dali_1_13_0/user-guide/docs/examples/general/data_loading/parallel_external_source.html
|
| 508 |
+
"""
|
| 509 |
+
|
| 510 |
+
def __init__(
|
| 511 |
+
self,
|
| 512 |
+
data_paths: Iterable[str],
|
| 513 |
+
num_samples: int,
|
| 514 |
+
channels: Iterable[int],
|
| 515 |
+
num_steps: int,
|
| 516 |
+
num_history: int,
|
| 517 |
+
stride: int,
|
| 518 |
+
num_samples_per_year: int,
|
| 519 |
+
use_cos_zenith: bool,
|
| 520 |
+
cos_zenith_args: Dict,
|
| 521 |
+
use_time_of_year_index: bool,
|
| 522 |
+
batch_size: int = 1,
|
| 523 |
+
shuffle: bool = True,
|
| 524 |
+
process_rank: int = 0,
|
| 525 |
+
world_size: int = 1,
|
| 526 |
+
):
|
| 527 |
+
self.data_paths = list(data_paths)
|
| 528 |
+
# Will be populated later once each worker starts running in its own process.
|
| 529 |
+
self.data_files = None
|
| 530 |
+
self.num_samples = num_samples
|
| 531 |
+
self.chans = list(channels)
|
| 532 |
+
self.num_steps = num_steps
|
| 533 |
+
self.num_history = num_history
|
| 534 |
+
self.stride = stride
|
| 535 |
+
self.num_samples_per_year = num_samples_per_year
|
| 536 |
+
self.use_cos_zenith = use_cos_zenith
|
| 537 |
+
self.use_time_of_year_index = use_time_of_year_index
|
| 538 |
+
self.batch_size = batch_size
|
| 539 |
+
self.shuffle = shuffle
|
| 540 |
+
|
| 541 |
+
self.last_epoch = None
|
| 542 |
+
|
| 543 |
+
self.indices = np.arange(num_samples)
|
| 544 |
+
# Shard from indices if running in parallel
|
| 545 |
+
self.indices = np.array_split(self.indices, world_size)[process_rank]
|
| 546 |
+
|
| 547 |
+
# Get number of full batches, ignore possible last incomplete batch for now.
|
| 548 |
+
# Also, DALI external source does not support incomplete batches in parallel mode.
|
| 549 |
+
self.num_batches = len(self.indices) // self.batch_size
|
| 550 |
+
|
| 551 |
+
# cos zenith args
|
| 552 |
+
if self.use_cos_zenith:
|
| 553 |
+
self.dt: float = cos_zenith_args.get("dt")
|
| 554 |
+
self.start_year: int = cos_zenith_args.get("start_year")
|
| 555 |
+
|
| 556 |
+
def __call__(
|
| 557 |
+
self, sample_info: dali.types.SampleInfo
|
| 558 |
+
) -> Tuple[Tensor, Tensor, np.ndarray]:
|
| 559 |
+
if sample_info.iteration >= self.num_batches:
|
| 560 |
+
raise StopIteration()
|
| 561 |
+
|
| 562 |
+
if self.data_files is None:
|
| 563 |
+
# This will be called once per worker. Workers are persistent,
|
| 564 |
+
# so there is no need to explicitly close the files - this will be done
|
| 565 |
+
# when corresponding pipeline/dataset is destroyed.
|
| 566 |
+
self.data_files = [h5py.File(path, "r") for path in self.data_paths]
|
| 567 |
+
|
| 568 |
+
# Shuffle before the next epoch starts.
|
| 569 |
+
if self.shuffle and sample_info.epoch_idx != self.last_epoch:
|
| 570 |
+
# All workers use the same rng seed so the resulting
|
| 571 |
+
# indices are the same across workers.
|
| 572 |
+
np.random.default_rng(seed=sample_info.epoch_idx).shuffle(self.indices)
|
| 573 |
+
self.last_epoch = sample_info.epoch_idx
|
| 574 |
+
|
| 575 |
+
# Get local indices from global index.
|
| 576 |
+
idx = self.indices[sample_info.idx_in_epoch]
|
| 577 |
+
year_idx = idx // self.num_samples_per_year
|
| 578 |
+
in_idx = idx % self.num_samples_per_year
|
| 579 |
+
|
| 580 |
+
# Load sequence of timestamps
|
| 581 |
+
if self.use_cos_zenith:
|
| 582 |
+
year = self.start_year + year_idx
|
| 583 |
+
start_time = datetime(year, 1, 1, tzinfo=pytz.utc) + timedelta(
|
| 584 |
+
hours=int(in_idx) * self.dt
|
| 585 |
+
)
|
| 586 |
+
timestamps = np.array(
|
| 587 |
+
[
|
| 588 |
+
(
|
| 589 |
+
start_time + timedelta(hours=i * self.stride * self.dt)
|
| 590 |
+
).timestamp()
|
| 591 |
+
for i in range(self.num_history + self.num_steps + 1)
|
| 592 |
+
]
|
| 593 |
+
)
|
| 594 |
+
else:
|
| 595 |
+
timestamps = np.array([])
|
| 596 |
+
if self.use_time_of_year_index:
|
| 597 |
+
time_of_year_idx = in_idx
|
| 598 |
+
else:
|
| 599 |
+
time_of_year_idx = -1
|
| 600 |
+
|
| 601 |
+
data = self.data_files[year_idx]["fields"]
|
| 602 |
+
if self.num_history == 0:
|
| 603 |
+
# Has [C,H,W] shape.
|
| 604 |
+
invar = data[in_idx, self.chans]
|
| 605 |
+
else:
|
| 606 |
+
# Has [T,C,H,W] shape.
|
| 607 |
+
invar = data[
|
| 608 |
+
in_idx : in_idx + (self.num_history + 1) * self.stride : self.stride,
|
| 609 |
+
self.chans,
|
| 610 |
+
]
|
| 611 |
+
|
| 612 |
+
# Has [T,C,H,W] shape.
|
| 613 |
+
outvar = np.empty((self.num_steps,) + invar.shape[-3:], dtype=invar.dtype)
|
| 614 |
+
|
| 615 |
+
for i in range(self.num_steps):
|
| 616 |
+
out_idx = in_idx + (self.num_history + i + 1) * self.stride
|
| 617 |
+
outvar[i] = data[out_idx, self.chans]
|
| 618 |
+
|
| 619 |
+
return invar, outvar, timestamps, np.array([time_of_year_idx])
|
| 620 |
+
|
| 621 |
+
def __len__(self):
|
| 622 |
+
return len(self.indices)
|
physics_mcp/source/physicsnemo/datapipes/climate/era5_netcdf.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
physics_mcp/source/physicsnemo/datapipes/climate/synthetic.py
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
import time
|
| 19 |
+
from typing import Any, Dict, List, Tuple
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
from torch.utils.data import DataLoader, Dataset
|
| 24 |
+
|
| 25 |
+
|
| 26 |
+
class SyntheticWeatherDataLoader(DataLoader):
|
| 27 |
+
"""
|
| 28 |
+
This custom DataLoader initializes the SyntheticWeatherDataset with given arguments.
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, *args, **kwargs):
|
| 32 |
+
dataset = SyntheticWeatherDataset(*args, **kwargs)
|
| 33 |
+
super().__init__(
|
| 34 |
+
dataset=dataset,
|
| 35 |
+
batch_size=kwargs.get("batch_size", 1),
|
| 36 |
+
shuffle=kwargs.get("shuffle", False),
|
| 37 |
+
num_workers=kwargs.get("num_workers", 0),
|
| 38 |
+
pin_memory=kwargs.get("pin_memory", False),
|
| 39 |
+
drop_last=kwargs.get("drop_last", False),
|
| 40 |
+
)
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class SyntheticWeatherDataset(Dataset):
|
| 44 |
+
"""
|
| 45 |
+
A dataset for generating synthetic temperature data on a latitude-longitude grid for multiple atmospheric layers.
|
| 46 |
+
|
| 47 |
+
Args:
|
| 48 |
+
channels (list): List of channels representing different atmospheric layers.
|
| 49 |
+
num_samples_per_year (int): Total number of days to simulate per year.
|
| 50 |
+
num_steps (int): Number of consecutive days in each training sample.
|
| 51 |
+
grid_size (tuple): Latitude by longitude dimensions of the temperature grid.
|
| 52 |
+
base_temp (float): Base temperature around which variations are simulated.
|
| 53 |
+
amplitude (float): Amplitude of the sinusoidal temperature variation.
|
| 54 |
+
noise_level (float): Standard deviation of the noise added to temperature data.
|
| 55 |
+
**kwargs: Additional keyword arguments for advanced configurations.
|
| 56 |
+
"""
|
| 57 |
+
|
| 58 |
+
def __init__(
|
| 59 |
+
self,
|
| 60 |
+
channels: List[int],
|
| 61 |
+
num_samples_per_year: int,
|
| 62 |
+
num_steps: int,
|
| 63 |
+
device: str | torch.device = "cuda",
|
| 64 |
+
grid_size: Tuple[int, int] = (721, 1440),
|
| 65 |
+
base_temp: float = 15,
|
| 66 |
+
amplitude: float = 10,
|
| 67 |
+
noise_level: float = 2,
|
| 68 |
+
**kwargs: Any,
|
| 69 |
+
):
|
| 70 |
+
self.num_days: int = num_samples_per_year
|
| 71 |
+
self.num_steps: int = num_steps
|
| 72 |
+
self.num_channels: int = len(channels)
|
| 73 |
+
self.device = device
|
| 74 |
+
self.grid_size: Tuple[int, int] = grid_size
|
| 75 |
+
start_time = time.time()
|
| 76 |
+
self.temperatures: np.ndarray = self.generate_data(
|
| 77 |
+
self.num_days,
|
| 78 |
+
self.num_channels,
|
| 79 |
+
self.grid_size,
|
| 80 |
+
base_temp,
|
| 81 |
+
amplitude,
|
| 82 |
+
noise_level,
|
| 83 |
+
)
|
| 84 |
+
print(
|
| 85 |
+
f"Generated synthetic temperature data in {time.time() - start_time:.2f} seconds."
|
| 86 |
+
)
|
| 87 |
+
self.extra_args: Dict[str, Any] = kwargs
|
| 88 |
+
|
| 89 |
+
def generate_data(
|
| 90 |
+
self,
|
| 91 |
+
num_days: int,
|
| 92 |
+
num_channels: int,
|
| 93 |
+
grid_size: Tuple[int, int],
|
| 94 |
+
base_temp: float,
|
| 95 |
+
amplitude: float,
|
| 96 |
+
noise_level: float,
|
| 97 |
+
) -> np.ndarray:
|
| 98 |
+
"""
|
| 99 |
+
Generates synthetic temperature data over a specified number of days for multiple atmospheric layers.
|
| 100 |
+
|
| 101 |
+
Args:
|
| 102 |
+
num_days (int): Number of days to generate data for.
|
| 103 |
+
num_channels (int): Number of channels representing different layers.
|
| 104 |
+
grid_size (tuple): Grid size (latitude, longitude).
|
| 105 |
+
base_temp (float): Base mean temperature for the data.
|
| 106 |
+
amplitude (float): Amplitude of temperature variations.
|
| 107 |
+
noise_level (float): Noise level to add stochasticity to the temperature.
|
| 108 |
+
|
| 109 |
+
Returns:
|
| 110 |
+
numpy.ndarray: A 4D array of temperature values across days, channels, latitudes, and longitudes.
|
| 111 |
+
"""
|
| 112 |
+
days = np.arange(num_days)
|
| 113 |
+
latitudes, longitudes = grid_size
|
| 114 |
+
|
| 115 |
+
# Create altitude effect and reshape
|
| 116 |
+
altitude_effect = np.arange(num_channels) * -0.5
|
| 117 |
+
altitude_effect = altitude_effect[
|
| 118 |
+
:, np.newaxis, np.newaxis
|
| 119 |
+
] # Shape: (num_channels, 1, 1)
|
| 120 |
+
altitude_effect = np.tile(
|
| 121 |
+
altitude_effect, (1, latitudes, longitudes)
|
| 122 |
+
) # Shape: (num_channels, latitudes, longitudes)
|
| 123 |
+
altitude_effect = altitude_effect[
|
| 124 |
+
np.newaxis, :, :, :
|
| 125 |
+
] # Shape: (1, num_channels, latitudes, longitudes)
|
| 126 |
+
altitude_effect = np.tile(
|
| 127 |
+
altitude_effect, (num_days, 1, 1, 1)
|
| 128 |
+
) # Shape: (num_days, num_channels, latitudes, longitudes)
|
| 129 |
+
|
| 130 |
+
# Create latitude variation and reshape
|
| 131 |
+
lat_variation = np.linspace(-amplitude, amplitude, latitudes)
|
| 132 |
+
lat_variation = lat_variation[:, np.newaxis] # Shape: (latitudes, 1)
|
| 133 |
+
lat_variation = np.tile(
|
| 134 |
+
lat_variation, (1, longitudes)
|
| 135 |
+
) # Shape: (latitudes, longitudes)
|
| 136 |
+
lat_variation = lat_variation[
|
| 137 |
+
np.newaxis, np.newaxis, :, :
|
| 138 |
+
] # Shape: (1, 1, latitudes, longitudes)
|
| 139 |
+
lat_variation = np.tile(
|
| 140 |
+
lat_variation, (num_days, num_channels, 1, 1)
|
| 141 |
+
) # Shape: (num_days, num_channels, latitudes, longitudes)
|
| 142 |
+
|
| 143 |
+
# Create time effect and reshape
|
| 144 |
+
time_effect = np.sin(2 * np.pi * days / 365)
|
| 145 |
+
time_effect = time_effect[
|
| 146 |
+
:, np.newaxis, np.newaxis, np.newaxis
|
| 147 |
+
] # Shape: (num_days, 1, 1, 1)
|
| 148 |
+
time_effect = np.tile(
|
| 149 |
+
time_effect, (1, num_channels, latitudes, longitudes)
|
| 150 |
+
) # Shape: (num_days, num_channels, latitudes, longitudes)
|
| 151 |
+
|
| 152 |
+
# Generate noise
|
| 153 |
+
noise = np.random.normal(
|
| 154 |
+
scale=noise_level, size=(num_days, num_channels, latitudes, longitudes)
|
| 155 |
+
)
|
| 156 |
+
|
| 157 |
+
# Calculate daily temperatures
|
| 158 |
+
daily_temps = base_temp + altitude_effect + lat_variation + time_effect + noise
|
| 159 |
+
|
| 160 |
+
return daily_temps
|
| 161 |
+
|
| 162 |
+
def __len__(self) -> int:
|
| 163 |
+
"""
|
| 164 |
+
Returns the number of samples available in the dataset.
|
| 165 |
+
"""
|
| 166 |
+
return self.num_days - self.num_steps
|
| 167 |
+
|
| 168 |
+
def __getitem__(self, idx: int) -> torch.Tensor:
|
| 169 |
+
"""
|
| 170 |
+
Retrieves a sample from the dataset at the specified index.
|
| 171 |
+
"""
|
| 172 |
+
return [
|
| 173 |
+
{
|
| 174 |
+
"invar": torch.tensor(self.temperatures[idx], dtype=torch.float32).to(
|
| 175 |
+
self.device
|
| 176 |
+
),
|
| 177 |
+
"outvar": torch.tensor(
|
| 178 |
+
self.temperatures[idx + 1 : idx + self.num_steps + 1],
|
| 179 |
+
dtype=torch.float32,
|
| 180 |
+
).to(self.device),
|
| 181 |
+
}
|
| 182 |
+
]
|
physics_mcp/source/physicsnemo/datapipes/climate/utils/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
physics_mcp/source/physicsnemo/datapipes/climate/utils/invariant.py
ADDED
|
@@ -0,0 +1,139 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
from abc import ABC, abstractmethod
|
| 18 |
+
from typing import List, Tuple
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import xarray as xr
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def latlon_grid(
|
| 25 |
+
bounds: Tuple[Tuple[float, float], Tuple[float, float]] = (
|
| 26 |
+
(90, -90),
|
| 27 |
+
(0, 360),
|
| 28 |
+
),
|
| 29 |
+
shape: Tuple[int, int] = (1440, 721),
|
| 30 |
+
) -> np.ndarray:
|
| 31 |
+
"""Infer latitude and longitude coordinates from bounds and data shape on a
|
| 32 |
+
equirectangular grid."""
|
| 33 |
+
|
| 34 |
+
# get latitudes and longitudes from data shape
|
| 35 |
+
lat = np.linspace(*bounds[0], shape[0], dtype=np.float32)
|
| 36 |
+
|
| 37 |
+
# does longitude wrap around the globe?
|
| 38 |
+
lon_wraparound = (bounds[1][0] % 360) == (bounds[1][1] % 360)
|
| 39 |
+
if lon_wraparound:
|
| 40 |
+
# treat differently from lat due to wrap-around
|
| 41 |
+
lon = np.linspace(*bounds[1], shape[1] + 1, dtype=np.float32)[:-1]
|
| 42 |
+
else:
|
| 43 |
+
lon = np.linspace(*bounds[1], shape[1], dtype=np.float32)
|
| 44 |
+
|
| 45 |
+
return np.meshgrid(lat, lon, indexing="ij")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class Invariant(ABC):
|
| 49 |
+
"""Invariant abstract class representing data that is invariant to inputs on load"""
|
| 50 |
+
|
| 51 |
+
@abstractmethod
|
| 52 |
+
def __call__(self, latlon: np.ndarray):
|
| 53 |
+
pass
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class LatLon(Invariant):
|
| 57 |
+
"""Time invariant latitude and longitude coordinates and trig functions"""
|
| 58 |
+
|
| 59 |
+
def __init__(
|
| 60 |
+
self, outputs: List[str] = ("sin_lat", "cos_lat", "sin_lon", "cos_lon")
|
| 61 |
+
):
|
| 62 |
+
"""
|
| 63 |
+
Outputs latitude and longitude and their trigonometric functions.
|
| 64 |
+
|
| 65 |
+
Parameters
|
| 66 |
+
----------
|
| 67 |
+
outputs: List[str]
|
| 68 |
+
List of outputs. Supported values are
|
| 69 |
+
`{"lat", "lon", "sin_lat", "cos_lat", "sin_lon", "cos_lon"}`
|
| 70 |
+
"""
|
| 71 |
+
self.outputs = outputs
|
| 72 |
+
|
| 73 |
+
def __call__(self, latlon: np.ndarray):
|
| 74 |
+
(lat, lon) = latlon
|
| 75 |
+
|
| 76 |
+
vars = {"lat": lat, "lon": lon}
|
| 77 |
+
|
| 78 |
+
# cos/sin latitudes and longitudes
|
| 79 |
+
if "sin_lat" in self.outputs:
|
| 80 |
+
vars["sin_lat"] = np.sin(np.deg2rad(lat))
|
| 81 |
+
if "cos_lat" in self.outputs:
|
| 82 |
+
vars["cos_lat"] = np.cos(np.deg2rad(lat))
|
| 83 |
+
if "sin_lon" in self.outputs:
|
| 84 |
+
vars["sin_lon"] = np.sin(np.deg2rad(lon))
|
| 85 |
+
if "cos_lon" in self.outputs:
|
| 86 |
+
vars["cos_lon"] = np.cos(np.deg2rad(lon))
|
| 87 |
+
|
| 88 |
+
return np.stack([vars[o] for o in self.outputs], axis=0)
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class FileInvariant(Invariant):
|
| 92 |
+
"""
|
| 93 |
+
Loads an time-invariant variable from a NetCDF4 file. The file should
|
| 94 |
+
contain one or more data variables of dimensions
|
| 95 |
+
`(channels, latitude, longitude)` as well as variables `latitude` and
|
| 96 |
+
`longitude` specifying these coordinates. `latitude` and `longitude`
|
| 97 |
+
can be either 2D or 1D.
|
| 98 |
+
|
| 99 |
+
Parameters
|
| 100 |
+
----------
|
| 101 |
+
filename: str
|
| 102 |
+
Path to the file containing the variable
|
| 103 |
+
var_name: str
|
| 104 |
+
The variable in the file containing the data
|
| 105 |
+
normalize: bool, optional
|
| 106 |
+
If True, normalize the data by to zero-mean and unit variance.
|
| 107 |
+
Default False.
|
| 108 |
+
interp_method: str, optional
|
| 109 |
+
Any argument accepted by xarray.DataArray.interp.
|
| 110 |
+
Default 'linear'.
|
| 111 |
+
"""
|
| 112 |
+
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
filename: str,
|
| 116 |
+
var_name: str,
|
| 117 |
+
normalize=False,
|
| 118 |
+
interp_method="linear",
|
| 119 |
+
):
|
| 120 |
+
with xr.open_dataset(filename) as ds:
|
| 121 |
+
self.data = ds[var_name].astype(np.float32)
|
| 122 |
+
self.lat = ds["latitude"].to_numpy().astype(np.float32)
|
| 123 |
+
self.lon = ds["longitude"].to_numpy().astype(np.float32)
|
| 124 |
+
|
| 125 |
+
if self.lat.ndim == 1:
|
| 126 |
+
(self.lat, self.lon) = np.meshgrid(self.lat, self.lon, indexing="ij")
|
| 127 |
+
|
| 128 |
+
if normalize:
|
| 129 |
+
self.data = (self.data - self.data.mean()) / self.data.std()
|
| 130 |
+
|
| 131 |
+
self.interp_method = interp_method
|
| 132 |
+
|
| 133 |
+
def __call__(self, latlon: np.ndarray):
|
| 134 |
+
(lat, lon) = latlon
|
| 135 |
+
lat = xr.DataArray(lat, dims=["latitude", "longitude"])
|
| 136 |
+
lon = xr.DataArray(lon, dims=["latitude", "longitude"])
|
| 137 |
+
return self.data.interp(
|
| 138 |
+
method=self.interp_method, latitude=lat, longitude=lon
|
| 139 |
+
).to_numpy()
|
physics_mcp/source/physicsnemo/datapipes/climate/utils/zenith_angle.py
ADDED
|
@@ -0,0 +1,208 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# ignore_header_test
|
| 2 |
+
|
| 3 |
+
# climt/LICENSE
|
| 4 |
+
# @mcgibbon
|
| 5 |
+
# BSD License
|
| 6 |
+
# Copyright (c) 2016, Rodrigo Caballero
|
| 7 |
+
# All rights reserved.
|
| 8 |
+
# Redistribution and use in source and binary forms, with or without modification,
|
| 9 |
+
# are permitted provided that the following conditions are met:
|
| 10 |
+
# * Redistributions of source code must retain the above copyright notice, this
|
| 11 |
+
# list of conditions and the following disclaimer.
|
| 12 |
+
# * Redistributions in binary form must reproduce the above copyright notice, this
|
| 13 |
+
# list of conditions and the following disclaimer in the documentation and/or
|
| 14 |
+
# other materials provided with the distribution.
|
| 15 |
+
# * Neither the name of the copyright holder nor the names of its
|
| 16 |
+
# contributors may be used to endorse or promote products derived from this
|
| 17 |
+
# software without specific prior written permission.
|
| 18 |
+
# THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 19 |
+
# ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 20 |
+
# WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE DISCLAIMED.
|
| 21 |
+
# IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE FOR ANY DIRECT,
|
| 22 |
+
# INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL DAMAGES (INCLUDING,
|
| 23 |
+
# BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR SERVICES; LOSS OF USE,
|
| 24 |
+
# DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER CAUSED AND ON ANY THEORY
|
| 25 |
+
# OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY, OR TORT (INCLUDING NEGLIGENCE
|
| 26 |
+
# OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE OF THIS SOFTWARE, EVEN IF ADVISED
|
| 27 |
+
# OF THE POSSIBILITY OF SUCH DAMAGE.
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
import datetime
|
| 31 |
+
|
| 32 |
+
import numpy as np
|
| 33 |
+
import pytz
|
| 34 |
+
|
| 35 |
+
try:
|
| 36 |
+
import nvidia.dali as dali
|
| 37 |
+
except ImportError:
|
| 38 |
+
raise ImportError(
|
| 39 |
+
"DALI dataset requires NVIDIA DALI package to be installed. "
|
| 40 |
+
+ "The package can be installed at:\n"
|
| 41 |
+
+ "https://docs.nvidia.com/deeplearning/dali/user-guide/docs/installation.html"
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
RAD_PER_DEG = np.pi / 180.0
|
| 45 |
+
DATETIME_2000 = datetime.datetime(2000, 1, 1, 12, 0, 0, tzinfo=pytz.utc).timestamp()
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
def _dali_mod(a, b):
|
| 49 |
+
return a - b * dali.math.floor(a / b)
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def cos_zenith_angle(
|
| 53 |
+
time: dali.types.DALIDataType,
|
| 54 |
+
latlon: dali.types.DALIDataType,
|
| 55 |
+
):
|
| 56 |
+
"""
|
| 57 |
+
Dali datapipe for computing Cosine of sun-zenith angle for lon, lat at time (UTC).
|
| 58 |
+
|
| 59 |
+
Parameters
|
| 60 |
+
----------
|
| 61 |
+
time : dali.types.DALIDataType
|
| 62 |
+
Time in seconds since 2000-01-01 12:00:00 UTC. Shape `(seq_length,)`.
|
| 63 |
+
latlon : dali.types.DALIDataType
|
| 64 |
+
Latitude and longitude in degrees. Shape `(2, nr_lat, nr_lon)`.
|
| 65 |
+
|
| 66 |
+
Returns
|
| 67 |
+
-------
|
| 68 |
+
dali.types.DALIDataType
|
| 69 |
+
Cosine of sun-zenith angle. Shape `(seq_length, 1, nr_lat, nr_lon)`.
|
| 70 |
+
"""
|
| 71 |
+
lat = latlon[dali.newaxis, 0:1, :, :] * RAD_PER_DEG
|
| 72 |
+
lon = latlon[dali.newaxis, 1:2, :, :] * RAD_PER_DEG
|
| 73 |
+
time = time[:, dali.newaxis, dali.newaxis, dali.newaxis]
|
| 74 |
+
return _star_cos_zenith(time, lat, lon)
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def _days_from_2000(model_time): # pragma: no cover
|
| 78 |
+
"""Get the days since year 2000."""
|
| 79 |
+
return (model_time - DATETIME_2000) / (24.0 * 3600.0)
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
def _greenwich_mean_sidereal_time(model_time):
|
| 83 |
+
"""
|
| 84 |
+
Greenwich mean sidereal time, in radians.
|
| 85 |
+
Reference:
|
| 86 |
+
The AIAA 2006 implementation:
|
| 87 |
+
http://www.celestrak.com/publications/AIAA/2006-6753/
|
| 88 |
+
"""
|
| 89 |
+
jul_centuries = _days_from_2000(model_time) / 36525.0
|
| 90 |
+
theta = 67310.54841 + jul_centuries * (
|
| 91 |
+
876600 * 3600
|
| 92 |
+
+ 8640184.812866
|
| 93 |
+
+ jul_centuries * (0.093104 - jul_centuries * 6.2 * 10e-6)
|
| 94 |
+
)
|
| 95 |
+
|
| 96 |
+
theta_radians = _dali_mod((theta / 240.0) * RAD_PER_DEG, 2 * np.pi)
|
| 97 |
+
return theta_radians
|
| 98 |
+
|
| 99 |
+
|
| 100 |
+
def _local_mean_sidereal_time(model_time, longitude):
|
| 101 |
+
"""
|
| 102 |
+
Local mean sidereal time. requires longitude in radians.
|
| 103 |
+
Ref:
|
| 104 |
+
http://www.setileague.org/askdr/lmst.htm
|
| 105 |
+
"""
|
| 106 |
+
return _greenwich_mean_sidereal_time(model_time) + longitude
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
def _sun_ecliptic_longitude(model_time):
|
| 110 |
+
"""
|
| 111 |
+
Ecliptic longitude of the sun.
|
| 112 |
+
Reference:
|
| 113 |
+
http://www.geoastro.de/elevaz/basics/meeus.htm
|
| 114 |
+
"""
|
| 115 |
+
julian_centuries = _days_from_2000(model_time) / 36525.0
|
| 116 |
+
|
| 117 |
+
# mean anomaly calculation
|
| 118 |
+
mean_anomaly = (
|
| 119 |
+
357.52910
|
| 120 |
+
+ 35999.05030 * julian_centuries
|
| 121 |
+
- 0.0001559 * julian_centuries * julian_centuries
|
| 122 |
+
- 0.00000048 * julian_centuries * julian_centuries * julian_centuries
|
| 123 |
+
) * RAD_PER_DEG
|
| 124 |
+
|
| 125 |
+
# mean longitude
|
| 126 |
+
mean_longitude = (
|
| 127 |
+
280.46645 + 36000.76983 * julian_centuries + 0.0003032 * (julian_centuries**2)
|
| 128 |
+
) * RAD_PER_DEG
|
| 129 |
+
|
| 130 |
+
d_l = (
|
| 131 |
+
(1.914600 - 0.004817 * julian_centuries - 0.000014 * (julian_centuries**2))
|
| 132 |
+
* dali.math.sin(mean_anomaly)
|
| 133 |
+
+ (0.019993 - 0.000101 * julian_centuries) * dali.math.sin(2 * mean_anomaly)
|
| 134 |
+
+ 0.000290 * dali.math.sin(3 * mean_anomaly)
|
| 135 |
+
) * RAD_PER_DEG
|
| 136 |
+
|
| 137 |
+
# true longitude
|
| 138 |
+
return mean_longitude + d_l
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def _obliquity_star(julian_centuries):
|
| 142 |
+
"""
|
| 143 |
+
return obliquity of the sun
|
| 144 |
+
Use 5th order equation from
|
| 145 |
+
https://en.wikipedia.org/wiki/Ecliptic#Obliquity_of_the_ecliptic
|
| 146 |
+
"""
|
| 147 |
+
return (
|
| 148 |
+
23.0
|
| 149 |
+
+ 26.0 / 60
|
| 150 |
+
+ 21.406 / 3600.0
|
| 151 |
+
- (
|
| 152 |
+
46.836769 * julian_centuries
|
| 153 |
+
- 0.0001831 * (julian_centuries**2)
|
| 154 |
+
+ 0.00200340 * (julian_centuries**3)
|
| 155 |
+
- 0.576e-6 * (julian_centuries**4)
|
| 156 |
+
- 4.34e-8 * (julian_centuries**5)
|
| 157 |
+
)
|
| 158 |
+
/ 3600.0
|
| 159 |
+
) * RAD_PER_DEG
|
| 160 |
+
|
| 161 |
+
|
| 162 |
+
def _right_ascension_declination(model_time):
|
| 163 |
+
"""
|
| 164 |
+
Right ascension and declination of the sun.
|
| 165 |
+
"""
|
| 166 |
+
julian_centuries = _days_from_2000(model_time) / 36525.0
|
| 167 |
+
eps = _obliquity_star(julian_centuries)
|
| 168 |
+
|
| 169 |
+
eclon = _sun_ecliptic_longitude(model_time)
|
| 170 |
+
x = dali.math.cos(eclon)
|
| 171 |
+
y = dali.math.cos(eps) * dali.math.sin(eclon)
|
| 172 |
+
z = dali.math.sin(eps) * dali.math.sin(eclon)
|
| 173 |
+
r = dali.math.sqrt(1.0 - z * z)
|
| 174 |
+
# sun declination
|
| 175 |
+
declination = dali.math.atan2(z, r)
|
| 176 |
+
# right ascension
|
| 177 |
+
right_ascension = 2 * dali.math.atan2(y, (x + r))
|
| 178 |
+
return right_ascension, declination
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def _local_hour_angle(model_time, longitude, right_ascension):
|
| 182 |
+
"""
|
| 183 |
+
Hour angle at model_time for the given longitude and right_ascension
|
| 184 |
+
longitude in radians
|
| 185 |
+
Ref:
|
| 186 |
+
https://en.wikipedia.org/wiki/Hour_angle#Relation_with_the_right_ascension
|
| 187 |
+
"""
|
| 188 |
+
return _local_mean_sidereal_time(model_time, longitude) - right_ascension
|
| 189 |
+
|
| 190 |
+
|
| 191 |
+
def _star_cos_zenith(model_time, lat, lon):
|
| 192 |
+
"""
|
| 193 |
+
Return cosine of star zenith angle
|
| 194 |
+
lon,lat in radians
|
| 195 |
+
Ref:
|
| 196 |
+
Azimuth:
|
| 197 |
+
https://en.wikipedia.org/wiki/Solar_azimuth_angle#Formulas
|
| 198 |
+
Zenith:
|
| 199 |
+
https://en.wikipedia.org/wiki/Solar_zenith_angle
|
| 200 |
+
"""
|
| 201 |
+
|
| 202 |
+
ra, dec = _right_ascension_declination(model_time)
|
| 203 |
+
h_angle = _local_hour_angle(model_time, lon, ra)
|
| 204 |
+
|
| 205 |
+
cosine_zenith = dali.math.sin(lat) * dali.math.sin(dec) + dali.math.cos(
|
| 206 |
+
lat
|
| 207 |
+
) * dali.math.cos(dec) * dali.math.cos(h_angle)
|
| 208 |
+
return cosine_zenith
|
physics_mcp/source/physicsnemo/datapipes/datapipe.py
ADDED
|
@@ -0,0 +1,60 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import logging
|
| 18 |
+
|
| 19 |
+
from physicsnemo.datapipes.meta import DatapipeMetaData
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
class Datapipe:
|
| 23 |
+
"""The base class for all datapipes in PhysicsNeMo.
|
| 24 |
+
|
| 25 |
+
Parameters
|
| 26 |
+
----------
|
| 27 |
+
meta : DatapipeMetaData, optional
|
| 28 |
+
Meta data class for storing info regarding model, by default None
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, meta: DatapipeMetaData = None):
|
| 32 |
+
super().__init__()
|
| 33 |
+
|
| 34 |
+
if not meta or not isinstance(meta, DatapipeMetaData):
|
| 35 |
+
self.meta = DatapipeMetaData()
|
| 36 |
+
else:
|
| 37 |
+
self.meta = meta
|
| 38 |
+
|
| 39 |
+
self.logger = logging.getLogger("core.datapipe")
|
| 40 |
+
handler = logging.StreamHandler()
|
| 41 |
+
formatter = logging.Formatter(
|
| 42 |
+
"[%(asctime)s - %(levelname)s] %(message)s", datefmt="%H:%M:%S"
|
| 43 |
+
)
|
| 44 |
+
handler.setFormatter(formatter)
|
| 45 |
+
self.logger.addHandler(handler)
|
| 46 |
+
self.logger.setLevel(logging.WARNING)
|
| 47 |
+
|
| 48 |
+
def debug(self):
|
| 49 |
+
"""Turn on debug logging"""
|
| 50 |
+
self.logger.handlers.clear()
|
| 51 |
+
handler = logging.StreamHandler()
|
| 52 |
+
formatter = logging.Formatter(
|
| 53 |
+
f"[%(asctime)s - %(levelname)s - {self.meta.name}] %(message)s",
|
| 54 |
+
datefmt="%Y-%m-%d %H:%M:%S",
|
| 55 |
+
)
|
| 56 |
+
handler.setFormatter(formatter)
|
| 57 |
+
self.logger.addHandler(handler)
|
| 58 |
+
self.logger.setLevel(logging.DEBUG)
|
| 59 |
+
# TODO: set up debug log
|
| 60 |
+
# fh = logging.FileHandler(f'physicsnemo-core-{self.meta.name}.log')
|
physics_mcp/source/physicsnemo/datapipes/gnn/__init__.py
ADDED
|
@@ -0,0 +1,15 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2023 - 2025 NVIDIA CORPORATION & AFFILIATES.
|
| 2 |
+
# SPDX-FileCopyrightText: All rights reserved.
|
| 3 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|