Add files using upload-large-folder tool
Browse files- pythonProject/.venv/Lib/site-packages/accelerate/commands/launch.py +1245 -0
- pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/INSTALLER +1 -0
- pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/METADATA +441 -0
- pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/RECORD +31 -0
- pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/WHEEL +5 -0
- pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/licenses/LICENSE.txt +27 -0
- pythonProject/.venv/Lib/site-packages/colorama/__pycache__/__init__.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/colorama/ansi.py +102 -0
- pythonProject/.venv/Lib/site-packages/colorama/ansitowin32.py +277 -0
- pythonProject/.venv/Lib/site-packages/colorama/initialise.py +121 -0
- pythonProject/.venv/Lib/site-packages/colorama/win32.py +180 -0
- pythonProject/.venv/Lib/site-packages/diffusers/callbacks.py +244 -0
- pythonProject/.venv/Lib/site-packages/diffusers/configuration_utils.py +769 -0
- pythonProject/.venv/Lib/site-packages/diffusers/dependency_versions_check.py +34 -0
- pythonProject/.venv/Lib/site-packages/diffusers/dependency_versions_table.py +54 -0
- pythonProject/.venv/Lib/site-packages/diffusers/image_processor.py +1451 -0
- pythonProject/.venv/Lib/site-packages/diffusers/optimization.py +361 -0
- pythonProject/.venv/Lib/site-packages/diffusers/pipelines/marigold/__pycache__/pipeline_marigold_normals.cpython-310.pyc +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/py.typed +0 -0
- pythonProject/.venv/Lib/site-packages/diffusers/training_utils.py +730 -0
pythonProject/.venv/Lib/site-packages/accelerate/commands/launch.py
ADDED
|
@@ -0,0 +1,1245 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
#!/usr/bin/env python
|
| 2 |
+
|
| 3 |
+
# Copyright 2021 The HuggingFace Team. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
|
| 17 |
+
import argparse
|
| 18 |
+
import importlib
|
| 19 |
+
import logging
|
| 20 |
+
import os
|
| 21 |
+
import subprocess
|
| 22 |
+
import sys
|
| 23 |
+
from pathlib import Path
|
| 24 |
+
|
| 25 |
+
import psutil
|
| 26 |
+
import torch
|
| 27 |
+
|
| 28 |
+
from accelerate.commands.config import default_config_file, load_config_from_file
|
| 29 |
+
from accelerate.commands.config.config_args import SageMakerConfig
|
| 30 |
+
from accelerate.commands.config.config_utils import DYNAMO_BACKENDS
|
| 31 |
+
from accelerate.commands.utils import CustomArgumentParser
|
| 32 |
+
from accelerate.state import get_int_from_env
|
| 33 |
+
from accelerate.utils import (
|
| 34 |
+
ComputeEnvironment,
|
| 35 |
+
DistributedType,
|
| 36 |
+
PrepareForLaunch,
|
| 37 |
+
_filter_args,
|
| 38 |
+
check_cuda_p2p_ib_support,
|
| 39 |
+
convert_dict_to_env_variables,
|
| 40 |
+
is_bf16_available,
|
| 41 |
+
is_deepspeed_available,
|
| 42 |
+
is_hpu_available,
|
| 43 |
+
is_mlu_available,
|
| 44 |
+
is_musa_available,
|
| 45 |
+
is_npu_available,
|
| 46 |
+
is_rich_available,
|
| 47 |
+
is_sagemaker_available,
|
| 48 |
+
is_sdaa_available,
|
| 49 |
+
is_torch_xla_available,
|
| 50 |
+
is_xpu_available,
|
| 51 |
+
patch_environment,
|
| 52 |
+
prepare_deepspeed_cmd_env,
|
| 53 |
+
prepare_multi_gpu_env,
|
| 54 |
+
prepare_sagemager_args_inputs,
|
| 55 |
+
prepare_simple_launcher_cmd_env,
|
| 56 |
+
prepare_tpu,
|
| 57 |
+
str_to_bool,
|
| 58 |
+
)
|
| 59 |
+
from accelerate.utils.constants import DEEPSPEED_MULTINODE_LAUNCHERS, TORCH_DYNAMO_MODES
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
if is_rich_available():
|
| 63 |
+
from rich import get_console
|
| 64 |
+
from rich.logging import RichHandler
|
| 65 |
+
|
| 66 |
+
FORMAT = "%(message)s"
|
| 67 |
+
logging.basicConfig(format=FORMAT, datefmt="[%X]", handlers=[RichHandler()])
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
logger = logging.getLogger(__name__)
|
| 71 |
+
|
| 72 |
+
|
| 73 |
+
options_to_group = {
|
| 74 |
+
"multi_gpu": "Distributed GPUs",
|
| 75 |
+
"tpu": "TPU",
|
| 76 |
+
"use_deepspeed": "DeepSpeed Arguments",
|
| 77 |
+
"use_fsdp": "FSDP Arguments",
|
| 78 |
+
"use_megatron_lm": "Megatron-LM Arguments",
|
| 79 |
+
"fp8_backend": "FP8 Arguments",
|
| 80 |
+
}
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
def clean_option(option):
|
| 84 |
+
"Finds all cases of - after the first two characters and changes them to _"
|
| 85 |
+
if "fp8_backend" in option:
|
| 86 |
+
option = "--fp8_backend"
|
| 87 |
+
if option.startswith("--"):
|
| 88 |
+
return option[2:].replace("-", "_")
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class CustomHelpFormatter(argparse.HelpFormatter):
|
| 92 |
+
"""
|
| 93 |
+
This is a custom help formatter that will hide all arguments that are not used in the command line when the help is
|
| 94 |
+
called. This is useful for the case where the user is using a specific platform and only wants to see the arguments
|
| 95 |
+
for that platform.
|
| 96 |
+
"""
|
| 97 |
+
|
| 98 |
+
def __init__(self, *args, **kwargs):
|
| 99 |
+
super().__init__(*args, **kwargs)
|
| 100 |
+
self.titles = [
|
| 101 |
+
"Hardware Selection Arguments",
|
| 102 |
+
"Resource Selection Arguments",
|
| 103 |
+
"Training Paradigm Arguments",
|
| 104 |
+
"positional arguments",
|
| 105 |
+
"optional arguments",
|
| 106 |
+
]
|
| 107 |
+
|
| 108 |
+
def add_argument(self, action: argparse.Action):
|
| 109 |
+
if "accelerate" in sys.argv[0] and "launch" in sys.argv[1:]:
|
| 110 |
+
args = sys.argv[2:]
|
| 111 |
+
else:
|
| 112 |
+
args = sys.argv[1:]
|
| 113 |
+
|
| 114 |
+
if len(args) > 1:
|
| 115 |
+
args = list(map(clean_option, args))
|
| 116 |
+
used_platforms = [arg for arg in args if arg in options_to_group.keys()]
|
| 117 |
+
used_titles = [options_to_group[o] for o in used_platforms]
|
| 118 |
+
if action.container.title not in self.titles + used_titles:
|
| 119 |
+
action.help = argparse.SUPPRESS
|
| 120 |
+
elif action.container.title == "Hardware Selection Arguments":
|
| 121 |
+
if set(action.option_strings).isdisjoint(set(args)):
|
| 122 |
+
action.help = argparse.SUPPRESS
|
| 123 |
+
else:
|
| 124 |
+
action.help = action.help + " (currently selected)"
|
| 125 |
+
elif action.container.title == "Training Paradigm Arguments":
|
| 126 |
+
if set(action.option_strings).isdisjoint(set(args)):
|
| 127 |
+
action.help = argparse.SUPPRESS
|
| 128 |
+
else:
|
| 129 |
+
action.help = action.help + " (currently selected)"
|
| 130 |
+
|
| 131 |
+
action.option_strings = [s for s in action.option_strings if "-" not in s[2:]]
|
| 132 |
+
super().add_argument(action)
|
| 133 |
+
|
| 134 |
+
def end_section(self):
|
| 135 |
+
if len(self._current_section.items) < 2:
|
| 136 |
+
self._current_section.items = []
|
| 137 |
+
self._current_section.heading = ""
|
| 138 |
+
super().end_section()
|
| 139 |
+
|
| 140 |
+
|
| 141 |
+
def launch_command_parser(subparsers=None):
|
| 142 |
+
description = "Launch a python script in a distributed scenario. Arguments can be passed in with either hyphens (`--num-processes=2`) or underscores (`--num_processes=2`)"
|
| 143 |
+
if subparsers is not None:
|
| 144 |
+
parser = subparsers.add_parser(
|
| 145 |
+
"launch", description=description, add_help=False, allow_abbrev=False, formatter_class=CustomHelpFormatter
|
| 146 |
+
)
|
| 147 |
+
else:
|
| 148 |
+
parser = CustomArgumentParser(
|
| 149 |
+
"Accelerate launch command",
|
| 150 |
+
description=description,
|
| 151 |
+
add_help=False,
|
| 152 |
+
allow_abbrev=False,
|
| 153 |
+
formatter_class=CustomHelpFormatter,
|
| 154 |
+
)
|
| 155 |
+
|
| 156 |
+
parser.add_argument("-h", "--help", action="help", help="Show this help message and exit.")
|
| 157 |
+
|
| 158 |
+
parser.add_argument(
|
| 159 |
+
"--config_file",
|
| 160 |
+
default=None,
|
| 161 |
+
help="The config file to use for the default values in the launching script.",
|
| 162 |
+
)
|
| 163 |
+
parser.add_argument(
|
| 164 |
+
"--quiet",
|
| 165 |
+
"-q",
|
| 166 |
+
action="store_true",
|
| 167 |
+
help="Silence subprocess errors from the launch stack trace and only show the relevant tracebacks. (Only applicable to DeepSpeed and single-process configurations)",
|
| 168 |
+
)
|
| 169 |
+
# Hardware selection arguments
|
| 170 |
+
hardware_args = parser.add_argument_group(
|
| 171 |
+
"Hardware Selection Arguments", "Arguments for selecting the hardware to be used."
|
| 172 |
+
)
|
| 173 |
+
hardware_args.add_argument(
|
| 174 |
+
"--cpu", default=False, action="store_true", help="Whether or not to force the training on the CPU."
|
| 175 |
+
)
|
| 176 |
+
hardware_args.add_argument(
|
| 177 |
+
"--multi_gpu",
|
| 178 |
+
default=False,
|
| 179 |
+
action="store_true",
|
| 180 |
+
help="Whether or not this should launch a distributed GPU training.",
|
| 181 |
+
)
|
| 182 |
+
hardware_args.add_argument(
|
| 183 |
+
"--tpu", default=False, action="store_true", help="Whether or not this should launch a TPU training."
|
| 184 |
+
)
|
| 185 |
+
# Resource selection arguments
|
| 186 |
+
resource_args = parser.add_argument_group(
|
| 187 |
+
"Resource Selection Arguments", "Arguments for fine-tuning how available hardware should be used."
|
| 188 |
+
)
|
| 189 |
+
resource_args.add_argument(
|
| 190 |
+
"--mixed_precision",
|
| 191 |
+
type=str,
|
| 192 |
+
choices=["no", "fp16", "bf16", "fp8"],
|
| 193 |
+
help="Whether or not to use mixed precision training. "
|
| 194 |
+
"Choose between FP16 and BF16 (bfloat16) training. "
|
| 195 |
+
"BF16 training is only supported on Nvidia Ampere GPUs and PyTorch 1.10 or later.",
|
| 196 |
+
)
|
| 197 |
+
resource_args.add_argument(
|
| 198 |
+
"--num_processes", type=int, default=None, help="The total number of processes to be launched in parallel."
|
| 199 |
+
)
|
| 200 |
+
resource_args.add_argument(
|
| 201 |
+
"--num_machines", type=int, default=None, help="The total number of machines used in this training."
|
| 202 |
+
)
|
| 203 |
+
resource_args.add_argument(
|
| 204 |
+
"--num_cpu_threads_per_process",
|
| 205 |
+
type=int,
|
| 206 |
+
default=None,
|
| 207 |
+
help="The number of CPU threads per process. Can be tuned for optimal performance.",
|
| 208 |
+
)
|
| 209 |
+
resource_args.add_argument(
|
| 210 |
+
"--enable_cpu_affinity",
|
| 211 |
+
default=False,
|
| 212 |
+
action="store_true",
|
| 213 |
+
help="Whether or not CPU affinity and balancing should be enabled. Currently only supported on NVIDIA hardware.",
|
| 214 |
+
)
|
| 215 |
+
# Dynamo arguments
|
| 216 |
+
resource_args.add_argument(
|
| 217 |
+
"--dynamo_backend",
|
| 218 |
+
type=str,
|
| 219 |
+
choices=["no"] + [b.lower() for b in DYNAMO_BACKENDS],
|
| 220 |
+
help="Choose a backend to optimize your training with dynamo, see more at "
|
| 221 |
+
"https://github.com/pytorch/torchdynamo.",
|
| 222 |
+
)
|
| 223 |
+
resource_args.add_argument(
|
| 224 |
+
"--dynamo_mode",
|
| 225 |
+
type=str,
|
| 226 |
+
default="default",
|
| 227 |
+
choices=TORCH_DYNAMO_MODES,
|
| 228 |
+
help="Choose a mode to optimize your training with dynamo.",
|
| 229 |
+
)
|
| 230 |
+
resource_args.add_argument(
|
| 231 |
+
"--dynamo_use_fullgraph",
|
| 232 |
+
default=False,
|
| 233 |
+
action="store_true",
|
| 234 |
+
help="Whether to use full graph mode for dynamo or it is ok to break model into several subgraphs",
|
| 235 |
+
)
|
| 236 |
+
resource_args.add_argument(
|
| 237 |
+
"--dynamo_use_dynamic",
|
| 238 |
+
default=False,
|
| 239 |
+
action="store_true",
|
| 240 |
+
help="Whether to enable dynamic shape tracing.",
|
| 241 |
+
)
|
| 242 |
+
resource_args.add_argument(
|
| 243 |
+
"--dynamo_use_regional_compilation",
|
| 244 |
+
default=False,
|
| 245 |
+
action="store_true",
|
| 246 |
+
help="Whether to enable regional compilation.",
|
| 247 |
+
)
|
| 248 |
+
|
| 249 |
+
# Training Paradigm arguments
|
| 250 |
+
paradigm_args = parser.add_argument_group(
|
| 251 |
+
"Training Paradigm Arguments", "Arguments for selecting which training paradigm to be used."
|
| 252 |
+
)
|
| 253 |
+
paradigm_args.add_argument(
|
| 254 |
+
"--use_deepspeed",
|
| 255 |
+
default=False,
|
| 256 |
+
action="store_true",
|
| 257 |
+
help="Whether to use deepspeed.",
|
| 258 |
+
)
|
| 259 |
+
paradigm_args.add_argument(
|
| 260 |
+
"--use_fsdp",
|
| 261 |
+
default=False,
|
| 262 |
+
action="store_true",
|
| 263 |
+
help="Whether to use fsdp.",
|
| 264 |
+
)
|
| 265 |
+
paradigm_args.add_argument(
|
| 266 |
+
"--use_parallelism_config",
|
| 267 |
+
default=False,
|
| 268 |
+
action="store_true",
|
| 269 |
+
help="Whether to use the parallelism config to configure the N-d distributed training.",
|
| 270 |
+
)
|
| 271 |
+
paradigm_args.add_argument(
|
| 272 |
+
"--use_megatron_lm",
|
| 273 |
+
default=False,
|
| 274 |
+
action="store_true",
|
| 275 |
+
help="Whether to use Megatron-LM.",
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
paradigm_args.add_argument(
|
| 279 |
+
"--use_xpu",
|
| 280 |
+
default=None,
|
| 281 |
+
action="store_true",
|
| 282 |
+
help="Whether to use IPEX plugin to speed up training on XPU specifically. This argument is deprecated and ignored, will be removed in Accelerate v1.20.",
|
| 283 |
+
)
|
| 284 |
+
|
| 285 |
+
# distributed GPU training arguments
|
| 286 |
+
distributed_args = parser.add_argument_group("Distributed GPUs", "Arguments related to distributed GPU training.")
|
| 287 |
+
distributed_args.add_argument(
|
| 288 |
+
"--gpu_ids",
|
| 289 |
+
default=None,
|
| 290 |
+
help="What GPUs (by id) should be used for training on this machine as a comma-separated list",
|
| 291 |
+
)
|
| 292 |
+
distributed_args.add_argument(
|
| 293 |
+
"--same_network",
|
| 294 |
+
default=False,
|
| 295 |
+
action="store_true",
|
| 296 |
+
help="Whether all machines used for multinode training exist on the same local network.",
|
| 297 |
+
)
|
| 298 |
+
distributed_args.add_argument(
|
| 299 |
+
"--machine_rank", type=int, default=None, help="The rank of the machine on which this script is launched."
|
| 300 |
+
)
|
| 301 |
+
distributed_args.add_argument(
|
| 302 |
+
"--main_process_ip", type=str, default=None, help="The IP address of the machine of rank 0."
|
| 303 |
+
)
|
| 304 |
+
distributed_args.add_argument(
|
| 305 |
+
"--main_process_port",
|
| 306 |
+
type=int,
|
| 307 |
+
default=None,
|
| 308 |
+
help="The port to use to communicate with the machine of rank 0.",
|
| 309 |
+
)
|
| 310 |
+
distributed_args.add_argument(
|
| 311 |
+
"-t",
|
| 312 |
+
"--tee",
|
| 313 |
+
default="0",
|
| 314 |
+
type=str,
|
| 315 |
+
help="Tee std streams into a log file and also to console.",
|
| 316 |
+
)
|
| 317 |
+
distributed_args.add_argument(
|
| 318 |
+
"--log_dir",
|
| 319 |
+
type=str,
|
| 320 |
+
default=None,
|
| 321 |
+
help=(
|
| 322 |
+
"Base directory to use for log files when using torchrun/torch.distributed.run as launcher. "
|
| 323 |
+
"Use with --tee to redirect std streams info log files."
|
| 324 |
+
),
|
| 325 |
+
)
|
| 326 |
+
distributed_args.add_argument(
|
| 327 |
+
"--role",
|
| 328 |
+
type=str,
|
| 329 |
+
default="default",
|
| 330 |
+
help="User-defined role for the workers.",
|
| 331 |
+
)
|
| 332 |
+
# Rendezvous related arguments
|
| 333 |
+
distributed_args.add_argument(
|
| 334 |
+
"--rdzv_backend",
|
| 335 |
+
type=str,
|
| 336 |
+
default="static",
|
| 337 |
+
help="The rendezvous method to use, such as 'static' (the default) or 'c10d'",
|
| 338 |
+
)
|
| 339 |
+
distributed_args.add_argument(
|
| 340 |
+
"--rdzv_conf",
|
| 341 |
+
type=str,
|
| 342 |
+
default="",
|
| 343 |
+
help="Additional rendezvous configuration (<key1>=<value1>,<key2>=<value2>,...).",
|
| 344 |
+
)
|
| 345 |
+
distributed_args.add_argument(
|
| 346 |
+
"--max_restarts",
|
| 347 |
+
type=int,
|
| 348 |
+
default=0,
|
| 349 |
+
help="Maximum number of worker group restarts before failing.",
|
| 350 |
+
)
|
| 351 |
+
distributed_args.add_argument(
|
| 352 |
+
"--monitor_interval",
|
| 353 |
+
type=float,
|
| 354 |
+
default=0.1,
|
| 355 |
+
help="Interval, in seconds, to monitor the state of workers.",
|
| 356 |
+
)
|
| 357 |
+
parser.add_argument(
|
| 358 |
+
"-m",
|
| 359 |
+
"--module",
|
| 360 |
+
action="store_true",
|
| 361 |
+
help="Change each process to interpret the launch script as a Python module, executing with the same behavior as 'python -m'.",
|
| 362 |
+
)
|
| 363 |
+
parser.add_argument(
|
| 364 |
+
"--no_python",
|
| 365 |
+
action="store_true",
|
| 366 |
+
help="Skip prepending the training script with 'python' - just execute it directly. Useful when the script is not a Python script.",
|
| 367 |
+
)
|
| 368 |
+
|
| 369 |
+
# TPU arguments
|
| 370 |
+
tpu_args = parser.add_argument_group("TPU", "Arguments related to TPU.")
|
| 371 |
+
tpu_args.add_argument(
|
| 372 |
+
"--tpu_cluster",
|
| 373 |
+
action="store_true",
|
| 374 |
+
dest="tpu_use_cluster",
|
| 375 |
+
help="Whether to use a GCP TPU pod for training.",
|
| 376 |
+
)
|
| 377 |
+
tpu_args.add_argument(
|
| 378 |
+
"--no_tpu_cluster",
|
| 379 |
+
action="store_false",
|
| 380 |
+
dest="tpu_use_cluster",
|
| 381 |
+
help="Should not be passed explicitly, this is for internal use only.",
|
| 382 |
+
)
|
| 383 |
+
tpu_args.add_argument(
|
| 384 |
+
"--tpu_use_sudo",
|
| 385 |
+
action="store_true",
|
| 386 |
+
help="Whether to use `sudo` when running the TPU training script in each pod.",
|
| 387 |
+
)
|
| 388 |
+
tpu_args.add_argument(
|
| 389 |
+
"--vm",
|
| 390 |
+
type=str,
|
| 391 |
+
action="append",
|
| 392 |
+
help=(
|
| 393 |
+
"List of single Compute VM instance names. "
|
| 394 |
+
"If not provided we assume usage of instance groups. For TPU pods."
|
| 395 |
+
),
|
| 396 |
+
)
|
| 397 |
+
tpu_args.add_argument(
|
| 398 |
+
"--env",
|
| 399 |
+
type=str,
|
| 400 |
+
action="append",
|
| 401 |
+
help="List of environment variables to set on the Compute VM instances. For TPU pods.",
|
| 402 |
+
)
|
| 403 |
+
tpu_args.add_argument(
|
| 404 |
+
"--main_training_function",
|
| 405 |
+
type=str,
|
| 406 |
+
default=None,
|
| 407 |
+
help="The name of the main function to be executed in your script (only for TPU training).",
|
| 408 |
+
)
|
| 409 |
+
tpu_args.add_argument(
|
| 410 |
+
"--downcast_bf16",
|
| 411 |
+
action="store_true",
|
| 412 |
+
help="Whether when using bf16 precision on TPUs if both float and double tensors are cast to bfloat16 or if double tensors remain as float32.",
|
| 413 |
+
)
|
| 414 |
+
|
| 415 |
+
# DeepSpeed arguments
|
| 416 |
+
deepspeed_args = parser.add_argument_group("DeepSpeed Arguments", "Arguments related to DeepSpeed.")
|
| 417 |
+
deepspeed_args.add_argument(
|
| 418 |
+
"--deepspeed_config_file",
|
| 419 |
+
default=None,
|
| 420 |
+
type=str,
|
| 421 |
+
help="DeepSpeed config file.",
|
| 422 |
+
)
|
| 423 |
+
deepspeed_args.add_argument(
|
| 424 |
+
"--zero_stage",
|
| 425 |
+
default=None,
|
| 426 |
+
type=int,
|
| 427 |
+
help="DeepSpeed's ZeRO optimization stage (useful only when `use_deepspeed` flag is passed). "
|
| 428 |
+
"If unspecified, will default to `2`.",
|
| 429 |
+
)
|
| 430 |
+
deepspeed_args.add_argument(
|
| 431 |
+
"--offload_optimizer_device",
|
| 432 |
+
default=None,
|
| 433 |
+
type=str,
|
| 434 |
+
help="Decides where (none|cpu|nvme) to offload optimizer states (useful only when `use_deepspeed` flag is passed). "
|
| 435 |
+
"If unspecified, will default to 'none'.",
|
| 436 |
+
)
|
| 437 |
+
deepspeed_args.add_argument(
|
| 438 |
+
"--offload_param_device",
|
| 439 |
+
default=None,
|
| 440 |
+
type=str,
|
| 441 |
+
help="Decides where (none|cpu|nvme) to offload parameters (useful only when `use_deepspeed` flag is passed). "
|
| 442 |
+
"If unspecified, will default to 'none'.",
|
| 443 |
+
)
|
| 444 |
+
deepspeed_args.add_argument(
|
| 445 |
+
"--offload_optimizer_nvme_path",
|
| 446 |
+
default=None,
|
| 447 |
+
type=str,
|
| 448 |
+
help="Decides Nvme Path to offload optimizer states (useful only when `use_deepspeed` flag is passed). "
|
| 449 |
+
"If unspecified, will default to 'none'.",
|
| 450 |
+
)
|
| 451 |
+
deepspeed_args.add_argument(
|
| 452 |
+
"--offload_param_nvme_path",
|
| 453 |
+
default=None,
|
| 454 |
+
type=str,
|
| 455 |
+
help="Decides Nvme Path to offload parameters (useful only when `use_deepspeed` flag is passed). "
|
| 456 |
+
"If unspecified, will default to 'none'.",
|
| 457 |
+
)
|
| 458 |
+
deepspeed_args.add_argument(
|
| 459 |
+
"--gradient_accumulation_steps",
|
| 460 |
+
default=None,
|
| 461 |
+
type=int,
|
| 462 |
+
help="No of gradient_accumulation_steps used in your training script (useful only when `use_deepspeed` flag is passed). "
|
| 463 |
+
"If unspecified, will default to `1`.",
|
| 464 |
+
)
|
| 465 |
+
deepspeed_args.add_argument(
|
| 466 |
+
"--gradient_clipping",
|
| 467 |
+
default=None,
|
| 468 |
+
type=float,
|
| 469 |
+
help="gradient clipping value used in your training script (useful only when `use_deepspeed` flag is passed). "
|
| 470 |
+
"If unspecified, will default to `1.0`.",
|
| 471 |
+
)
|
| 472 |
+
deepspeed_args.add_argument(
|
| 473 |
+
"--zero3_init_flag",
|
| 474 |
+
default=None,
|
| 475 |
+
type=str,
|
| 476 |
+
help="Decides Whether (true|false) to enable `deepspeed.zero.Init` for constructing massive models. "
|
| 477 |
+
"Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `true`.",
|
| 478 |
+
)
|
| 479 |
+
deepspeed_args.add_argument(
|
| 480 |
+
"--zero3_save_16bit_model",
|
| 481 |
+
default=None,
|
| 482 |
+
type=str,
|
| 483 |
+
help="Decides Whether (true|false) to save 16-bit model weights when using ZeRO Stage-3. "
|
| 484 |
+
"Only applicable with DeepSpeed ZeRO Stage-3. If unspecified, will default to `false`.",
|
| 485 |
+
)
|
| 486 |
+
deepspeed_args.add_argument(
|
| 487 |
+
"--deepspeed_hostfile",
|
| 488 |
+
default=None,
|
| 489 |
+
type=str,
|
| 490 |
+
help="DeepSpeed hostfile for configuring multi-node compute resources.",
|
| 491 |
+
)
|
| 492 |
+
deepspeed_args.add_argument(
|
| 493 |
+
"--deepspeed_exclusion_filter",
|
| 494 |
+
default=None,
|
| 495 |
+
type=str,
|
| 496 |
+
help="DeepSpeed exclusion filter string when using mutli-node setup.",
|
| 497 |
+
)
|
| 498 |
+
deepspeed_args.add_argument(
|
| 499 |
+
"--deepspeed_inclusion_filter",
|
| 500 |
+
default=None,
|
| 501 |
+
type=str,
|
| 502 |
+
help="DeepSpeed inclusion filter string when using mutli-node setup.",
|
| 503 |
+
)
|
| 504 |
+
deepspeed_args.add_argument(
|
| 505 |
+
"--deepspeed_multinode_launcher",
|
| 506 |
+
default=None,
|
| 507 |
+
type=str,
|
| 508 |
+
help="DeepSpeed multi-node launcher to use, e.g. `pdsh`, `standard`, `openmpi`, `mvapich`, `mpich`, `slurm`, `nossh` (requires DeepSpeed >= 0.14.5). If unspecified, will default to `pdsh`.",
|
| 509 |
+
)
|
| 510 |
+
deepspeed_args.add_argument(
|
| 511 |
+
"--deepspeed_moe_layer_cls_names",
|
| 512 |
+
default=None,
|
| 513 |
+
type=str,
|
| 514 |
+
help="comma-separated list of transformer MoE layer class names (case-sensitive) to wrap ,e.g, `MixtralSparseMoeBlock`, `Qwen2MoeSparseMoeBlock`, `JetMoEAttention,JetMoEBlock` ..."
|
| 515 |
+
" (useful only when `use_deepspeed` flag is passed).",
|
| 516 |
+
)
|
| 517 |
+
|
| 518 |
+
# fsdp arguments
|
| 519 |
+
fsdp_args = parser.add_argument_group("FSDP Arguments", "Arguments related to Fully Shared Data Parallelism.")
|
| 520 |
+
fsdp_args.add_argument(
|
| 521 |
+
"--fsdp_version",
|
| 522 |
+
type=str,
|
| 523 |
+
default="1",
|
| 524 |
+
choices=["1", "2"],
|
| 525 |
+
help="FSDP version to use. (useful only when `use_fsdp` flag is passed).",
|
| 526 |
+
)
|
| 527 |
+
fsdp_args.add_argument(
|
| 528 |
+
"--fsdp_offload_params",
|
| 529 |
+
default="false",
|
| 530 |
+
type=str,
|
| 531 |
+
help="Decides Whether (true|false) to offload parameters and gradients to CPU. (useful only when `use_fsdp` flag is passed).",
|
| 532 |
+
)
|
| 533 |
+
fsdp_args.add_argument(
|
| 534 |
+
"--fsdp_min_num_params",
|
| 535 |
+
type=int,
|
| 536 |
+
default=1e8,
|
| 537 |
+
help="FSDP's minimum number of parameters for Default Auto Wrapping. (useful only when `use_fsdp` flag is passed).",
|
| 538 |
+
)
|
| 539 |
+
# We enable this for backwards compatibility, throw a warning if this is set in `FullyShardedDataParallelPlugin`
|
| 540 |
+
fsdp_args.add_argument(
|
| 541 |
+
"--fsdp_sharding_strategy",
|
| 542 |
+
type=str,
|
| 543 |
+
default="FULL_SHARD",
|
| 544 |
+
help="FSDP's sharding strategy. (useful only when `use_fsdp` flag is passed and `fsdp_version=1`).",
|
| 545 |
+
)
|
| 546 |
+
fsdp_args.add_argument(
|
| 547 |
+
"--fsdp_reshard_after_forward",
|
| 548 |
+
type=str,
|
| 549 |
+
default="true",
|
| 550 |
+
help="FSDP's Reshard After Forward Strategy. (useful only when `use_fsdp` flag is passed). Supports either boolean (FSDP2) or `FULL_SHARD | SHARD_GRAD_OP | NO_RESHARD` (FSDP1).",
|
| 551 |
+
)
|
| 552 |
+
fsdp_args.add_argument(
|
| 553 |
+
"--fsdp_auto_wrap_policy",
|
| 554 |
+
type=str,
|
| 555 |
+
default=None,
|
| 556 |
+
help="FSDP's auto wrap policy. (useful only when `use_fsdp` flag is passed).",
|
| 557 |
+
)
|
| 558 |
+
fsdp_args.add_argument(
|
| 559 |
+
"--fsdp_transformer_layer_cls_to_wrap",
|
| 560 |
+
default=None,
|
| 561 |
+
type=str,
|
| 562 |
+
help="Transformer layer class name (case-sensitive) to wrap ,e.g, `BertLayer`, `GPTJBlock`, `T5Block` .... "
|
| 563 |
+
"(useful only when `use_fsdp` flag is passed).",
|
| 564 |
+
)
|
| 565 |
+
fsdp_args.add_argument(
|
| 566 |
+
"--fsdp_backward_prefetch",
|
| 567 |
+
default=None,
|
| 568 |
+
type=str,
|
| 569 |
+
help="FSDP's backward prefetch policy. (useful only when `use_fsdp` flag is passed).",
|
| 570 |
+
)
|
| 571 |
+
fsdp_args.add_argument(
|
| 572 |
+
"--fsdp_state_dict_type",
|
| 573 |
+
default=None,
|
| 574 |
+
type=str,
|
| 575 |
+
help="FSDP's state dict type. (useful only when `use_fsdp` flag is passed).",
|
| 576 |
+
)
|
| 577 |
+
fsdp_args.add_argument(
|
| 578 |
+
"--fsdp_forward_prefetch",
|
| 579 |
+
default="false",
|
| 580 |
+
type=str,
|
| 581 |
+
help="If True, then FSDP explicitly prefetches the next upcoming "
|
| 582 |
+
"all-gather while executing in the forward pass (useful only when `use_fsdp` flag is passed).",
|
| 583 |
+
)
|
| 584 |
+
fsdp_args.add_argument(
|
| 585 |
+
"--fsdp_use_orig_params",
|
| 586 |
+
default="true",
|
| 587 |
+
type=str,
|
| 588 |
+
help="If True, allows non-uniform `requires_grad` during init, which means support for interspersed frozen and trainable paramteres."
|
| 589 |
+
" (useful only when `use_fsdp` flag is passed).",
|
| 590 |
+
)
|
| 591 |
+
fsdp_args.add_argument(
|
| 592 |
+
"--fsdp_cpu_ram_efficient_loading",
|
| 593 |
+
default="true",
|
| 594 |
+
type=str,
|
| 595 |
+
help="If True, only the first process loads the pretrained model checkoint while all other processes have empty weights. "
|
| 596 |
+
"Only applicable for 🤗 Transformers. When using this, `--fsdp_sync_module_states` needs to True. "
|
| 597 |
+
"(useful only when `use_fsdp` flag is passed).",
|
| 598 |
+
)
|
| 599 |
+
fsdp_args.add_argument(
|
| 600 |
+
"--fsdp_sync_module_states",
|
| 601 |
+
default="true",
|
| 602 |
+
type=str,
|
| 603 |
+
help="If True, each individually wrapped FSDP unit will broadcast module parameters from rank 0."
|
| 604 |
+
" (useful only when `use_fsdp` flag is passed).",
|
| 605 |
+
)
|
| 606 |
+
fsdp_args.add_argument(
|
| 607 |
+
"--fsdp_activation_checkpointing",
|
| 608 |
+
default="false",
|
| 609 |
+
type=str,
|
| 610 |
+
help="Decides Whether (true|false) intermediate activations are freed during the forward pass, and a checkpoint is left as a placeholder. (useful only when `use_fsdp` flag is passed).",
|
| 611 |
+
)
|
| 612 |
+
|
| 613 |
+
# megatron_lm args
|
| 614 |
+
megatron_lm_args = parser.add_argument_group("Megatron-LM Arguments", "Arguments related to Megatron-LM.")
|
| 615 |
+
megatron_lm_args.add_argument(
|
| 616 |
+
"--megatron_lm_tp_degree",
|
| 617 |
+
type=int,
|
| 618 |
+
default=1,
|
| 619 |
+
help="Megatron-LM's Tensor Parallelism (TP) degree. (useful only when `use_megatron_lm` flag is passed).",
|
| 620 |
+
)
|
| 621 |
+
megatron_lm_args.add_argument(
|
| 622 |
+
"--megatron_lm_pp_degree",
|
| 623 |
+
type=int,
|
| 624 |
+
default=1,
|
| 625 |
+
help="Megatron-LM's Pipeline Parallelism (PP) degree. (useful only when `use_megatron_lm` flag is passed).",
|
| 626 |
+
)
|
| 627 |
+
megatron_lm_args.add_argument(
|
| 628 |
+
"--megatron_lm_num_micro_batches",
|
| 629 |
+
type=int,
|
| 630 |
+
default=None,
|
| 631 |
+
help="Megatron-LM's number of micro batches when PP degree > 1. (useful only when `use_megatron_lm` flag is passed).",
|
| 632 |
+
)
|
| 633 |
+
megatron_lm_args.add_argument(
|
| 634 |
+
"--megatron_lm_sequence_parallelism",
|
| 635 |
+
default=None,
|
| 636 |
+
type=str,
|
| 637 |
+
help="Decides Whether (true|false) to enable Sequence Parallelism when TP degree > 1. "
|
| 638 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 639 |
+
)
|
| 640 |
+
megatron_lm_args.add_argument(
|
| 641 |
+
"--megatron_lm_recompute_activations",
|
| 642 |
+
default=None,
|
| 643 |
+
type=str,
|
| 644 |
+
help="Decides Whether (true|false) to enable Selective Activation Recomputation. "
|
| 645 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 646 |
+
)
|
| 647 |
+
megatron_lm_args.add_argument(
|
| 648 |
+
"--megatron_lm_use_distributed_optimizer",
|
| 649 |
+
default=None,
|
| 650 |
+
type=str,
|
| 651 |
+
help="Decides Whether (true|false) to use distributed optimizer "
|
| 652 |
+
"which shards optimizer state and gradients across Data Pralellel (DP) ranks. "
|
| 653 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 654 |
+
)
|
| 655 |
+
megatron_lm_args.add_argument(
|
| 656 |
+
"--megatron_lm_gradient_clipping",
|
| 657 |
+
default=1.0,
|
| 658 |
+
type=float,
|
| 659 |
+
help="Megatron-LM's gradient clipping value based on global L2 Norm (0 to disable). "
|
| 660 |
+
"(useful only when `use_megatron_lm` flag is passed).",
|
| 661 |
+
)
|
| 662 |
+
|
| 663 |
+
# FP8 arguments
|
| 664 |
+
fp8_args = parser.add_argument_group(
|
| 665 |
+
"FP8 Arguments", "Arguments related to FP8 training (requires `--mixed_precision=fp8`)"
|
| 666 |
+
)
|
| 667 |
+
fp8_args.add_argument(
|
| 668 |
+
"--fp8_backend",
|
| 669 |
+
type=str,
|
| 670 |
+
choices=["te", "msamp"],
|
| 671 |
+
help="Choose a backend to train with FP8 (te: TransformerEngine, msamp: MS-AMP)",
|
| 672 |
+
)
|
| 673 |
+
fp8_args.add_argument(
|
| 674 |
+
"--fp8_use_autocast_during_eval",
|
| 675 |
+
default=False,
|
| 676 |
+
action="store_true",
|
| 677 |
+
help="Whether to use FP8 autocast during eval mode (useful only when `--fp8_backend=te` is passed). Generally better metrics are found when this is not passed.",
|
| 678 |
+
)
|
| 679 |
+
fp8_args.add_argument(
|
| 680 |
+
"--fp8_margin",
|
| 681 |
+
type=int,
|
| 682 |
+
default=0,
|
| 683 |
+
help="The margin to use for the gradient scaling (useful only when `--fp8_backend=te` is passed).",
|
| 684 |
+
)
|
| 685 |
+
fp8_args.add_argument(
|
| 686 |
+
"--fp8_interval",
|
| 687 |
+
type=int,
|
| 688 |
+
default=1,
|
| 689 |
+
help="The interval to use for how often the scaling factor is recomputed (useful only when `--fp8_backend=te` is passed).",
|
| 690 |
+
)
|
| 691 |
+
fp8_args.add_argument(
|
| 692 |
+
"--fp8_format",
|
| 693 |
+
type=str,
|
| 694 |
+
default="HYBRID",
|
| 695 |
+
choices=["HYBRID", "E4M3", "E5M2"],
|
| 696 |
+
help="The format to use for the FP8 recipe (useful only when `--fp8_backend=te` is passed).",
|
| 697 |
+
)
|
| 698 |
+
fp8_args.add_argument(
|
| 699 |
+
"--fp8_amax_history_len",
|
| 700 |
+
type=int,
|
| 701 |
+
default=1024,
|
| 702 |
+
help="The length of the history to use for the scaling factor computation (useful only when `--fp8_backend=te` is passed).",
|
| 703 |
+
)
|
| 704 |
+
fp8_args.add_argument(
|
| 705 |
+
"--fp8_amax_compute_algo",
|
| 706 |
+
type=str,
|
| 707 |
+
default="most_recent",
|
| 708 |
+
choices=["max", "most_recent"],
|
| 709 |
+
help="The algorithm to use for the scaling factor computation. (useful only when `--fp8_backend=te` is passed).",
|
| 710 |
+
)
|
| 711 |
+
fp8_args.add_argument(
|
| 712 |
+
"--fp8_override_linear_precision",
|
| 713 |
+
type=lambda x: tuple(map(str_to_bool, x.split(","))),
|
| 714 |
+
default=(False, False, False),
|
| 715 |
+
help="Whether or not to execute `fprop`, `dgrad`, and `wgrad` GEMMS in higher precision. Should be passed in a comma-separated string of booleans (useful only when `--fp8_backend=te` is passed).",
|
| 716 |
+
)
|
| 717 |
+
fp8_args.add_argument(
|
| 718 |
+
"--fp8_opt_level",
|
| 719 |
+
type=str,
|
| 720 |
+
default="O2",
|
| 721 |
+
choices=["O1", "O2"],
|
| 722 |
+
help="What level of 8-bit collective communication should be used with MS-AMP (useful only when `--fp8_backend=msamp` is passed).",
|
| 723 |
+
)
|
| 724 |
+
|
| 725 |
+
# AWS arguments
|
| 726 |
+
aws_args = parser.add_argument_group("AWS Arguments", "Arguments related to AWS.")
|
| 727 |
+
aws_args.add_argument(
|
| 728 |
+
"--aws_access_key_id",
|
| 729 |
+
type=str,
|
| 730 |
+
default=None,
|
| 731 |
+
help="The AWS_ACCESS_KEY_ID used to launch the Amazon SageMaker training job",
|
| 732 |
+
)
|
| 733 |
+
aws_args.add_argument(
|
| 734 |
+
"--aws_secret_access_key",
|
| 735 |
+
type=str,
|
| 736 |
+
default=None,
|
| 737 |
+
help="The AWS_SECRET_ACCESS_KEY used to launch the Amazon SageMaker training job.",
|
| 738 |
+
)
|
| 739 |
+
parser.add_argument(
|
| 740 |
+
"--debug",
|
| 741 |
+
action="store_true",
|
| 742 |
+
help="Whether to print out the torch.distributed stack trace when something fails.",
|
| 743 |
+
)
|
| 744 |
+
parser.add_argument(
|
| 745 |
+
"training_script",
|
| 746 |
+
type=str,
|
| 747 |
+
help=(
|
| 748 |
+
"The full path to the script to be launched in parallel, followed by all the arguments for the training "
|
| 749 |
+
"script."
|
| 750 |
+
),
|
| 751 |
+
)
|
| 752 |
+
|
| 753 |
+
# MPI arguments
|
| 754 |
+
mpirun_args = parser.add_argument_group("MPI Arguments", "Arguments related to mpirun for Multi-CPU")
|
| 755 |
+
mpirun_args.add_argument(
|
| 756 |
+
"--mpirun_hostfile",
|
| 757 |
+
type=str,
|
| 758 |
+
default=None,
|
| 759 |
+
help="Location for a hostfile for using Accelerate to launch a multi-CPU training job with mpirun. This will "
|
| 760 |
+
"get passed to the MPI --hostfile or -f parameter, depending on which MPI program is installed.",
|
| 761 |
+
)
|
| 762 |
+
mpirun_args.add_argument(
|
| 763 |
+
"--mpirun_ccl",
|
| 764 |
+
type=int,
|
| 765 |
+
default=1,
|
| 766 |
+
help="The number of oneCCL worker threads when using Accelerate to launch multi-CPU training with mpirun.",
|
| 767 |
+
)
|
| 768 |
+
|
| 769 |
+
# ParallelismConfig arguments
|
| 770 |
+
parallelism_config_args = parser.add_argument_group(
|
| 771 |
+
"ParallelismConfig Arguments",
|
| 772 |
+
"Arguments related to the ParallelismConfig used for distributed training.",
|
| 773 |
+
)
|
| 774 |
+
parallelism_config_args.add_argument(
|
| 775 |
+
"--parallelism_config_dp_replicate_size",
|
| 776 |
+
type=int,
|
| 777 |
+
default=1,
|
| 778 |
+
help="The number of processes for data parallel training. Defaults to 1 (no data parallelism).",
|
| 779 |
+
)
|
| 780 |
+
|
| 781 |
+
parallelism_config_args.add_argument(
|
| 782 |
+
"--parallelism_config_dp_shard_size",
|
| 783 |
+
type=int,
|
| 784 |
+
default=1,
|
| 785 |
+
help="The number of processes for FSDP sharding. Defaults to 1 (No FSDP sharding).",
|
| 786 |
+
)
|
| 787 |
+
|
| 788 |
+
parallelism_config_args.add_argument(
|
| 789 |
+
"--parallelism_config_tp_size",
|
| 790 |
+
type=int,
|
| 791 |
+
default=1,
|
| 792 |
+
help="The number of processes for tensor parallel training. Defaults to 1 (no tensor parallelism).",
|
| 793 |
+
)
|
| 794 |
+
|
| 795 |
+
parallelism_config_args.add_argument(
|
| 796 |
+
"--parallelism_config_cp_size",
|
| 797 |
+
type=int,
|
| 798 |
+
default=1,
|
| 799 |
+
help="The number of processese for context parallel training. Defaults to 1 (no context parallelism).",
|
| 800 |
+
)
|
| 801 |
+
parallelism_config_args.add_argument(
|
| 802 |
+
"--parallelism_config_cp_comm_strategy",
|
| 803 |
+
type=str,
|
| 804 |
+
default="allgather",
|
| 805 |
+
help="The communication strategy for context parallel training. Defaults to 'allgather'. Other option is alltoall",
|
| 806 |
+
)
|
| 807 |
+
|
| 808 |
+
# Other arguments of the training scripts
|
| 809 |
+
parser.add_argument("training_script_args", nargs=argparse.REMAINDER, help="Arguments of the training script.")
|
| 810 |
+
|
| 811 |
+
if subparsers is not None:
|
| 812 |
+
parser.set_defaults(func=launch_command)
|
| 813 |
+
return parser
|
| 814 |
+
|
| 815 |
+
|
| 816 |
+
def simple_launcher(args):
|
| 817 |
+
cmd, current_env = prepare_simple_launcher_cmd_env(args)
|
| 818 |
+
|
| 819 |
+
process = subprocess.Popen(cmd, env=current_env)
|
| 820 |
+
process.wait()
|
| 821 |
+
if process.returncode != 0:
|
| 822 |
+
if not args.quiet:
|
| 823 |
+
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
|
| 824 |
+
else:
|
| 825 |
+
sys.exit(1)
|
| 826 |
+
|
| 827 |
+
|
| 828 |
+
def multi_gpu_launcher(args):
|
| 829 |
+
import torch.distributed.run as distrib_run
|
| 830 |
+
|
| 831 |
+
current_env = prepare_multi_gpu_env(args)
|
| 832 |
+
if not check_cuda_p2p_ib_support():
|
| 833 |
+
message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled."
|
| 834 |
+
warn = False
|
| 835 |
+
if "NCCL_P2P_DISABLE" not in current_env:
|
| 836 |
+
current_env["NCCL_P2P_DISABLE"] = "1"
|
| 837 |
+
warn = True
|
| 838 |
+
if "NCCL_IB_DISABLE" not in current_env:
|
| 839 |
+
current_env["NCCL_IB_DISABLE"] = "1"
|
| 840 |
+
warn = True
|
| 841 |
+
if warn:
|
| 842 |
+
logger.warning(message)
|
| 843 |
+
|
| 844 |
+
debug = getattr(args, "debug", False)
|
| 845 |
+
args = _filter_args(
|
| 846 |
+
args,
|
| 847 |
+
distrib_run.get_args_parser(),
|
| 848 |
+
["--training_script", args.training_script, "--training_script_args", args.training_script_args],
|
| 849 |
+
)
|
| 850 |
+
|
| 851 |
+
with patch_environment(**current_env):
|
| 852 |
+
try:
|
| 853 |
+
distrib_run.run(args)
|
| 854 |
+
except Exception:
|
| 855 |
+
if is_rich_available() and debug:
|
| 856 |
+
console = get_console()
|
| 857 |
+
console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
|
| 858 |
+
console.print_exception(suppress=[__file__], show_locals=False)
|
| 859 |
+
else:
|
| 860 |
+
raise
|
| 861 |
+
|
| 862 |
+
|
| 863 |
+
def deepspeed_launcher(args):
|
| 864 |
+
import torch.distributed.run as distrib_run
|
| 865 |
+
|
| 866 |
+
if not is_deepspeed_available():
|
| 867 |
+
raise ImportError("DeepSpeed is not installed => run `pip3 install deepspeed` or build it from source.")
|
| 868 |
+
else:
|
| 869 |
+
from deepspeed.launcher.runner import DEEPSPEED_ENVIRONMENT_NAME
|
| 870 |
+
|
| 871 |
+
cmd, current_env = prepare_deepspeed_cmd_env(args)
|
| 872 |
+
if not check_cuda_p2p_ib_support():
|
| 873 |
+
message = "Using RTX 4000 series which doesn't support faster communication speedups. Ensuring P2P and IB communications are disabled."
|
| 874 |
+
warn = False
|
| 875 |
+
if "NCCL_P2P_DISABLE" not in current_env:
|
| 876 |
+
current_env["NCCL_P2P_DISABLE"] = "1"
|
| 877 |
+
warn = True
|
| 878 |
+
if "NCCL_IB_DISABLE" not in current_env:
|
| 879 |
+
current_env["NCCL_IB_DISABLE"] = "1"
|
| 880 |
+
warn = True
|
| 881 |
+
if warn:
|
| 882 |
+
logger.warning(message)
|
| 883 |
+
|
| 884 |
+
if args.num_machines > 1 and args.deepspeed_multinode_launcher != DEEPSPEED_MULTINODE_LAUNCHERS[1]:
|
| 885 |
+
with open(DEEPSPEED_ENVIRONMENT_NAME, "a") as f:
|
| 886 |
+
valid_env_items = convert_dict_to_env_variables(current_env)
|
| 887 |
+
if len(valid_env_items) > 1:
|
| 888 |
+
f.writelines(valid_env_items)
|
| 889 |
+
|
| 890 |
+
process = subprocess.Popen(cmd, env=current_env)
|
| 891 |
+
process.wait()
|
| 892 |
+
if process.returncode != 0:
|
| 893 |
+
if not args.quiet:
|
| 894 |
+
raise subprocess.CalledProcessError(returncode=process.returncode, cmd=cmd)
|
| 895 |
+
else:
|
| 896 |
+
sys.exit(1)
|
| 897 |
+
else:
|
| 898 |
+
debug = getattr(args, "debug", False)
|
| 899 |
+
args = _filter_args(
|
| 900 |
+
args,
|
| 901 |
+
distrib_run.get_args_parser(),
|
| 902 |
+
["--training_script", args.training_script, "--training_script_args", args.training_script_args],
|
| 903 |
+
)
|
| 904 |
+
with patch_environment(**current_env):
|
| 905 |
+
try:
|
| 906 |
+
distrib_run.run(args)
|
| 907 |
+
except Exception:
|
| 908 |
+
if is_rich_available() and debug:
|
| 909 |
+
console = get_console()
|
| 910 |
+
console.print("\n[bold red]Using --debug, `torch.distributed` Stack Trace:[/bold red]")
|
| 911 |
+
console.print_exception(suppress=[__file__], show_locals=False)
|
| 912 |
+
else:
|
| 913 |
+
raise
|
| 914 |
+
|
| 915 |
+
|
| 916 |
+
def tpu_launcher(args):
|
| 917 |
+
import torch_xla.distributed.xla_multiprocessing as xmp
|
| 918 |
+
|
| 919 |
+
if args.no_python:
|
| 920 |
+
raise ValueError("--no_python cannot be used with TPU launcher")
|
| 921 |
+
|
| 922 |
+
args, current_env = prepare_tpu(args, {})
|
| 923 |
+
|
| 924 |
+
if args.module:
|
| 925 |
+
mod_name = args.training_script
|
| 926 |
+
else:
|
| 927 |
+
# Import training_script as a module
|
| 928 |
+
script_path = Path(args.training_script)
|
| 929 |
+
sys.path.append(str(script_path.parent.resolve()))
|
| 930 |
+
mod_name = script_path.stem
|
| 931 |
+
|
| 932 |
+
mod = importlib.import_module(mod_name)
|
| 933 |
+
if not hasattr(mod, args.main_training_function):
|
| 934 |
+
raise ValueError(
|
| 935 |
+
f"Your training script should have a function named {args.main_training_function}, or you should pass a "
|
| 936 |
+
"different value to `--main_training_function`."
|
| 937 |
+
)
|
| 938 |
+
|
| 939 |
+
# Patch sys.argv
|
| 940 |
+
sys.argv = [mod.__file__] + args.training_script_args
|
| 941 |
+
|
| 942 |
+
main_function = getattr(mod, args.main_training_function)
|
| 943 |
+
with patch_environment(**current_env):
|
| 944 |
+
xmp.spawn(PrepareForLaunch(main_function), args=())
|
| 945 |
+
|
| 946 |
+
|
| 947 |
+
def tpu_pod_launcher(args):
|
| 948 |
+
from torch_xla.distributed import xla_dist
|
| 949 |
+
|
| 950 |
+
current_env = {}
|
| 951 |
+
args, current_env = prepare_tpu(args, current_env, True)
|
| 952 |
+
debug = getattr(args, "debug", False)
|
| 953 |
+
|
| 954 |
+
training_script = args.training_script
|
| 955 |
+
training_script_args = args.training_script_args
|
| 956 |
+
new_args = _filter_args(
|
| 957 |
+
args, xla_dist.get_args_parser(), ["--tpu", args.tpu_name, "--positional", "", "--restart-tpuvm-pod-server"]
|
| 958 |
+
)
|
| 959 |
+
|
| 960 |
+
if args.tpu_use_sudo:
|
| 961 |
+
new_cmd = ["sudo"]
|
| 962 |
+
else:
|
| 963 |
+
new_cmd = []
|
| 964 |
+
|
| 965 |
+
new_cmd += [
|
| 966 |
+
"accelerate-launch",
|
| 967 |
+
"--tpu",
|
| 968 |
+
"--no_tpu_cluster",
|
| 969 |
+
"--num_machines",
|
| 970 |
+
"1",
|
| 971 |
+
"--mixed_precision",
|
| 972 |
+
"no",
|
| 973 |
+
"--dynamo_backend",
|
| 974 |
+
"no",
|
| 975 |
+
"--num_processes",
|
| 976 |
+
str(args.num_processes),
|
| 977 |
+
"--main_training_function",
|
| 978 |
+
str(args.main_training_function),
|
| 979 |
+
training_script,
|
| 980 |
+
] + training_script_args
|
| 981 |
+
|
| 982 |
+
new_args.positional = new_cmd
|
| 983 |
+
bad_flags = ""
|
| 984 |
+
for arg in vars(new_args):
|
| 985 |
+
if arg.startswith("docker_"):
|
| 986 |
+
value = getattr(new_args, arg)
|
| 987 |
+
if value != "" and value is not None:
|
| 988 |
+
bad_flags += f'{arg}="{value}"\n'
|
| 989 |
+
if bad_flags != "":
|
| 990 |
+
raise ValueError(
|
| 991 |
+
f"Docker containers are not supported for TPU pod launcher currently, please remove the following flags:\n{bad_flags}"
|
| 992 |
+
)
|
| 993 |
+
new_args.env = [f"{k}={v}" for k, v in current_env.items()]
|
| 994 |
+
new_args.env.append("ACCELERATE_IN_TPU_POD=1")
|
| 995 |
+
try:
|
| 996 |
+
xla_dist.resolve_and_execute(new_args)
|
| 997 |
+
except Exception:
|
| 998 |
+
if is_rich_available() and debug:
|
| 999 |
+
console = get_console()
|
| 1000 |
+
console.print("\n[bold red]Using --debug, `torch_xla.xla_dist` Stack Trace:[/bold red]")
|
| 1001 |
+
console.print_exception(suppress=[__file__], show_locals=False)
|
| 1002 |
+
else:
|
| 1003 |
+
raise
|
| 1004 |
+
|
| 1005 |
+
|
| 1006 |
+
def sagemaker_launcher(sagemaker_config: SageMakerConfig, args):
|
| 1007 |
+
if not is_sagemaker_available():
|
| 1008 |
+
raise ImportError(
|
| 1009 |
+
"Please install sagemaker to be able to launch training on Amazon SageMaker with `pip install accelerate[sagemaker]`"
|
| 1010 |
+
)
|
| 1011 |
+
if args.module or args.no_python:
|
| 1012 |
+
raise ValueError(
|
| 1013 |
+
"SageMaker requires a python training script file and cannot be used with --module or --no_python"
|
| 1014 |
+
)
|
| 1015 |
+
|
| 1016 |
+
from sagemaker.huggingface import HuggingFace
|
| 1017 |
+
|
| 1018 |
+
args, sagemaker_inputs = prepare_sagemager_args_inputs(sagemaker_config, args)
|
| 1019 |
+
|
| 1020 |
+
huggingface_estimator = HuggingFace(**args)
|
| 1021 |
+
|
| 1022 |
+
huggingface_estimator.fit(inputs=sagemaker_inputs)
|
| 1023 |
+
print(f"You can find your model data at: {huggingface_estimator.model_data}")
|
| 1024 |
+
|
| 1025 |
+
|
| 1026 |
+
def _validate_launch_command(args):
|
| 1027 |
+
# Sanity checks
|
| 1028 |
+
if sum([args.multi_gpu, args.cpu, args.tpu, args.use_deepspeed, args.use_fsdp]) > 1:
|
| 1029 |
+
raise ValueError(
|
| 1030 |
+
"You can only use one of `--cpu`, `--multi_gpu`, `--tpu`, `--use_deepspeed`, `--use_fsdp` at a time."
|
| 1031 |
+
)
|
| 1032 |
+
if args.multi_gpu and (args.num_processes is not None) and (args.num_processes < 2):
|
| 1033 |
+
raise ValueError("You need to use at least 2 processes to use `--multi_gpu`.")
|
| 1034 |
+
|
| 1035 |
+
if (not args.use_fsdp or args.fsdp_version == 1) and args.use_parallelism_config:
|
| 1036 |
+
raise ValueError("You cannot use `--use_parallelism_config` without `--use_fsdp` and `--fsdp_version=2`. ")
|
| 1037 |
+
|
| 1038 |
+
defaults = None
|
| 1039 |
+
warned = []
|
| 1040 |
+
mp_from_config_flag = False
|
| 1041 |
+
# Get the default from the config file.
|
| 1042 |
+
if args.config_file is not None or os.path.isfile(default_config_file) and not args.cpu:
|
| 1043 |
+
defaults = load_config_from_file(args.config_file)
|
| 1044 |
+
if (
|
| 1045 |
+
not args.multi_gpu
|
| 1046 |
+
and not args.tpu
|
| 1047 |
+
and not args.tpu_use_cluster
|
| 1048 |
+
and not args.use_deepspeed
|
| 1049 |
+
and not args.use_fsdp
|
| 1050 |
+
and not args.use_megatron_lm
|
| 1051 |
+
):
|
| 1052 |
+
args.use_deepspeed = defaults.distributed_type == DistributedType.DEEPSPEED
|
| 1053 |
+
args.multi_gpu = (
|
| 1054 |
+
True
|
| 1055 |
+
if defaults.distributed_type
|
| 1056 |
+
in (
|
| 1057 |
+
DistributedType.MULTI_GPU,
|
| 1058 |
+
DistributedType.MULTI_NPU,
|
| 1059 |
+
DistributedType.MULTI_MLU,
|
| 1060 |
+
DistributedType.MULTI_SDAA,
|
| 1061 |
+
DistributedType.MULTI_MUSA,
|
| 1062 |
+
DistributedType.MULTI_XPU,
|
| 1063 |
+
DistributedType.MULTI_HPU,
|
| 1064 |
+
)
|
| 1065 |
+
else False
|
| 1066 |
+
)
|
| 1067 |
+
args.tpu = defaults.distributed_type == DistributedType.XLA
|
| 1068 |
+
args.use_fsdp = defaults.distributed_type == DistributedType.FSDP
|
| 1069 |
+
args.use_megatron_lm = defaults.distributed_type == DistributedType.MEGATRON_LM
|
| 1070 |
+
args.tpu_use_cluster = defaults.tpu_use_cluster if args.tpu else False
|
| 1071 |
+
args.use_parallelism_config = defaults.parallelism_config != {}
|
| 1072 |
+
if args.gpu_ids is None:
|
| 1073 |
+
if defaults.gpu_ids is not None:
|
| 1074 |
+
args.gpu_ids = defaults.gpu_ids
|
| 1075 |
+
else:
|
| 1076 |
+
args.gpu_ids = "all"
|
| 1077 |
+
|
| 1078 |
+
if args.multi_gpu and args.num_machines is None:
|
| 1079 |
+
args.num_machines = defaults.num_machines
|
| 1080 |
+
|
| 1081 |
+
if len(args.gpu_ids.split(",")) < 2 and (args.gpu_ids != "all") and args.multi_gpu and args.num_machines <= 1:
|
| 1082 |
+
raise ValueError(
|
| 1083 |
+
"Less than two GPU ids were configured and tried to run on on multiple GPUs. "
|
| 1084 |
+
"Please ensure at least two are specified for `--gpu_ids`, or use `--gpu_ids='all'`."
|
| 1085 |
+
)
|
| 1086 |
+
if defaults.compute_environment == ComputeEnvironment.LOCAL_MACHINE:
|
| 1087 |
+
# Update args with the defaults
|
| 1088 |
+
for name, attr in defaults.__dict__.items():
|
| 1089 |
+
if isinstance(attr, dict):
|
| 1090 |
+
# Copy defaults.somedict.somearg to args.somearg and
|
| 1091 |
+
# defaults.fsdp_config.x to args.fsdp_x
|
| 1092 |
+
for key, value in attr.items():
|
| 1093 |
+
if name == "fsdp_config" and not key.startswith("fsdp"):
|
| 1094 |
+
key = "fsdp_" + key
|
| 1095 |
+
elif name == "fp8_config" and not key.startswith("fp8"):
|
| 1096 |
+
key = "fp8_" + key
|
| 1097 |
+
if hasattr(args, "nondefault") and key not in args.nondefault:
|
| 1098 |
+
setattr(args, key, value)
|
| 1099 |
+
elif (
|
| 1100 |
+
name not in ["compute_environment", "mixed_precision", "distributed_type"]
|
| 1101 |
+
and getattr(args, name, None) is None
|
| 1102 |
+
):
|
| 1103 |
+
# Those args are handled separately
|
| 1104 |
+
setattr(args, name, attr)
|
| 1105 |
+
if not args.debug:
|
| 1106 |
+
args.debug = defaults.debug
|
| 1107 |
+
|
| 1108 |
+
if not args.mixed_precision:
|
| 1109 |
+
if defaults.mixed_precision is None:
|
| 1110 |
+
args.mixed_precision = "no"
|
| 1111 |
+
else:
|
| 1112 |
+
args.mixed_precision = defaults.mixed_precision
|
| 1113 |
+
mp_from_config_flag = True
|
| 1114 |
+
else:
|
| 1115 |
+
native_amp = is_bf16_available(True)
|
| 1116 |
+
if (
|
| 1117 |
+
args.mixed_precision == "bf16"
|
| 1118 |
+
and not native_amp
|
| 1119 |
+
and not (args.tpu and is_torch_xla_available(check_is_tpu=True))
|
| 1120 |
+
):
|
| 1121 |
+
raise ValueError("bf16 mixed precision requires PyTorch >= 1.10 and a supported device.")
|
| 1122 |
+
|
| 1123 |
+
# Silently set the default here
|
| 1124 |
+
if args.dynamo_backend is None:
|
| 1125 |
+
args.dynamo_backend = "no"
|
| 1126 |
+
if args.num_processes == -1:
|
| 1127 |
+
raise ValueError("You need to manually pass in `--num_processes` using this config yaml.")
|
| 1128 |
+
else:
|
| 1129 |
+
if args.num_processes is None:
|
| 1130 |
+
if is_xpu_available():
|
| 1131 |
+
args.num_processes = torch.xpu.device_count()
|
| 1132 |
+
elif is_mlu_available():
|
| 1133 |
+
args.num_processes = torch.mlu.device_count()
|
| 1134 |
+
elif is_sdaa_available():
|
| 1135 |
+
args.num_processes = torch.sdaa.device_count()
|
| 1136 |
+
elif is_musa_available():
|
| 1137 |
+
args.num_processes = torch.musa.device_count()
|
| 1138 |
+
elif is_npu_available():
|
| 1139 |
+
args.num_processes = torch.npu.device_count()
|
| 1140 |
+
elif is_hpu_available():
|
| 1141 |
+
args.num_processes = torch.hpu.device_count()
|
| 1142 |
+
else:
|
| 1143 |
+
args.num_processes = torch.cuda.device_count()
|
| 1144 |
+
warned.append(f"\t`--num_processes` was set to a value of `{args.num_processes}`")
|
| 1145 |
+
if args.debug is None:
|
| 1146 |
+
args.debug = False
|
| 1147 |
+
if (
|
| 1148 |
+
not args.multi_gpu
|
| 1149 |
+
and args.num_processes > 1
|
| 1150 |
+
and (
|
| 1151 |
+
(is_xpu_available() and torch.xpu.device_count() > 1)
|
| 1152 |
+
or (is_npu_available() and torch.npu.device_count() > 1)
|
| 1153 |
+
or (is_hpu_available() and torch.hpu.device_count() > 1)
|
| 1154 |
+
or (is_mlu_available() and torch.mlu.device_count() > 1)
|
| 1155 |
+
or (is_sdaa_available() and torch.sdaa.device_count() > 1)
|
| 1156 |
+
or (is_musa_available() and torch.musa.device_count() > 1)
|
| 1157 |
+
or (torch.cuda.is_available() and torch.cuda.device_count() > 1)
|
| 1158 |
+
)
|
| 1159 |
+
):
|
| 1160 |
+
warned.append(
|
| 1161 |
+
"\t\tMore than one GPU was found, enabling multi-GPU training.\n"
|
| 1162 |
+
"\t\tIf this was unintended please pass in `--num_processes=1`."
|
| 1163 |
+
)
|
| 1164 |
+
args.multi_gpu = True
|
| 1165 |
+
if args.num_machines is None:
|
| 1166 |
+
warned.append("\t`--num_machines` was set to a value of `1`")
|
| 1167 |
+
args.num_machines = 1
|
| 1168 |
+
if args.mixed_precision is None:
|
| 1169 |
+
warned.append("\t`--mixed_precision` was set to a value of `'no'`")
|
| 1170 |
+
args.mixed_precision = "no"
|
| 1171 |
+
if not hasattr(args, "use_cpu"):
|
| 1172 |
+
args.use_cpu = args.cpu
|
| 1173 |
+
if args.dynamo_backend is None:
|
| 1174 |
+
warned.append("\t`--dynamo_backend` was set to a value of `'no'`")
|
| 1175 |
+
args.dynamo_backend = "no"
|
| 1176 |
+
if args.debug:
|
| 1177 |
+
logger.debug("Running script in debug mode, expect distributed operations to be slightly slower.")
|
| 1178 |
+
|
| 1179 |
+
is_aws_env_disabled = defaults is None or (
|
| 1180 |
+
defaults is not None and defaults.compute_environment != ComputeEnvironment.AMAZON_SAGEMAKER
|
| 1181 |
+
)
|
| 1182 |
+
if is_aws_env_disabled and args.num_cpu_threads_per_process is None:
|
| 1183 |
+
args.num_cpu_threads_per_process = get_int_from_env(["OMP_NUM_THREADS"], 1)
|
| 1184 |
+
if args.use_cpu and args.num_processes >= 1 and get_int_from_env(["OMP_NUM_THREADS"], 0) == 0:
|
| 1185 |
+
local_size = get_int_from_env(
|
| 1186 |
+
["MPI_LOCALNRANKS", "OMPI_COMM_WORLD_LOCAL_SIZE", "MV2_COMM_WORLD_LOCAL_SIZE"],
|
| 1187 |
+
max(int(args.num_processes / args.num_machines), 1),
|
| 1188 |
+
)
|
| 1189 |
+
threads_per_process = int(psutil.cpu_count(logical=False) / local_size)
|
| 1190 |
+
if threads_per_process > 1:
|
| 1191 |
+
args.num_cpu_threads_per_process = threads_per_process
|
| 1192 |
+
warned.append(
|
| 1193 |
+
f"\t`--num_cpu_threads_per_process` was set to `{args.num_cpu_threads_per_process}` to improve out-of-box performance when training on CPUs"
|
| 1194 |
+
)
|
| 1195 |
+
|
| 1196 |
+
if args.use_xpu is not None:
|
| 1197 |
+
logger.warning(
|
| 1198 |
+
"use_xpu is deprecated and ignored, will be removed in Accelerate v1.20. "
|
| 1199 |
+
"XPU is a PyTorch native citizen now, we don't need extra argument to enable it any more."
|
| 1200 |
+
)
|
| 1201 |
+
|
| 1202 |
+
if any(warned):
|
| 1203 |
+
message = "The following values were not passed to `accelerate launch` and had defaults used instead:\n"
|
| 1204 |
+
message += "\n".join(warned)
|
| 1205 |
+
message += (
|
| 1206 |
+
"\nTo avoid this warning pass in values for each of the problematic parameters or run `accelerate config`."
|
| 1207 |
+
)
|
| 1208 |
+
logger.warning(message)
|
| 1209 |
+
return args, defaults, mp_from_config_flag
|
| 1210 |
+
|
| 1211 |
+
|
| 1212 |
+
def launch_command(args):
|
| 1213 |
+
args, defaults, mp_from_config_flag = _validate_launch_command(args)
|
| 1214 |
+
# Use the proper launcher
|
| 1215 |
+
if args.use_deepspeed and not args.cpu:
|
| 1216 |
+
args.deepspeed_fields_from_accelerate_config = list(defaults.deepspeed_config.keys()) if defaults else []
|
| 1217 |
+
if mp_from_config_flag:
|
| 1218 |
+
args.deepspeed_fields_from_accelerate_config.append("mixed_precision")
|
| 1219 |
+
args.deepspeed_fields_from_accelerate_config = ",".join(args.deepspeed_fields_from_accelerate_config)
|
| 1220 |
+
deepspeed_launcher(args)
|
| 1221 |
+
elif args.use_fsdp and not args.cpu:
|
| 1222 |
+
multi_gpu_launcher(args)
|
| 1223 |
+
elif args.use_megatron_lm and not args.cpu:
|
| 1224 |
+
multi_gpu_launcher(args)
|
| 1225 |
+
elif args.multi_gpu and not args.cpu:
|
| 1226 |
+
multi_gpu_launcher(args)
|
| 1227 |
+
elif args.tpu and not args.cpu:
|
| 1228 |
+
if args.tpu_use_cluster:
|
| 1229 |
+
tpu_pod_launcher(args)
|
| 1230 |
+
else:
|
| 1231 |
+
tpu_launcher(args)
|
| 1232 |
+
elif defaults is not None and defaults.compute_environment == ComputeEnvironment.AMAZON_SAGEMAKER:
|
| 1233 |
+
sagemaker_launcher(defaults, args)
|
| 1234 |
+
else:
|
| 1235 |
+
simple_launcher(args)
|
| 1236 |
+
|
| 1237 |
+
|
| 1238 |
+
def main():
|
| 1239 |
+
parser = launch_command_parser()
|
| 1240 |
+
args = parser.parse_args()
|
| 1241 |
+
launch_command(args)
|
| 1242 |
+
|
| 1243 |
+
|
| 1244 |
+
if __name__ == "__main__":
|
| 1245 |
+
main()
|
pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/INSTALLER
ADDED
|
@@ -0,0 +1 @@
|
|
|
|
|
|
|
| 1 |
+
pip
|
pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/METADATA
ADDED
|
@@ -0,0 +1,441 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Metadata-Version: 2.1
|
| 2 |
+
Name: colorama
|
| 3 |
+
Version: 0.4.6
|
| 4 |
+
Summary: Cross-platform colored terminal text.
|
| 5 |
+
Project-URL: Homepage, https://github.com/tartley/colorama
|
| 6 |
+
Author-email: Jonathan Hartley <tartley@tartley.com>
|
| 7 |
+
License-File: LICENSE.txt
|
| 8 |
+
Keywords: ansi,color,colour,crossplatform,terminal,text,windows,xplatform
|
| 9 |
+
Classifier: Development Status :: 5 - Production/Stable
|
| 10 |
+
Classifier: Environment :: Console
|
| 11 |
+
Classifier: Intended Audience :: Developers
|
| 12 |
+
Classifier: License :: OSI Approved :: BSD License
|
| 13 |
+
Classifier: Operating System :: OS Independent
|
| 14 |
+
Classifier: Programming Language :: Python
|
| 15 |
+
Classifier: Programming Language :: Python :: 2
|
| 16 |
+
Classifier: Programming Language :: Python :: 2.7
|
| 17 |
+
Classifier: Programming Language :: Python :: 3
|
| 18 |
+
Classifier: Programming Language :: Python :: 3.7
|
| 19 |
+
Classifier: Programming Language :: Python :: 3.8
|
| 20 |
+
Classifier: Programming Language :: Python :: 3.9
|
| 21 |
+
Classifier: Programming Language :: Python :: 3.10
|
| 22 |
+
Classifier: Programming Language :: Python :: Implementation :: CPython
|
| 23 |
+
Classifier: Programming Language :: Python :: Implementation :: PyPy
|
| 24 |
+
Classifier: Topic :: Terminals
|
| 25 |
+
Requires-Python: !=3.0.*,!=3.1.*,!=3.2.*,!=3.3.*,!=3.4.*,!=3.5.*,!=3.6.*,>=2.7
|
| 26 |
+
Description-Content-Type: text/x-rst
|
| 27 |
+
|
| 28 |
+
.. image:: https://img.shields.io/pypi/v/colorama.svg
|
| 29 |
+
:target: https://pypi.org/project/colorama/
|
| 30 |
+
:alt: Latest Version
|
| 31 |
+
|
| 32 |
+
.. image:: https://img.shields.io/pypi/pyversions/colorama.svg
|
| 33 |
+
:target: https://pypi.org/project/colorama/
|
| 34 |
+
:alt: Supported Python versions
|
| 35 |
+
|
| 36 |
+
.. image:: https://github.com/tartley/colorama/actions/workflows/test.yml/badge.svg
|
| 37 |
+
:target: https://github.com/tartley/colorama/actions/workflows/test.yml
|
| 38 |
+
:alt: Build Status
|
| 39 |
+
|
| 40 |
+
Colorama
|
| 41 |
+
========
|
| 42 |
+
|
| 43 |
+
Makes ANSI escape character sequences (for producing colored terminal text and
|
| 44 |
+
cursor positioning) work under MS Windows.
|
| 45 |
+
|
| 46 |
+
.. |donate| image:: https://www.paypalobjects.com/en_US/i/btn/btn_donate_SM.gif
|
| 47 |
+
:target: https://www.paypal.com/cgi-bin/webscr?cmd=_donations&business=2MZ9D2GMLYCUJ&item_name=Colorama¤cy_code=USD
|
| 48 |
+
:alt: Donate with Paypal
|
| 49 |
+
|
| 50 |
+
`PyPI for releases <https://pypi.org/project/colorama/>`_ |
|
| 51 |
+
`Github for source <https://github.com/tartley/colorama>`_ |
|
| 52 |
+
`Colorama for enterprise on Tidelift <https://github.com/tartley/colorama/blob/master/ENTERPRISE.md>`_
|
| 53 |
+
|
| 54 |
+
If you find Colorama useful, please |donate| to the authors. Thank you!
|
| 55 |
+
|
| 56 |
+
Installation
|
| 57 |
+
------------
|
| 58 |
+
|
| 59 |
+
Tested on CPython 2.7, 3.7, 3.8, 3.9 and 3.10 and Pypy 2.7 and 3.8.
|
| 60 |
+
|
| 61 |
+
No requirements other than the standard library.
|
| 62 |
+
|
| 63 |
+
.. code-block:: bash
|
| 64 |
+
|
| 65 |
+
pip install colorama
|
| 66 |
+
# or
|
| 67 |
+
conda install -c anaconda colorama
|
| 68 |
+
|
| 69 |
+
Description
|
| 70 |
+
-----------
|
| 71 |
+
|
| 72 |
+
ANSI escape character sequences have long been used to produce colored terminal
|
| 73 |
+
text and cursor positioning on Unix and Macs. Colorama makes this work on
|
| 74 |
+
Windows, too, by wrapping ``stdout``, stripping ANSI sequences it finds (which
|
| 75 |
+
would appear as gobbledygook in the output), and converting them into the
|
| 76 |
+
appropriate win32 calls to modify the state of the terminal. On other platforms,
|
| 77 |
+
Colorama does nothing.
|
| 78 |
+
|
| 79 |
+
This has the upshot of providing a simple cross-platform API for printing
|
| 80 |
+
colored terminal text from Python, and has the happy side-effect that existing
|
| 81 |
+
applications or libraries which use ANSI sequences to produce colored output on
|
| 82 |
+
Linux or Macs can now also work on Windows, simply by calling
|
| 83 |
+
``colorama.just_fix_windows_console()`` (since v0.4.6) or ``colorama.init()``
|
| 84 |
+
(all versions, but may have other side-effects – see below).
|
| 85 |
+
|
| 86 |
+
An alternative approach is to install ``ansi.sys`` on Windows machines, which
|
| 87 |
+
provides the same behaviour for all applications running in terminals. Colorama
|
| 88 |
+
is intended for situations where that isn't easy (e.g., maybe your app doesn't
|
| 89 |
+
have an installer.)
|
| 90 |
+
|
| 91 |
+
Demo scripts in the source code repository print some colored text using
|
| 92 |
+
ANSI sequences. Compare their output under Gnome-terminal's built in ANSI
|
| 93 |
+
handling, versus on Windows Command-Prompt using Colorama:
|
| 94 |
+
|
| 95 |
+
.. image:: https://github.com/tartley/colorama/raw/master/screenshots/ubuntu-demo.png
|
| 96 |
+
:width: 661
|
| 97 |
+
:height: 357
|
| 98 |
+
:alt: ANSI sequences on Ubuntu under gnome-terminal.
|
| 99 |
+
|
| 100 |
+
.. image:: https://github.com/tartley/colorama/raw/master/screenshots/windows-demo.png
|
| 101 |
+
:width: 668
|
| 102 |
+
:height: 325
|
| 103 |
+
:alt: Same ANSI sequences on Windows, using Colorama.
|
| 104 |
+
|
| 105 |
+
These screenshots show that, on Windows, Colorama does not support ANSI 'dim
|
| 106 |
+
text'; it looks the same as 'normal text'.
|
| 107 |
+
|
| 108 |
+
Usage
|
| 109 |
+
-----
|
| 110 |
+
|
| 111 |
+
Initialisation
|
| 112 |
+
..............
|
| 113 |
+
|
| 114 |
+
If the only thing you want from Colorama is to get ANSI escapes to work on
|
| 115 |
+
Windows, then run:
|
| 116 |
+
|
| 117 |
+
.. code-block:: python
|
| 118 |
+
|
| 119 |
+
from colorama import just_fix_windows_console
|
| 120 |
+
just_fix_windows_console()
|
| 121 |
+
|
| 122 |
+
If you're on a recent version of Windows 10 or better, and your stdout/stderr
|
| 123 |
+
are pointing to a Windows console, then this will flip the magic configuration
|
| 124 |
+
switch to enable Windows' built-in ANSI support.
|
| 125 |
+
|
| 126 |
+
If you're on an older version of Windows, and your stdout/stderr are pointing to
|
| 127 |
+
a Windows console, then this will wrap ``sys.stdout`` and/or ``sys.stderr`` in a
|
| 128 |
+
magic file object that intercepts ANSI escape sequences and issues the
|
| 129 |
+
appropriate Win32 calls to emulate them.
|
| 130 |
+
|
| 131 |
+
In all other circumstances, it does nothing whatsoever. Basically the idea is
|
| 132 |
+
that this makes Windows act like Unix with respect to ANSI escape handling.
|
| 133 |
+
|
| 134 |
+
It's safe to call this function multiple times. It's safe to call this function
|
| 135 |
+
on non-Windows platforms, but it won't do anything. It's safe to call this
|
| 136 |
+
function when one or both of your stdout/stderr are redirected to a file – it
|
| 137 |
+
won't do anything to those streams.
|
| 138 |
+
|
| 139 |
+
Alternatively, you can use the older interface with more features (but also more
|
| 140 |
+
potential footguns):
|
| 141 |
+
|
| 142 |
+
.. code-block:: python
|
| 143 |
+
|
| 144 |
+
from colorama import init
|
| 145 |
+
init()
|
| 146 |
+
|
| 147 |
+
This does the same thing as ``just_fix_windows_console``, except for the
|
| 148 |
+
following differences:
|
| 149 |
+
|
| 150 |
+
- It's not safe to call ``init`` multiple times; you can end up with multiple
|
| 151 |
+
layers of wrapping and broken ANSI support.
|
| 152 |
+
|
| 153 |
+
- Colorama will apply a heuristic to guess whether stdout/stderr support ANSI,
|
| 154 |
+
and if it thinks they don't, then it will wrap ``sys.stdout`` and
|
| 155 |
+
``sys.stderr`` in a magic file object that strips out ANSI escape sequences
|
| 156 |
+
before printing them. This happens on all platforms, and can be convenient if
|
| 157 |
+
you want to write your code to emit ANSI escape sequences unconditionally, and
|
| 158 |
+
let Colorama decide whether they should actually be output. But note that
|
| 159 |
+
Colorama's heuristic is not particularly clever.
|
| 160 |
+
|
| 161 |
+
- ``init`` also accepts explicit keyword args to enable/disable various
|
| 162 |
+
functionality – see below.
|
| 163 |
+
|
| 164 |
+
To stop using Colorama before your program exits, simply call ``deinit()``.
|
| 165 |
+
This will restore ``stdout`` and ``stderr`` to their original values, so that
|
| 166 |
+
Colorama is disabled. To resume using Colorama again, call ``reinit()``; it is
|
| 167 |
+
cheaper than calling ``init()`` again (but does the same thing).
|
| 168 |
+
|
| 169 |
+
Most users should depend on ``colorama >= 0.4.6``, and use
|
| 170 |
+
``just_fix_windows_console``. The old ``init`` interface will be supported
|
| 171 |
+
indefinitely for backwards compatibility, but we don't plan to fix any issues
|
| 172 |
+
with it, also for backwards compatibility.
|
| 173 |
+
|
| 174 |
+
Colored Output
|
| 175 |
+
..............
|
| 176 |
+
|
| 177 |
+
Cross-platform printing of colored text can then be done using Colorama's
|
| 178 |
+
constant shorthand for ANSI escape sequences. These are deliberately
|
| 179 |
+
rudimentary, see below.
|
| 180 |
+
|
| 181 |
+
.. code-block:: python
|
| 182 |
+
|
| 183 |
+
from colorama import Fore, Back, Style
|
| 184 |
+
print(Fore.RED + 'some red text')
|
| 185 |
+
print(Back.GREEN + 'and with a green background')
|
| 186 |
+
print(Style.DIM + 'and in dim text')
|
| 187 |
+
print(Style.RESET_ALL)
|
| 188 |
+
print('back to normal now')
|
| 189 |
+
|
| 190 |
+
...or simply by manually printing ANSI sequences from your own code:
|
| 191 |
+
|
| 192 |
+
.. code-block:: python
|
| 193 |
+
|
| 194 |
+
print('\033[31m' + 'some red text')
|
| 195 |
+
print('\033[39m') # and reset to default color
|
| 196 |
+
|
| 197 |
+
...or, Colorama can be used in conjunction with existing ANSI libraries
|
| 198 |
+
such as the venerable `Termcolor <https://pypi.org/project/termcolor/>`_
|
| 199 |
+
the fabulous `Blessings <https://pypi.org/project/blessings/>`_,
|
| 200 |
+
or the incredible `_Rich <https://pypi.org/project/rich/>`_.
|
| 201 |
+
|
| 202 |
+
If you wish Colorama's Fore, Back and Style constants were more capable,
|
| 203 |
+
then consider using one of the above highly capable libraries to generate
|
| 204 |
+
colors, etc, and use Colorama just for its primary purpose: to convert
|
| 205 |
+
those ANSI sequences to also work on Windows:
|
| 206 |
+
|
| 207 |
+
SIMILARLY, do not send PRs adding the generation of new ANSI types to Colorama.
|
| 208 |
+
We are only interested in converting ANSI codes to win32 API calls, not
|
| 209 |
+
shortcuts like the above to generate ANSI characters.
|
| 210 |
+
|
| 211 |
+
.. code-block:: python
|
| 212 |
+
|
| 213 |
+
from colorama import just_fix_windows_console
|
| 214 |
+
from termcolor import colored
|
| 215 |
+
|
| 216 |
+
# use Colorama to make Termcolor work on Windows too
|
| 217 |
+
just_fix_windows_console()
|
| 218 |
+
|
| 219 |
+
# then use Termcolor for all colored text output
|
| 220 |
+
print(colored('Hello, World!', 'green', 'on_red'))
|
| 221 |
+
|
| 222 |
+
Available formatting constants are::
|
| 223 |
+
|
| 224 |
+
Fore: BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, RESET.
|
| 225 |
+
Back: BLACK, RED, GREEN, YELLOW, BLUE, MAGENTA, CYAN, WHITE, RESET.
|
| 226 |
+
Style: DIM, NORMAL, BRIGHT, RESET_ALL
|
| 227 |
+
|
| 228 |
+
``Style.RESET_ALL`` resets foreground, background, and brightness. Colorama will
|
| 229 |
+
perform this reset automatically on program exit.
|
| 230 |
+
|
| 231 |
+
These are fairly well supported, but not part of the standard::
|
| 232 |
+
|
| 233 |
+
Fore: LIGHTBLACK_EX, LIGHTRED_EX, LIGHTGREEN_EX, LIGHTYELLOW_EX, LIGHTBLUE_EX, LIGHTMAGENTA_EX, LIGHTCYAN_EX, LIGHTWHITE_EX
|
| 234 |
+
Back: LIGHTBLACK_EX, LIGHTRED_EX, LIGHTGREEN_EX, LIGHTYELLOW_EX, LIGHTBLUE_EX, LIGHTMAGENTA_EX, LIGHTCYAN_EX, LIGHTWHITE_EX
|
| 235 |
+
|
| 236 |
+
Cursor Positioning
|
| 237 |
+
..................
|
| 238 |
+
|
| 239 |
+
ANSI codes to reposition the cursor are supported. See ``demos/demo06.py`` for
|
| 240 |
+
an example of how to generate them.
|
| 241 |
+
|
| 242 |
+
Init Keyword Args
|
| 243 |
+
.................
|
| 244 |
+
|
| 245 |
+
``init()`` accepts some ``**kwargs`` to override default behaviour.
|
| 246 |
+
|
| 247 |
+
init(autoreset=False):
|
| 248 |
+
If you find yourself repeatedly sending reset sequences to turn off color
|
| 249 |
+
changes at the end of every print, then ``init(autoreset=True)`` will
|
| 250 |
+
automate that:
|
| 251 |
+
|
| 252 |
+
.. code-block:: python
|
| 253 |
+
|
| 254 |
+
from colorama import init
|
| 255 |
+
init(autoreset=True)
|
| 256 |
+
print(Fore.RED + 'some red text')
|
| 257 |
+
print('automatically back to default color again')
|
| 258 |
+
|
| 259 |
+
init(strip=None):
|
| 260 |
+
Pass ``True`` or ``False`` to override whether ANSI codes should be
|
| 261 |
+
stripped from the output. The default behaviour is to strip if on Windows
|
| 262 |
+
or if output is redirected (not a tty).
|
| 263 |
+
|
| 264 |
+
init(convert=None):
|
| 265 |
+
Pass ``True`` or ``False`` to override whether to convert ANSI codes in the
|
| 266 |
+
output into win32 calls. The default behaviour is to convert if on Windows
|
| 267 |
+
and output is to a tty (terminal).
|
| 268 |
+
|
| 269 |
+
init(wrap=True):
|
| 270 |
+
On Windows, Colorama works by replacing ``sys.stdout`` and ``sys.stderr``
|
| 271 |
+
with proxy objects, which override the ``.write()`` method to do their work.
|
| 272 |
+
If this wrapping causes you problems, then this can be disabled by passing
|
| 273 |
+
``init(wrap=False)``. The default behaviour is to wrap if ``autoreset`` or
|
| 274 |
+
``strip`` or ``convert`` are True.
|
| 275 |
+
|
| 276 |
+
When wrapping is disabled, colored printing on non-Windows platforms will
|
| 277 |
+
continue to work as normal. To do cross-platform colored output, you can
|
| 278 |
+
use Colorama's ``AnsiToWin32`` proxy directly:
|
| 279 |
+
|
| 280 |
+
.. code-block:: python
|
| 281 |
+
|
| 282 |
+
import sys
|
| 283 |
+
from colorama import init, AnsiToWin32
|
| 284 |
+
init(wrap=False)
|
| 285 |
+
stream = AnsiToWin32(sys.stderr).stream
|
| 286 |
+
|
| 287 |
+
# Python 2
|
| 288 |
+
print >>stream, Fore.BLUE + 'blue text on stderr'
|
| 289 |
+
|
| 290 |
+
# Python 3
|
| 291 |
+
print(Fore.BLUE + 'blue text on stderr', file=stream)
|
| 292 |
+
|
| 293 |
+
Recognised ANSI Sequences
|
| 294 |
+
.........................
|
| 295 |
+
|
| 296 |
+
ANSI sequences generally take the form::
|
| 297 |
+
|
| 298 |
+
ESC [ <param> ; <param> ... <command>
|
| 299 |
+
|
| 300 |
+
Where ``<param>`` is an integer, and ``<command>`` is a single letter. Zero or
|
| 301 |
+
more params are passed to a ``<command>``. If no params are passed, it is
|
| 302 |
+
generally synonymous with passing a single zero. No spaces exist in the
|
| 303 |
+
sequence; they have been inserted here simply to read more easily.
|
| 304 |
+
|
| 305 |
+
The only ANSI sequences that Colorama converts into win32 calls are::
|
| 306 |
+
|
| 307 |
+
ESC [ 0 m # reset all (colors and brightness)
|
| 308 |
+
ESC [ 1 m # bright
|
| 309 |
+
ESC [ 2 m # dim (looks same as normal brightness)
|
| 310 |
+
ESC [ 22 m # normal brightness
|
| 311 |
+
|
| 312 |
+
# FOREGROUND:
|
| 313 |
+
ESC [ 30 m # black
|
| 314 |
+
ESC [ 31 m # red
|
| 315 |
+
ESC [ 32 m # green
|
| 316 |
+
ESC [ 33 m # yellow
|
| 317 |
+
ESC [ 34 m # blue
|
| 318 |
+
ESC [ 35 m # magenta
|
| 319 |
+
ESC [ 36 m # cyan
|
| 320 |
+
ESC [ 37 m # white
|
| 321 |
+
ESC [ 39 m # reset
|
| 322 |
+
|
| 323 |
+
# BACKGROUND
|
| 324 |
+
ESC [ 40 m # black
|
| 325 |
+
ESC [ 41 m # red
|
| 326 |
+
ESC [ 42 m # green
|
| 327 |
+
ESC [ 43 m # yellow
|
| 328 |
+
ESC [ 44 m # blue
|
| 329 |
+
ESC [ 45 m # magenta
|
| 330 |
+
ESC [ 46 m # cyan
|
| 331 |
+
ESC [ 47 m # white
|
| 332 |
+
ESC [ 49 m # reset
|
| 333 |
+
|
| 334 |
+
# cursor positioning
|
| 335 |
+
ESC [ y;x H # position cursor at x across, y down
|
| 336 |
+
ESC [ y;x f # position cursor at x across, y down
|
| 337 |
+
ESC [ n A # move cursor n lines up
|
| 338 |
+
ESC [ n B # move cursor n lines down
|
| 339 |
+
ESC [ n C # move cursor n characters forward
|
| 340 |
+
ESC [ n D # move cursor n characters backward
|
| 341 |
+
|
| 342 |
+
# clear the screen
|
| 343 |
+
ESC [ mode J # clear the screen
|
| 344 |
+
|
| 345 |
+
# clear the line
|
| 346 |
+
ESC [ mode K # clear the line
|
| 347 |
+
|
| 348 |
+
Multiple numeric params to the ``'m'`` command can be combined into a single
|
| 349 |
+
sequence::
|
| 350 |
+
|
| 351 |
+
ESC [ 36 ; 45 ; 1 m # bright cyan text on magenta background
|
| 352 |
+
|
| 353 |
+
All other ANSI sequences of the form ``ESC [ <param> ; <param> ... <command>``
|
| 354 |
+
are silently stripped from the output on Windows.
|
| 355 |
+
|
| 356 |
+
Any other form of ANSI sequence, such as single-character codes or alternative
|
| 357 |
+
initial characters, are not recognised or stripped. It would be cool to add
|
| 358 |
+
them though. Let me know if it would be useful for you, via the Issues on
|
| 359 |
+
GitHub.
|
| 360 |
+
|
| 361 |
+
Status & Known Problems
|
| 362 |
+
-----------------------
|
| 363 |
+
|
| 364 |
+
I've personally only tested it on Windows XP (CMD, Console2), Ubuntu
|
| 365 |
+
(gnome-terminal, xterm), and OS X.
|
| 366 |
+
|
| 367 |
+
Some valid ANSI sequences aren't recognised.
|
| 368 |
+
|
| 369 |
+
If you're hacking on the code, see `README-hacking.md`_. ESPECIALLY, see the
|
| 370 |
+
explanation there of why we do not want PRs that allow Colorama to generate new
|
| 371 |
+
types of ANSI codes.
|
| 372 |
+
|
| 373 |
+
See outstanding issues and wish-list:
|
| 374 |
+
https://github.com/tartley/colorama/issues
|
| 375 |
+
|
| 376 |
+
If anything doesn't work for you, or doesn't do what you expected or hoped for,
|
| 377 |
+
I'd love to hear about it on that issues list, would be delighted by patches,
|
| 378 |
+
and would be happy to grant commit access to anyone who submits a working patch
|
| 379 |
+
or two.
|
| 380 |
+
|
| 381 |
+
.. _README-hacking.md: README-hacking.md
|
| 382 |
+
|
| 383 |
+
License
|
| 384 |
+
-------
|
| 385 |
+
|
| 386 |
+
Copyright Jonathan Hartley & Arnon Yaari, 2013-2020. BSD 3-Clause license; see
|
| 387 |
+
LICENSE file.
|
| 388 |
+
|
| 389 |
+
Professional support
|
| 390 |
+
--------------------
|
| 391 |
+
|
| 392 |
+
.. |tideliftlogo| image:: https://cdn2.hubspot.net/hubfs/4008838/website/logos/logos_for_download/Tidelift_primary-shorthand-logo.png
|
| 393 |
+
:alt: Tidelift
|
| 394 |
+
:target: https://tidelift.com/subscription/pkg/pypi-colorama?utm_source=pypi-colorama&utm_medium=referral&utm_campaign=readme
|
| 395 |
+
|
| 396 |
+
.. list-table::
|
| 397 |
+
:widths: 10 100
|
| 398 |
+
|
| 399 |
+
* - |tideliftlogo|
|
| 400 |
+
- Professional support for colorama is available as part of the
|
| 401 |
+
`Tidelift Subscription`_.
|
| 402 |
+
Tidelift gives software development teams a single source for purchasing
|
| 403 |
+
and maintaining their software, with professional grade assurances from
|
| 404 |
+
the experts who know it best, while seamlessly integrating with existing
|
| 405 |
+
tools.
|
| 406 |
+
|
| 407 |
+
.. _Tidelift Subscription: https://tidelift.com/subscription/pkg/pypi-colorama?utm_source=pypi-colorama&utm_medium=referral&utm_campaign=readme
|
| 408 |
+
|
| 409 |
+
Thanks
|
| 410 |
+
------
|
| 411 |
+
|
| 412 |
+
See the CHANGELOG for more thanks!
|
| 413 |
+
|
| 414 |
+
* Marc Schlaich (schlamar) for a ``setup.py`` fix for Python2.5.
|
| 415 |
+
* Marc Abramowitz, reported & fixed a crash on exit with closed ``stdout``,
|
| 416 |
+
providing a solution to issue #7's setuptools/distutils debate,
|
| 417 |
+
and other fixes.
|
| 418 |
+
* User 'eryksun', for guidance on correctly instantiating ``ctypes.windll``.
|
| 419 |
+
* Matthew McCormick for politely pointing out a longstanding crash on non-Win.
|
| 420 |
+
* Ben Hoyt, for a magnificent fix under 64-bit Windows.
|
| 421 |
+
* Jesse at Empty Square for submitting a fix for examples in the README.
|
| 422 |
+
* User 'jamessp', an observant documentation fix for cursor positioning.
|
| 423 |
+
* User 'vaal1239', Dave Mckee & Lackner Kristof for a tiny but much-needed Win7
|
| 424 |
+
fix.
|
| 425 |
+
* Julien Stuyck, for wisely suggesting Python3 compatible updates to README.
|
| 426 |
+
* Daniel Griffith for multiple fabulous patches.
|
| 427 |
+
* Oscar Lesta for a valuable fix to stop ANSI chars being sent to non-tty
|
| 428 |
+
output.
|
| 429 |
+
* Roger Binns, for many suggestions, valuable feedback, & bug reports.
|
| 430 |
+
* Tim Golden for thought and much appreciated feedback on the initial idea.
|
| 431 |
+
* User 'Zearin' for updates to the README file.
|
| 432 |
+
* John Szakmeister for adding support for light colors
|
| 433 |
+
* Charles Merriam for adding documentation to demos
|
| 434 |
+
* Jurko for a fix on 64-bit Windows CPython2.5 w/o ctypes
|
| 435 |
+
* Florian Bruhin for a fix when stdout or stderr are None
|
| 436 |
+
* Thomas Weininger for fixing ValueError on Windows
|
| 437 |
+
* Remi Rampin for better Github integration and fixes to the README file
|
| 438 |
+
* Simeon Visser for closing a file handle using 'with' and updating classifiers
|
| 439 |
+
to include Python 3.3 and 3.4
|
| 440 |
+
* Andy Neff for fixing RESET of LIGHT_EX colors.
|
| 441 |
+
* Jonathan Hartley for the initial idea and implementation.
|
pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/RECORD
ADDED
|
@@ -0,0 +1,31 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
colorama-0.4.6.dist-info/INSTALLER,sha256=zuuue4knoyJ-UwPPXg8fezS7VCrXJQrAP7zeNuwvFQg,4
|
| 2 |
+
colorama-0.4.6.dist-info/METADATA,sha256=e67SnrUMOym9sz_4TjF3vxvAV4T3aF7NyqRHHH3YEMw,17158
|
| 3 |
+
colorama-0.4.6.dist-info/RECORD,,
|
| 4 |
+
colorama-0.4.6.dist-info/WHEEL,sha256=cdcF4Fbd0FPtw2EMIOwH-3rSOTUdTCeOSXRMD1iLUb8,105
|
| 5 |
+
colorama-0.4.6.dist-info/licenses/LICENSE.txt,sha256=ysNcAmhuXQSlpxQL-zs25zrtSWZW6JEQLkKIhteTAxg,1491
|
| 6 |
+
colorama/__init__.py,sha256=wePQA4U20tKgYARySLEC047ucNX-g8pRLpYBuiHlLb8,266
|
| 7 |
+
colorama/__pycache__/__init__.cpython-310.pyc,,
|
| 8 |
+
colorama/__pycache__/ansi.cpython-310.pyc,,
|
| 9 |
+
colorama/__pycache__/ansitowin32.cpython-310.pyc,,
|
| 10 |
+
colorama/__pycache__/initialise.cpython-310.pyc,,
|
| 11 |
+
colorama/__pycache__/win32.cpython-310.pyc,,
|
| 12 |
+
colorama/__pycache__/winterm.cpython-310.pyc,,
|
| 13 |
+
colorama/ansi.py,sha256=Top4EeEuaQdBWdteKMEcGOTeKeF19Q-Wo_6_Cj5kOzQ,2522
|
| 14 |
+
colorama/ansitowin32.py,sha256=vPNYa3OZbxjbuFyaVo0Tmhmy1FZ1lKMWCnT7odXpItk,11128
|
| 15 |
+
colorama/initialise.py,sha256=-hIny86ClXo39ixh5iSCfUIa2f_h_bgKRDW7gqs-KLU,3325
|
| 16 |
+
colorama/tests/__init__.py,sha256=MkgPAEzGQd-Rq0w0PZXSX2LadRWhUECcisJY8lSrm4Q,75
|
| 17 |
+
colorama/tests/__pycache__/__init__.cpython-310.pyc,,
|
| 18 |
+
colorama/tests/__pycache__/ansi_test.cpython-310.pyc,,
|
| 19 |
+
colorama/tests/__pycache__/ansitowin32_test.cpython-310.pyc,,
|
| 20 |
+
colorama/tests/__pycache__/initialise_test.cpython-310.pyc,,
|
| 21 |
+
colorama/tests/__pycache__/isatty_test.cpython-310.pyc,,
|
| 22 |
+
colorama/tests/__pycache__/utils.cpython-310.pyc,,
|
| 23 |
+
colorama/tests/__pycache__/winterm_test.cpython-310.pyc,,
|
| 24 |
+
colorama/tests/ansi_test.py,sha256=FeViDrUINIZcr505PAxvU4AjXz1asEiALs9GXMhwRaE,2839
|
| 25 |
+
colorama/tests/ansitowin32_test.py,sha256=RN7AIhMJ5EqDsYaCjVo-o4u8JzDD4ukJbmevWKS70rY,10678
|
| 26 |
+
colorama/tests/initialise_test.py,sha256=BbPy-XfyHwJ6zKozuQOvNvQZzsx9vdb_0bYXn7hsBTc,6741
|
| 27 |
+
colorama/tests/isatty_test.py,sha256=Pg26LRpv0yQDB5Ac-sxgVXG7hsA1NYvapFgApZfYzZg,1866
|
| 28 |
+
colorama/tests/utils.py,sha256=1IIRylG39z5-dzq09R_ngufxyPZxgldNbrxKxUGwGKE,1079
|
| 29 |
+
colorama/tests/winterm_test.py,sha256=qoWFPEjym5gm2RuMwpf3pOis3a5r_PJZFCzK254JL8A,3709
|
| 30 |
+
colorama/win32.py,sha256=YQOKwMTwtGBbsY4dL5HYTvwTeP9wIQra5MvPNddpxZs,6181
|
| 31 |
+
colorama/winterm.py,sha256=XCQFDHjPi6AHYNdZwy0tA02H-Jh48Jp-HvCjeLeLp3U,7134
|
pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/WHEEL
ADDED
|
@@ -0,0 +1,5 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Wheel-Version: 1.0
|
| 2 |
+
Generator: hatchling 1.11.1
|
| 3 |
+
Root-Is-Purelib: true
|
| 4 |
+
Tag: py2-none-any
|
| 5 |
+
Tag: py3-none-any
|
pythonProject/.venv/Lib/site-packages/colorama-0.4.6.dist-info/licenses/LICENSE.txt
ADDED
|
@@ -0,0 +1,27 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
Copyright (c) 2010 Jonathan Hartley
|
| 2 |
+
All rights reserved.
|
| 3 |
+
|
| 4 |
+
Redistribution and use in source and binary forms, with or without
|
| 5 |
+
modification, are permitted provided that the following conditions are met:
|
| 6 |
+
|
| 7 |
+
* Redistributions of source code must retain the above copyright notice, this
|
| 8 |
+
list of conditions and the following disclaimer.
|
| 9 |
+
|
| 10 |
+
* Redistributions in binary form must reproduce the above copyright notice,
|
| 11 |
+
this list of conditions and the following disclaimer in the documentation
|
| 12 |
+
and/or other materials provided with the distribution.
|
| 13 |
+
|
| 14 |
+
* Neither the name of the copyright holders, nor those of its contributors
|
| 15 |
+
may be used to endorse or promote products derived from this software without
|
| 16 |
+
specific prior written permission.
|
| 17 |
+
|
| 18 |
+
THIS SOFTWARE IS PROVIDED BY THE COPYRIGHT HOLDERS AND CONTRIBUTORS "AS IS" AND
|
| 19 |
+
ANY EXPRESS OR IMPLIED WARRANTIES, INCLUDING, BUT NOT LIMITED TO, THE IMPLIED
|
| 20 |
+
WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR PURPOSE ARE
|
| 21 |
+
DISCLAIMED. IN NO EVENT SHALL THE COPYRIGHT HOLDER OR CONTRIBUTORS BE LIABLE
|
| 22 |
+
FOR ANY DIRECT, INDIRECT, INCIDENTAL, SPECIAL, EXEMPLARY, OR CONSEQUENTIAL
|
| 23 |
+
DAMAGES (INCLUDING, BUT NOT LIMITED TO, PROCUREMENT OF SUBSTITUTE GOODS OR
|
| 24 |
+
SERVICES; LOSS OF USE, DATA, OR PROFITS; OR BUSINESS INTERRUPTION) HOWEVER
|
| 25 |
+
CAUSED AND ON ANY THEORY OF LIABILITY, WHETHER IN CONTRACT, STRICT LIABILITY,
|
| 26 |
+
OR TORT (INCLUDING NEGLIGENCE OR OTHERWISE) ARISING IN ANY WAY OUT OF THE USE
|
| 27 |
+
OF THIS SOFTWARE, EVEN IF ADVISED OF THE POSSIBILITY OF SUCH DAMAGE.
|
pythonProject/.venv/Lib/site-packages/colorama/__pycache__/__init__.cpython-310.pyc
ADDED
|
Binary file (451 Bytes). View file
|
|
|
pythonProject/.venv/Lib/site-packages/colorama/ansi.py
ADDED
|
@@ -0,0 +1,102 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file.
|
| 2 |
+
'''
|
| 3 |
+
This module generates ANSI character codes to printing colors to terminals.
|
| 4 |
+
See: http://en.wikipedia.org/wiki/ANSI_escape_code
|
| 5 |
+
'''
|
| 6 |
+
|
| 7 |
+
CSI = '\033['
|
| 8 |
+
OSC = '\033]'
|
| 9 |
+
BEL = '\a'
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def code_to_chars(code):
|
| 13 |
+
return CSI + str(code) + 'm'
|
| 14 |
+
|
| 15 |
+
def set_title(title):
|
| 16 |
+
return OSC + '2;' + title + BEL
|
| 17 |
+
|
| 18 |
+
def clear_screen(mode=2):
|
| 19 |
+
return CSI + str(mode) + 'J'
|
| 20 |
+
|
| 21 |
+
def clear_line(mode=2):
|
| 22 |
+
return CSI + str(mode) + 'K'
|
| 23 |
+
|
| 24 |
+
|
| 25 |
+
class AnsiCodes(object):
|
| 26 |
+
def __init__(self):
|
| 27 |
+
# the subclasses declare class attributes which are numbers.
|
| 28 |
+
# Upon instantiation we define instance attributes, which are the same
|
| 29 |
+
# as the class attributes but wrapped with the ANSI escape sequence
|
| 30 |
+
for name in dir(self):
|
| 31 |
+
if not name.startswith('_'):
|
| 32 |
+
value = getattr(self, name)
|
| 33 |
+
setattr(self, name, code_to_chars(value))
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
class AnsiCursor(object):
|
| 37 |
+
def UP(self, n=1):
|
| 38 |
+
return CSI + str(n) + 'A'
|
| 39 |
+
def DOWN(self, n=1):
|
| 40 |
+
return CSI + str(n) + 'B'
|
| 41 |
+
def FORWARD(self, n=1):
|
| 42 |
+
return CSI + str(n) + 'C'
|
| 43 |
+
def BACK(self, n=1):
|
| 44 |
+
return CSI + str(n) + 'D'
|
| 45 |
+
def POS(self, x=1, y=1):
|
| 46 |
+
return CSI + str(y) + ';' + str(x) + 'H'
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
class AnsiFore(AnsiCodes):
|
| 50 |
+
BLACK = 30
|
| 51 |
+
RED = 31
|
| 52 |
+
GREEN = 32
|
| 53 |
+
YELLOW = 33
|
| 54 |
+
BLUE = 34
|
| 55 |
+
MAGENTA = 35
|
| 56 |
+
CYAN = 36
|
| 57 |
+
WHITE = 37
|
| 58 |
+
RESET = 39
|
| 59 |
+
|
| 60 |
+
# These are fairly well supported, but not part of the standard.
|
| 61 |
+
LIGHTBLACK_EX = 90
|
| 62 |
+
LIGHTRED_EX = 91
|
| 63 |
+
LIGHTGREEN_EX = 92
|
| 64 |
+
LIGHTYELLOW_EX = 93
|
| 65 |
+
LIGHTBLUE_EX = 94
|
| 66 |
+
LIGHTMAGENTA_EX = 95
|
| 67 |
+
LIGHTCYAN_EX = 96
|
| 68 |
+
LIGHTWHITE_EX = 97
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
class AnsiBack(AnsiCodes):
|
| 72 |
+
BLACK = 40
|
| 73 |
+
RED = 41
|
| 74 |
+
GREEN = 42
|
| 75 |
+
YELLOW = 43
|
| 76 |
+
BLUE = 44
|
| 77 |
+
MAGENTA = 45
|
| 78 |
+
CYAN = 46
|
| 79 |
+
WHITE = 47
|
| 80 |
+
RESET = 49
|
| 81 |
+
|
| 82 |
+
# These are fairly well supported, but not part of the standard.
|
| 83 |
+
LIGHTBLACK_EX = 100
|
| 84 |
+
LIGHTRED_EX = 101
|
| 85 |
+
LIGHTGREEN_EX = 102
|
| 86 |
+
LIGHTYELLOW_EX = 103
|
| 87 |
+
LIGHTBLUE_EX = 104
|
| 88 |
+
LIGHTMAGENTA_EX = 105
|
| 89 |
+
LIGHTCYAN_EX = 106
|
| 90 |
+
LIGHTWHITE_EX = 107
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
class AnsiStyle(AnsiCodes):
|
| 94 |
+
BRIGHT = 1
|
| 95 |
+
DIM = 2
|
| 96 |
+
NORMAL = 22
|
| 97 |
+
RESET_ALL = 0
|
| 98 |
+
|
| 99 |
+
Fore = AnsiFore()
|
| 100 |
+
Back = AnsiBack()
|
| 101 |
+
Style = AnsiStyle()
|
| 102 |
+
Cursor = AnsiCursor()
|
pythonProject/.venv/Lib/site-packages/colorama/ansitowin32.py
ADDED
|
@@ -0,0 +1,277 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file.
|
| 2 |
+
import re
|
| 3 |
+
import sys
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from .ansi import AnsiFore, AnsiBack, AnsiStyle, Style, BEL
|
| 7 |
+
from .winterm import enable_vt_processing, WinTerm, WinColor, WinStyle
|
| 8 |
+
from .win32 import windll, winapi_test
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
winterm = None
|
| 12 |
+
if windll is not None:
|
| 13 |
+
winterm = WinTerm()
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class StreamWrapper(object):
|
| 17 |
+
'''
|
| 18 |
+
Wraps a stream (such as stdout), acting as a transparent proxy for all
|
| 19 |
+
attribute access apart from method 'write()', which is delegated to our
|
| 20 |
+
Converter instance.
|
| 21 |
+
'''
|
| 22 |
+
def __init__(self, wrapped, converter):
|
| 23 |
+
# double-underscore everything to prevent clashes with names of
|
| 24 |
+
# attributes on the wrapped stream object.
|
| 25 |
+
self.__wrapped = wrapped
|
| 26 |
+
self.__convertor = converter
|
| 27 |
+
|
| 28 |
+
def __getattr__(self, name):
|
| 29 |
+
return getattr(self.__wrapped, name)
|
| 30 |
+
|
| 31 |
+
def __enter__(self, *args, **kwargs):
|
| 32 |
+
# special method lookup bypasses __getattr__/__getattribute__, see
|
| 33 |
+
# https://stackoverflow.com/questions/12632894/why-doesnt-getattr-work-with-exit
|
| 34 |
+
# thus, contextlib magic methods are not proxied via __getattr__
|
| 35 |
+
return self.__wrapped.__enter__(*args, **kwargs)
|
| 36 |
+
|
| 37 |
+
def __exit__(self, *args, **kwargs):
|
| 38 |
+
return self.__wrapped.__exit__(*args, **kwargs)
|
| 39 |
+
|
| 40 |
+
def __setstate__(self, state):
|
| 41 |
+
self.__dict__ = state
|
| 42 |
+
|
| 43 |
+
def __getstate__(self):
|
| 44 |
+
return self.__dict__
|
| 45 |
+
|
| 46 |
+
def write(self, text):
|
| 47 |
+
self.__convertor.write(text)
|
| 48 |
+
|
| 49 |
+
def isatty(self):
|
| 50 |
+
stream = self.__wrapped
|
| 51 |
+
if 'PYCHARM_HOSTED' in os.environ:
|
| 52 |
+
if stream is not None and (stream is sys.__stdout__ or stream is sys.__stderr__):
|
| 53 |
+
return True
|
| 54 |
+
try:
|
| 55 |
+
stream_isatty = stream.isatty
|
| 56 |
+
except AttributeError:
|
| 57 |
+
return False
|
| 58 |
+
else:
|
| 59 |
+
return stream_isatty()
|
| 60 |
+
|
| 61 |
+
@property
|
| 62 |
+
def closed(self):
|
| 63 |
+
stream = self.__wrapped
|
| 64 |
+
try:
|
| 65 |
+
return stream.closed
|
| 66 |
+
# AttributeError in the case that the stream doesn't support being closed
|
| 67 |
+
# ValueError for the case that the stream has already been detached when atexit runs
|
| 68 |
+
except (AttributeError, ValueError):
|
| 69 |
+
return True
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
class AnsiToWin32(object):
|
| 73 |
+
'''
|
| 74 |
+
Implements a 'write()' method which, on Windows, will strip ANSI character
|
| 75 |
+
sequences from the text, and if outputting to a tty, will convert them into
|
| 76 |
+
win32 function calls.
|
| 77 |
+
'''
|
| 78 |
+
ANSI_CSI_RE = re.compile('\001?\033\\[((?:\\d|;)*)([a-zA-Z])\002?') # Control Sequence Introducer
|
| 79 |
+
ANSI_OSC_RE = re.compile('\001?\033\\]([^\a]*)(\a)\002?') # Operating System Command
|
| 80 |
+
|
| 81 |
+
def __init__(self, wrapped, convert=None, strip=None, autoreset=False):
|
| 82 |
+
# The wrapped stream (normally sys.stdout or sys.stderr)
|
| 83 |
+
self.wrapped = wrapped
|
| 84 |
+
|
| 85 |
+
# should we reset colors to defaults after every .write()
|
| 86 |
+
self.autoreset = autoreset
|
| 87 |
+
|
| 88 |
+
# create the proxy wrapping our output stream
|
| 89 |
+
self.stream = StreamWrapper(wrapped, self)
|
| 90 |
+
|
| 91 |
+
on_windows = os.name == 'nt'
|
| 92 |
+
# We test if the WinAPI works, because even if we are on Windows
|
| 93 |
+
# we may be using a terminal that doesn't support the WinAPI
|
| 94 |
+
# (e.g. Cygwin Terminal). In this case it's up to the terminal
|
| 95 |
+
# to support the ANSI codes.
|
| 96 |
+
conversion_supported = on_windows and winapi_test()
|
| 97 |
+
try:
|
| 98 |
+
fd = wrapped.fileno()
|
| 99 |
+
except Exception:
|
| 100 |
+
fd = -1
|
| 101 |
+
system_has_native_ansi = not on_windows or enable_vt_processing(fd)
|
| 102 |
+
have_tty = not self.stream.closed and self.stream.isatty()
|
| 103 |
+
need_conversion = conversion_supported and not system_has_native_ansi
|
| 104 |
+
|
| 105 |
+
# should we strip ANSI sequences from our output?
|
| 106 |
+
if strip is None:
|
| 107 |
+
strip = need_conversion or not have_tty
|
| 108 |
+
self.strip = strip
|
| 109 |
+
|
| 110 |
+
# should we should convert ANSI sequences into win32 calls?
|
| 111 |
+
if convert is None:
|
| 112 |
+
convert = need_conversion and have_tty
|
| 113 |
+
self.convert = convert
|
| 114 |
+
|
| 115 |
+
# dict of ansi codes to win32 functions and parameters
|
| 116 |
+
self.win32_calls = self.get_win32_calls()
|
| 117 |
+
|
| 118 |
+
# are we wrapping stderr?
|
| 119 |
+
self.on_stderr = self.wrapped is sys.stderr
|
| 120 |
+
|
| 121 |
+
def should_wrap(self):
|
| 122 |
+
'''
|
| 123 |
+
True if this class is actually needed. If false, then the output
|
| 124 |
+
stream will not be affected, nor will win32 calls be issued, so
|
| 125 |
+
wrapping stdout is not actually required. This will generally be
|
| 126 |
+
False on non-Windows platforms, unless optional functionality like
|
| 127 |
+
autoreset has been requested using kwargs to init()
|
| 128 |
+
'''
|
| 129 |
+
return self.convert or self.strip or self.autoreset
|
| 130 |
+
|
| 131 |
+
def get_win32_calls(self):
|
| 132 |
+
if self.convert and winterm:
|
| 133 |
+
return {
|
| 134 |
+
AnsiStyle.RESET_ALL: (winterm.reset_all, ),
|
| 135 |
+
AnsiStyle.BRIGHT: (winterm.style, WinStyle.BRIGHT),
|
| 136 |
+
AnsiStyle.DIM: (winterm.style, WinStyle.NORMAL),
|
| 137 |
+
AnsiStyle.NORMAL: (winterm.style, WinStyle.NORMAL),
|
| 138 |
+
AnsiFore.BLACK: (winterm.fore, WinColor.BLACK),
|
| 139 |
+
AnsiFore.RED: (winterm.fore, WinColor.RED),
|
| 140 |
+
AnsiFore.GREEN: (winterm.fore, WinColor.GREEN),
|
| 141 |
+
AnsiFore.YELLOW: (winterm.fore, WinColor.YELLOW),
|
| 142 |
+
AnsiFore.BLUE: (winterm.fore, WinColor.BLUE),
|
| 143 |
+
AnsiFore.MAGENTA: (winterm.fore, WinColor.MAGENTA),
|
| 144 |
+
AnsiFore.CYAN: (winterm.fore, WinColor.CYAN),
|
| 145 |
+
AnsiFore.WHITE: (winterm.fore, WinColor.GREY),
|
| 146 |
+
AnsiFore.RESET: (winterm.fore, ),
|
| 147 |
+
AnsiFore.LIGHTBLACK_EX: (winterm.fore, WinColor.BLACK, True),
|
| 148 |
+
AnsiFore.LIGHTRED_EX: (winterm.fore, WinColor.RED, True),
|
| 149 |
+
AnsiFore.LIGHTGREEN_EX: (winterm.fore, WinColor.GREEN, True),
|
| 150 |
+
AnsiFore.LIGHTYELLOW_EX: (winterm.fore, WinColor.YELLOW, True),
|
| 151 |
+
AnsiFore.LIGHTBLUE_EX: (winterm.fore, WinColor.BLUE, True),
|
| 152 |
+
AnsiFore.LIGHTMAGENTA_EX: (winterm.fore, WinColor.MAGENTA, True),
|
| 153 |
+
AnsiFore.LIGHTCYAN_EX: (winterm.fore, WinColor.CYAN, True),
|
| 154 |
+
AnsiFore.LIGHTWHITE_EX: (winterm.fore, WinColor.GREY, True),
|
| 155 |
+
AnsiBack.BLACK: (winterm.back, WinColor.BLACK),
|
| 156 |
+
AnsiBack.RED: (winterm.back, WinColor.RED),
|
| 157 |
+
AnsiBack.GREEN: (winterm.back, WinColor.GREEN),
|
| 158 |
+
AnsiBack.YELLOW: (winterm.back, WinColor.YELLOW),
|
| 159 |
+
AnsiBack.BLUE: (winterm.back, WinColor.BLUE),
|
| 160 |
+
AnsiBack.MAGENTA: (winterm.back, WinColor.MAGENTA),
|
| 161 |
+
AnsiBack.CYAN: (winterm.back, WinColor.CYAN),
|
| 162 |
+
AnsiBack.WHITE: (winterm.back, WinColor.GREY),
|
| 163 |
+
AnsiBack.RESET: (winterm.back, ),
|
| 164 |
+
AnsiBack.LIGHTBLACK_EX: (winterm.back, WinColor.BLACK, True),
|
| 165 |
+
AnsiBack.LIGHTRED_EX: (winterm.back, WinColor.RED, True),
|
| 166 |
+
AnsiBack.LIGHTGREEN_EX: (winterm.back, WinColor.GREEN, True),
|
| 167 |
+
AnsiBack.LIGHTYELLOW_EX: (winterm.back, WinColor.YELLOW, True),
|
| 168 |
+
AnsiBack.LIGHTBLUE_EX: (winterm.back, WinColor.BLUE, True),
|
| 169 |
+
AnsiBack.LIGHTMAGENTA_EX: (winterm.back, WinColor.MAGENTA, True),
|
| 170 |
+
AnsiBack.LIGHTCYAN_EX: (winterm.back, WinColor.CYAN, True),
|
| 171 |
+
AnsiBack.LIGHTWHITE_EX: (winterm.back, WinColor.GREY, True),
|
| 172 |
+
}
|
| 173 |
+
return dict()
|
| 174 |
+
|
| 175 |
+
def write(self, text):
|
| 176 |
+
if self.strip or self.convert:
|
| 177 |
+
self.write_and_convert(text)
|
| 178 |
+
else:
|
| 179 |
+
self.wrapped.write(text)
|
| 180 |
+
self.wrapped.flush()
|
| 181 |
+
if self.autoreset:
|
| 182 |
+
self.reset_all()
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
def reset_all(self):
|
| 186 |
+
if self.convert:
|
| 187 |
+
self.call_win32('m', (0,))
|
| 188 |
+
elif not self.strip and not self.stream.closed:
|
| 189 |
+
self.wrapped.write(Style.RESET_ALL)
|
| 190 |
+
|
| 191 |
+
|
| 192 |
+
def write_and_convert(self, text):
|
| 193 |
+
'''
|
| 194 |
+
Write the given text to our wrapped stream, stripping any ANSI
|
| 195 |
+
sequences from the text, and optionally converting them into win32
|
| 196 |
+
calls.
|
| 197 |
+
'''
|
| 198 |
+
cursor = 0
|
| 199 |
+
text = self.convert_osc(text)
|
| 200 |
+
for match in self.ANSI_CSI_RE.finditer(text):
|
| 201 |
+
start, end = match.span()
|
| 202 |
+
self.write_plain_text(text, cursor, start)
|
| 203 |
+
self.convert_ansi(*match.groups())
|
| 204 |
+
cursor = end
|
| 205 |
+
self.write_plain_text(text, cursor, len(text))
|
| 206 |
+
|
| 207 |
+
|
| 208 |
+
def write_plain_text(self, text, start, end):
|
| 209 |
+
if start < end:
|
| 210 |
+
self.wrapped.write(text[start:end])
|
| 211 |
+
self.wrapped.flush()
|
| 212 |
+
|
| 213 |
+
|
| 214 |
+
def convert_ansi(self, paramstring, command):
|
| 215 |
+
if self.convert:
|
| 216 |
+
params = self.extract_params(command, paramstring)
|
| 217 |
+
self.call_win32(command, params)
|
| 218 |
+
|
| 219 |
+
|
| 220 |
+
def extract_params(self, command, paramstring):
|
| 221 |
+
if command in 'Hf':
|
| 222 |
+
params = tuple(int(p) if len(p) != 0 else 1 for p in paramstring.split(';'))
|
| 223 |
+
while len(params) < 2:
|
| 224 |
+
# defaults:
|
| 225 |
+
params = params + (1,)
|
| 226 |
+
else:
|
| 227 |
+
params = tuple(int(p) for p in paramstring.split(';') if len(p) != 0)
|
| 228 |
+
if len(params) == 0:
|
| 229 |
+
# defaults:
|
| 230 |
+
if command in 'JKm':
|
| 231 |
+
params = (0,)
|
| 232 |
+
elif command in 'ABCD':
|
| 233 |
+
params = (1,)
|
| 234 |
+
|
| 235 |
+
return params
|
| 236 |
+
|
| 237 |
+
|
| 238 |
+
def call_win32(self, command, params):
|
| 239 |
+
if command == 'm':
|
| 240 |
+
for param in params:
|
| 241 |
+
if param in self.win32_calls:
|
| 242 |
+
func_args = self.win32_calls[param]
|
| 243 |
+
func = func_args[0]
|
| 244 |
+
args = func_args[1:]
|
| 245 |
+
kwargs = dict(on_stderr=self.on_stderr)
|
| 246 |
+
func(*args, **kwargs)
|
| 247 |
+
elif command in 'J':
|
| 248 |
+
winterm.erase_screen(params[0], on_stderr=self.on_stderr)
|
| 249 |
+
elif command in 'K':
|
| 250 |
+
winterm.erase_line(params[0], on_stderr=self.on_stderr)
|
| 251 |
+
elif command in 'Hf': # cursor position - absolute
|
| 252 |
+
winterm.set_cursor_position(params, on_stderr=self.on_stderr)
|
| 253 |
+
elif command in 'ABCD': # cursor position - relative
|
| 254 |
+
n = params[0]
|
| 255 |
+
# A - up, B - down, C - forward, D - back
|
| 256 |
+
x, y = {'A': (0, -n), 'B': (0, n), 'C': (n, 0), 'D': (-n, 0)}[command]
|
| 257 |
+
winterm.cursor_adjust(x, y, on_stderr=self.on_stderr)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
+
def convert_osc(self, text):
|
| 261 |
+
for match in self.ANSI_OSC_RE.finditer(text):
|
| 262 |
+
start, end = match.span()
|
| 263 |
+
text = text[:start] + text[end:]
|
| 264 |
+
paramstring, command = match.groups()
|
| 265 |
+
if command == BEL:
|
| 266 |
+
if paramstring.count(";") == 1:
|
| 267 |
+
params = paramstring.split(";")
|
| 268 |
+
# 0 - change title and icon (we will only change title)
|
| 269 |
+
# 1 - change icon (we don't support this)
|
| 270 |
+
# 2 - change title
|
| 271 |
+
if params[0] in '02':
|
| 272 |
+
winterm.set_title(params[1])
|
| 273 |
+
return text
|
| 274 |
+
|
| 275 |
+
|
| 276 |
+
def flush(self):
|
| 277 |
+
self.wrapped.flush()
|
pythonProject/.venv/Lib/site-packages/colorama/initialise.py
ADDED
|
@@ -0,0 +1,121 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file.
|
| 2 |
+
import atexit
|
| 3 |
+
import contextlib
|
| 4 |
+
import sys
|
| 5 |
+
|
| 6 |
+
from .ansitowin32 import AnsiToWin32
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def _wipe_internal_state_for_tests():
|
| 10 |
+
global orig_stdout, orig_stderr
|
| 11 |
+
orig_stdout = None
|
| 12 |
+
orig_stderr = None
|
| 13 |
+
|
| 14 |
+
global wrapped_stdout, wrapped_stderr
|
| 15 |
+
wrapped_stdout = None
|
| 16 |
+
wrapped_stderr = None
|
| 17 |
+
|
| 18 |
+
global atexit_done
|
| 19 |
+
atexit_done = False
|
| 20 |
+
|
| 21 |
+
global fixed_windows_console
|
| 22 |
+
fixed_windows_console = False
|
| 23 |
+
|
| 24 |
+
try:
|
| 25 |
+
# no-op if it wasn't registered
|
| 26 |
+
atexit.unregister(reset_all)
|
| 27 |
+
except AttributeError:
|
| 28 |
+
# python 2: no atexit.unregister. Oh well, we did our best.
|
| 29 |
+
pass
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def reset_all():
|
| 33 |
+
if AnsiToWin32 is not None: # Issue #74: objects might become None at exit
|
| 34 |
+
AnsiToWin32(orig_stdout).reset_all()
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def init(autoreset=False, convert=None, strip=None, wrap=True):
|
| 38 |
+
|
| 39 |
+
if not wrap and any([autoreset, convert, strip]):
|
| 40 |
+
raise ValueError('wrap=False conflicts with any other arg=True')
|
| 41 |
+
|
| 42 |
+
global wrapped_stdout, wrapped_stderr
|
| 43 |
+
global orig_stdout, orig_stderr
|
| 44 |
+
|
| 45 |
+
orig_stdout = sys.stdout
|
| 46 |
+
orig_stderr = sys.stderr
|
| 47 |
+
|
| 48 |
+
if sys.stdout is None:
|
| 49 |
+
wrapped_stdout = None
|
| 50 |
+
else:
|
| 51 |
+
sys.stdout = wrapped_stdout = \
|
| 52 |
+
wrap_stream(orig_stdout, convert, strip, autoreset, wrap)
|
| 53 |
+
if sys.stderr is None:
|
| 54 |
+
wrapped_stderr = None
|
| 55 |
+
else:
|
| 56 |
+
sys.stderr = wrapped_stderr = \
|
| 57 |
+
wrap_stream(orig_stderr, convert, strip, autoreset, wrap)
|
| 58 |
+
|
| 59 |
+
global atexit_done
|
| 60 |
+
if not atexit_done:
|
| 61 |
+
atexit.register(reset_all)
|
| 62 |
+
atexit_done = True
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
def deinit():
|
| 66 |
+
if orig_stdout is not None:
|
| 67 |
+
sys.stdout = orig_stdout
|
| 68 |
+
if orig_stderr is not None:
|
| 69 |
+
sys.stderr = orig_stderr
|
| 70 |
+
|
| 71 |
+
|
| 72 |
+
def just_fix_windows_console():
|
| 73 |
+
global fixed_windows_console
|
| 74 |
+
|
| 75 |
+
if sys.platform != "win32":
|
| 76 |
+
return
|
| 77 |
+
if fixed_windows_console:
|
| 78 |
+
return
|
| 79 |
+
if wrapped_stdout is not None or wrapped_stderr is not None:
|
| 80 |
+
# Someone already ran init() and it did stuff, so we won't second-guess them
|
| 81 |
+
return
|
| 82 |
+
|
| 83 |
+
# On newer versions of Windows, AnsiToWin32.__init__ will implicitly enable the
|
| 84 |
+
# native ANSI support in the console as a side-effect. We only need to actually
|
| 85 |
+
# replace sys.stdout/stderr if we're in the old-style conversion mode.
|
| 86 |
+
new_stdout = AnsiToWin32(sys.stdout, convert=None, strip=None, autoreset=False)
|
| 87 |
+
if new_stdout.convert:
|
| 88 |
+
sys.stdout = new_stdout
|
| 89 |
+
new_stderr = AnsiToWin32(sys.stderr, convert=None, strip=None, autoreset=False)
|
| 90 |
+
if new_stderr.convert:
|
| 91 |
+
sys.stderr = new_stderr
|
| 92 |
+
|
| 93 |
+
fixed_windows_console = True
|
| 94 |
+
|
| 95 |
+
@contextlib.contextmanager
|
| 96 |
+
def colorama_text(*args, **kwargs):
|
| 97 |
+
init(*args, **kwargs)
|
| 98 |
+
try:
|
| 99 |
+
yield
|
| 100 |
+
finally:
|
| 101 |
+
deinit()
|
| 102 |
+
|
| 103 |
+
|
| 104 |
+
def reinit():
|
| 105 |
+
if wrapped_stdout is not None:
|
| 106 |
+
sys.stdout = wrapped_stdout
|
| 107 |
+
if wrapped_stderr is not None:
|
| 108 |
+
sys.stderr = wrapped_stderr
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
def wrap_stream(stream, convert, strip, autoreset, wrap):
|
| 112 |
+
if wrap:
|
| 113 |
+
wrapper = AnsiToWin32(stream,
|
| 114 |
+
convert=convert, strip=strip, autoreset=autoreset)
|
| 115 |
+
if wrapper.should_wrap():
|
| 116 |
+
stream = wrapper.stream
|
| 117 |
+
return stream
|
| 118 |
+
|
| 119 |
+
|
| 120 |
+
# Use this for initial setup as well, to reduce code duplication
|
| 121 |
+
_wipe_internal_state_for_tests()
|
pythonProject/.venv/Lib/site-packages/colorama/win32.py
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright Jonathan Hartley 2013. BSD 3-Clause license, see LICENSE file.
|
| 2 |
+
|
| 3 |
+
# from winbase.h
|
| 4 |
+
STDOUT = -11
|
| 5 |
+
STDERR = -12
|
| 6 |
+
|
| 7 |
+
ENABLE_VIRTUAL_TERMINAL_PROCESSING = 0x0004
|
| 8 |
+
|
| 9 |
+
try:
|
| 10 |
+
import ctypes
|
| 11 |
+
from ctypes import LibraryLoader
|
| 12 |
+
windll = LibraryLoader(ctypes.WinDLL)
|
| 13 |
+
from ctypes import wintypes
|
| 14 |
+
except (AttributeError, ImportError):
|
| 15 |
+
windll = None
|
| 16 |
+
SetConsoleTextAttribute = lambda *_: None
|
| 17 |
+
winapi_test = lambda *_: None
|
| 18 |
+
else:
|
| 19 |
+
from ctypes import byref, Structure, c_char, POINTER
|
| 20 |
+
|
| 21 |
+
COORD = wintypes._COORD
|
| 22 |
+
|
| 23 |
+
class CONSOLE_SCREEN_BUFFER_INFO(Structure):
|
| 24 |
+
"""struct in wincon.h."""
|
| 25 |
+
_fields_ = [
|
| 26 |
+
("dwSize", COORD),
|
| 27 |
+
("dwCursorPosition", COORD),
|
| 28 |
+
("wAttributes", wintypes.WORD),
|
| 29 |
+
("srWindow", wintypes.SMALL_RECT),
|
| 30 |
+
("dwMaximumWindowSize", COORD),
|
| 31 |
+
]
|
| 32 |
+
def __str__(self):
|
| 33 |
+
return '(%d,%d,%d,%d,%d,%d,%d,%d,%d,%d,%d)' % (
|
| 34 |
+
self.dwSize.Y, self.dwSize.X
|
| 35 |
+
, self.dwCursorPosition.Y, self.dwCursorPosition.X
|
| 36 |
+
, self.wAttributes
|
| 37 |
+
, self.srWindow.Top, self.srWindow.Left, self.srWindow.Bottom, self.srWindow.Right
|
| 38 |
+
, self.dwMaximumWindowSize.Y, self.dwMaximumWindowSize.X
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
_GetStdHandle = windll.kernel32.GetStdHandle
|
| 42 |
+
_GetStdHandle.argtypes = [
|
| 43 |
+
wintypes.DWORD,
|
| 44 |
+
]
|
| 45 |
+
_GetStdHandle.restype = wintypes.HANDLE
|
| 46 |
+
|
| 47 |
+
_GetConsoleScreenBufferInfo = windll.kernel32.GetConsoleScreenBufferInfo
|
| 48 |
+
_GetConsoleScreenBufferInfo.argtypes = [
|
| 49 |
+
wintypes.HANDLE,
|
| 50 |
+
POINTER(CONSOLE_SCREEN_BUFFER_INFO),
|
| 51 |
+
]
|
| 52 |
+
_GetConsoleScreenBufferInfo.restype = wintypes.BOOL
|
| 53 |
+
|
| 54 |
+
_SetConsoleTextAttribute = windll.kernel32.SetConsoleTextAttribute
|
| 55 |
+
_SetConsoleTextAttribute.argtypes = [
|
| 56 |
+
wintypes.HANDLE,
|
| 57 |
+
wintypes.WORD,
|
| 58 |
+
]
|
| 59 |
+
_SetConsoleTextAttribute.restype = wintypes.BOOL
|
| 60 |
+
|
| 61 |
+
_SetConsoleCursorPosition = windll.kernel32.SetConsoleCursorPosition
|
| 62 |
+
_SetConsoleCursorPosition.argtypes = [
|
| 63 |
+
wintypes.HANDLE,
|
| 64 |
+
COORD,
|
| 65 |
+
]
|
| 66 |
+
_SetConsoleCursorPosition.restype = wintypes.BOOL
|
| 67 |
+
|
| 68 |
+
_FillConsoleOutputCharacterA = windll.kernel32.FillConsoleOutputCharacterA
|
| 69 |
+
_FillConsoleOutputCharacterA.argtypes = [
|
| 70 |
+
wintypes.HANDLE,
|
| 71 |
+
c_char,
|
| 72 |
+
wintypes.DWORD,
|
| 73 |
+
COORD,
|
| 74 |
+
POINTER(wintypes.DWORD),
|
| 75 |
+
]
|
| 76 |
+
_FillConsoleOutputCharacterA.restype = wintypes.BOOL
|
| 77 |
+
|
| 78 |
+
_FillConsoleOutputAttribute = windll.kernel32.FillConsoleOutputAttribute
|
| 79 |
+
_FillConsoleOutputAttribute.argtypes = [
|
| 80 |
+
wintypes.HANDLE,
|
| 81 |
+
wintypes.WORD,
|
| 82 |
+
wintypes.DWORD,
|
| 83 |
+
COORD,
|
| 84 |
+
POINTER(wintypes.DWORD),
|
| 85 |
+
]
|
| 86 |
+
_FillConsoleOutputAttribute.restype = wintypes.BOOL
|
| 87 |
+
|
| 88 |
+
_SetConsoleTitleW = windll.kernel32.SetConsoleTitleW
|
| 89 |
+
_SetConsoleTitleW.argtypes = [
|
| 90 |
+
wintypes.LPCWSTR
|
| 91 |
+
]
|
| 92 |
+
_SetConsoleTitleW.restype = wintypes.BOOL
|
| 93 |
+
|
| 94 |
+
_GetConsoleMode = windll.kernel32.GetConsoleMode
|
| 95 |
+
_GetConsoleMode.argtypes = [
|
| 96 |
+
wintypes.HANDLE,
|
| 97 |
+
POINTER(wintypes.DWORD)
|
| 98 |
+
]
|
| 99 |
+
_GetConsoleMode.restype = wintypes.BOOL
|
| 100 |
+
|
| 101 |
+
_SetConsoleMode = windll.kernel32.SetConsoleMode
|
| 102 |
+
_SetConsoleMode.argtypes = [
|
| 103 |
+
wintypes.HANDLE,
|
| 104 |
+
wintypes.DWORD
|
| 105 |
+
]
|
| 106 |
+
_SetConsoleMode.restype = wintypes.BOOL
|
| 107 |
+
|
| 108 |
+
def _winapi_test(handle):
|
| 109 |
+
csbi = CONSOLE_SCREEN_BUFFER_INFO()
|
| 110 |
+
success = _GetConsoleScreenBufferInfo(
|
| 111 |
+
handle, byref(csbi))
|
| 112 |
+
return bool(success)
|
| 113 |
+
|
| 114 |
+
def winapi_test():
|
| 115 |
+
return any(_winapi_test(h) for h in
|
| 116 |
+
(_GetStdHandle(STDOUT), _GetStdHandle(STDERR)))
|
| 117 |
+
|
| 118 |
+
def GetConsoleScreenBufferInfo(stream_id=STDOUT):
|
| 119 |
+
handle = _GetStdHandle(stream_id)
|
| 120 |
+
csbi = CONSOLE_SCREEN_BUFFER_INFO()
|
| 121 |
+
success = _GetConsoleScreenBufferInfo(
|
| 122 |
+
handle, byref(csbi))
|
| 123 |
+
return csbi
|
| 124 |
+
|
| 125 |
+
def SetConsoleTextAttribute(stream_id, attrs):
|
| 126 |
+
handle = _GetStdHandle(stream_id)
|
| 127 |
+
return _SetConsoleTextAttribute(handle, attrs)
|
| 128 |
+
|
| 129 |
+
def SetConsoleCursorPosition(stream_id, position, adjust=True):
|
| 130 |
+
position = COORD(*position)
|
| 131 |
+
# If the position is out of range, do nothing.
|
| 132 |
+
if position.Y <= 0 or position.X <= 0:
|
| 133 |
+
return
|
| 134 |
+
# Adjust for Windows' SetConsoleCursorPosition:
|
| 135 |
+
# 1. being 0-based, while ANSI is 1-based.
|
| 136 |
+
# 2. expecting (x,y), while ANSI uses (y,x).
|
| 137 |
+
adjusted_position = COORD(position.Y - 1, position.X - 1)
|
| 138 |
+
if adjust:
|
| 139 |
+
# Adjust for viewport's scroll position
|
| 140 |
+
sr = GetConsoleScreenBufferInfo(STDOUT).srWindow
|
| 141 |
+
adjusted_position.Y += sr.Top
|
| 142 |
+
adjusted_position.X += sr.Left
|
| 143 |
+
# Resume normal processing
|
| 144 |
+
handle = _GetStdHandle(stream_id)
|
| 145 |
+
return _SetConsoleCursorPosition(handle, adjusted_position)
|
| 146 |
+
|
| 147 |
+
def FillConsoleOutputCharacter(stream_id, char, length, start):
|
| 148 |
+
handle = _GetStdHandle(stream_id)
|
| 149 |
+
char = c_char(char.encode())
|
| 150 |
+
length = wintypes.DWORD(length)
|
| 151 |
+
num_written = wintypes.DWORD(0)
|
| 152 |
+
# Note that this is hard-coded for ANSI (vs wide) bytes.
|
| 153 |
+
success = _FillConsoleOutputCharacterA(
|
| 154 |
+
handle, char, length, start, byref(num_written))
|
| 155 |
+
return num_written.value
|
| 156 |
+
|
| 157 |
+
def FillConsoleOutputAttribute(stream_id, attr, length, start):
|
| 158 |
+
''' FillConsoleOutputAttribute( hConsole, csbi.wAttributes, dwConSize, coordScreen, &cCharsWritten )'''
|
| 159 |
+
handle = _GetStdHandle(stream_id)
|
| 160 |
+
attribute = wintypes.WORD(attr)
|
| 161 |
+
length = wintypes.DWORD(length)
|
| 162 |
+
num_written = wintypes.DWORD(0)
|
| 163 |
+
# Note that this is hard-coded for ANSI (vs wide) bytes.
|
| 164 |
+
return _FillConsoleOutputAttribute(
|
| 165 |
+
handle, attribute, length, start, byref(num_written))
|
| 166 |
+
|
| 167 |
+
def SetConsoleTitle(title):
|
| 168 |
+
return _SetConsoleTitleW(title)
|
| 169 |
+
|
| 170 |
+
def GetConsoleMode(handle):
|
| 171 |
+
mode = wintypes.DWORD()
|
| 172 |
+
success = _GetConsoleMode(handle, byref(mode))
|
| 173 |
+
if not success:
|
| 174 |
+
raise ctypes.WinError()
|
| 175 |
+
return mode.value
|
| 176 |
+
|
| 177 |
+
def SetConsoleMode(handle, mode):
|
| 178 |
+
success = _SetConsoleMode(handle, mode)
|
| 179 |
+
if not success:
|
| 180 |
+
raise ctypes.WinError()
|
pythonProject/.venv/Lib/site-packages/diffusers/callbacks.py
ADDED
|
@@ -0,0 +1,244 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Any, Dict, List
|
| 2 |
+
|
| 3 |
+
from .configuration_utils import ConfigMixin, register_to_config
|
| 4 |
+
from .utils import CONFIG_NAME
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class PipelineCallback(ConfigMixin):
|
| 8 |
+
"""
|
| 9 |
+
Base class for all the official callbacks used in a pipeline. This class provides a structure for implementing
|
| 10 |
+
custom callbacks and ensures that all callbacks have a consistent interface.
|
| 11 |
+
|
| 12 |
+
Please implement the following:
|
| 13 |
+
`tensor_inputs`: This should return a list of tensor inputs specific to your callback. You will only be able to
|
| 14 |
+
include
|
| 15 |
+
variables listed in the `._callback_tensor_inputs` attribute of your pipeline class.
|
| 16 |
+
`callback_fn`: This method defines the core functionality of your callback.
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
config_name = CONFIG_NAME
|
| 20 |
+
|
| 21 |
+
@register_to_config
|
| 22 |
+
def __init__(self, cutoff_step_ratio=1.0, cutoff_step_index=None):
|
| 23 |
+
super().__init__()
|
| 24 |
+
|
| 25 |
+
if (cutoff_step_ratio is None and cutoff_step_index is None) or (
|
| 26 |
+
cutoff_step_ratio is not None and cutoff_step_index is not None
|
| 27 |
+
):
|
| 28 |
+
raise ValueError("Either cutoff_step_ratio or cutoff_step_index should be provided, not both or none.")
|
| 29 |
+
|
| 30 |
+
if cutoff_step_ratio is not None and (
|
| 31 |
+
not isinstance(cutoff_step_ratio, float) or not (0.0 <= cutoff_step_ratio <= 1.0)
|
| 32 |
+
):
|
| 33 |
+
raise ValueError("cutoff_step_ratio must be a float between 0.0 and 1.0.")
|
| 34 |
+
|
| 35 |
+
@property
|
| 36 |
+
def tensor_inputs(self) -> List[str]:
|
| 37 |
+
raise NotImplementedError(f"You need to set the attribute `tensor_inputs` for {self.__class__}")
|
| 38 |
+
|
| 39 |
+
def callback_fn(self, pipeline, step_index, timesteps, callback_kwargs) -> Dict[str, Any]:
|
| 40 |
+
raise NotImplementedError(f"You need to implement the method `callback_fn` for {self.__class__}")
|
| 41 |
+
|
| 42 |
+
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
| 43 |
+
return self.callback_fn(pipeline, step_index, timestep, callback_kwargs)
|
| 44 |
+
|
| 45 |
+
|
| 46 |
+
class MultiPipelineCallbacks:
|
| 47 |
+
"""
|
| 48 |
+
This class is designed to handle multiple pipeline callbacks. It accepts a list of PipelineCallback objects and
|
| 49 |
+
provides a unified interface for calling all of them.
|
| 50 |
+
"""
|
| 51 |
+
|
| 52 |
+
def __init__(self, callbacks: List[PipelineCallback]):
|
| 53 |
+
self.callbacks = callbacks
|
| 54 |
+
|
| 55 |
+
@property
|
| 56 |
+
def tensor_inputs(self) -> List[str]:
|
| 57 |
+
return [input for callback in self.callbacks for input in callback.tensor_inputs]
|
| 58 |
+
|
| 59 |
+
def __call__(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
| 60 |
+
"""
|
| 61 |
+
Calls all the callbacks in order with the given arguments and returns the final callback_kwargs.
|
| 62 |
+
"""
|
| 63 |
+
for callback in self.callbacks:
|
| 64 |
+
callback_kwargs = callback(pipeline, step_index, timestep, callback_kwargs)
|
| 65 |
+
|
| 66 |
+
return callback_kwargs
|
| 67 |
+
|
| 68 |
+
|
| 69 |
+
class SDCFGCutoffCallback(PipelineCallback):
|
| 70 |
+
"""
|
| 71 |
+
Callback function for Stable Diffusion Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
|
| 72 |
+
`cutoff_step_index`), this callback will disable the CFG.
|
| 73 |
+
|
| 74 |
+
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
|
| 75 |
+
"""
|
| 76 |
+
|
| 77 |
+
tensor_inputs = ["prompt_embeds"]
|
| 78 |
+
|
| 79 |
+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
| 80 |
+
cutoff_step_ratio = self.config.cutoff_step_ratio
|
| 81 |
+
cutoff_step_index = self.config.cutoff_step_index
|
| 82 |
+
|
| 83 |
+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
|
| 84 |
+
cutoff_step = (
|
| 85 |
+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
|
| 86 |
+
)
|
| 87 |
+
|
| 88 |
+
if step_index == cutoff_step:
|
| 89 |
+
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
|
| 90 |
+
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
|
| 91 |
+
|
| 92 |
+
pipeline._guidance_scale = 0.0
|
| 93 |
+
|
| 94 |
+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
|
| 95 |
+
return callback_kwargs
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
class SDXLCFGCutoffCallback(PipelineCallback):
|
| 99 |
+
"""
|
| 100 |
+
Callback function for the base Stable Diffusion XL Pipelines. After certain number of steps (set by
|
| 101 |
+
`cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
|
| 102 |
+
|
| 103 |
+
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
|
| 104 |
+
"""
|
| 105 |
+
|
| 106 |
+
tensor_inputs = [
|
| 107 |
+
"prompt_embeds",
|
| 108 |
+
"add_text_embeds",
|
| 109 |
+
"add_time_ids",
|
| 110 |
+
]
|
| 111 |
+
|
| 112 |
+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
| 113 |
+
cutoff_step_ratio = self.config.cutoff_step_ratio
|
| 114 |
+
cutoff_step_index = self.config.cutoff_step_index
|
| 115 |
+
|
| 116 |
+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
|
| 117 |
+
cutoff_step = (
|
| 118 |
+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
|
| 119 |
+
)
|
| 120 |
+
|
| 121 |
+
if step_index == cutoff_step:
|
| 122 |
+
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
|
| 123 |
+
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
|
| 124 |
+
|
| 125 |
+
add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
|
| 126 |
+
add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
|
| 127 |
+
|
| 128 |
+
add_time_ids = callback_kwargs[self.tensor_inputs[2]]
|
| 129 |
+
add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
|
| 130 |
+
|
| 131 |
+
pipeline._guidance_scale = 0.0
|
| 132 |
+
|
| 133 |
+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
|
| 134 |
+
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
|
| 135 |
+
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
|
| 136 |
+
|
| 137 |
+
return callback_kwargs
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
class SDXLControlnetCFGCutoffCallback(PipelineCallback):
|
| 141 |
+
"""
|
| 142 |
+
Callback function for the Controlnet Stable Diffusion XL Pipelines. After certain number of steps (set by
|
| 143 |
+
`cutoff_step_ratio` or `cutoff_step_index`), this callback will disable the CFG.
|
| 144 |
+
|
| 145 |
+
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
|
| 146 |
+
"""
|
| 147 |
+
|
| 148 |
+
tensor_inputs = [
|
| 149 |
+
"prompt_embeds",
|
| 150 |
+
"add_text_embeds",
|
| 151 |
+
"add_time_ids",
|
| 152 |
+
"image",
|
| 153 |
+
]
|
| 154 |
+
|
| 155 |
+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
| 156 |
+
cutoff_step_ratio = self.config.cutoff_step_ratio
|
| 157 |
+
cutoff_step_index = self.config.cutoff_step_index
|
| 158 |
+
|
| 159 |
+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
|
| 160 |
+
cutoff_step = (
|
| 161 |
+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
|
| 162 |
+
)
|
| 163 |
+
|
| 164 |
+
if step_index == cutoff_step:
|
| 165 |
+
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
|
| 166 |
+
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
|
| 167 |
+
|
| 168 |
+
add_text_embeds = callback_kwargs[self.tensor_inputs[1]]
|
| 169 |
+
add_text_embeds = add_text_embeds[-1:] # "-1" denotes the embeddings for conditional pooled text tokens
|
| 170 |
+
|
| 171 |
+
add_time_ids = callback_kwargs[self.tensor_inputs[2]]
|
| 172 |
+
add_time_ids = add_time_ids[-1:] # "-1" denotes the embeddings for conditional added time vector
|
| 173 |
+
|
| 174 |
+
# For Controlnet
|
| 175 |
+
image = callback_kwargs[self.tensor_inputs[3]]
|
| 176 |
+
image = image[-1:]
|
| 177 |
+
|
| 178 |
+
pipeline._guidance_scale = 0.0
|
| 179 |
+
|
| 180 |
+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
|
| 181 |
+
callback_kwargs[self.tensor_inputs[1]] = add_text_embeds
|
| 182 |
+
callback_kwargs[self.tensor_inputs[2]] = add_time_ids
|
| 183 |
+
callback_kwargs[self.tensor_inputs[3]] = image
|
| 184 |
+
|
| 185 |
+
return callback_kwargs
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class IPAdapterScaleCutoffCallback(PipelineCallback):
|
| 189 |
+
"""
|
| 190 |
+
Callback function for any pipeline that inherits `IPAdapterMixin`. After certain number of steps (set by
|
| 191 |
+
`cutoff_step_ratio` or `cutoff_step_index`), this callback will set the IP Adapter scale to `0.0`.
|
| 192 |
+
|
| 193 |
+
Note: This callback mutates the IP Adapter attention processors by setting the scale to 0.0 after the cutoff step.
|
| 194 |
+
"""
|
| 195 |
+
|
| 196 |
+
tensor_inputs = []
|
| 197 |
+
|
| 198 |
+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
| 199 |
+
cutoff_step_ratio = self.config.cutoff_step_ratio
|
| 200 |
+
cutoff_step_index = self.config.cutoff_step_index
|
| 201 |
+
|
| 202 |
+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
|
| 203 |
+
cutoff_step = (
|
| 204 |
+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
|
| 205 |
+
)
|
| 206 |
+
|
| 207 |
+
if step_index == cutoff_step:
|
| 208 |
+
pipeline.set_ip_adapter_scale(0.0)
|
| 209 |
+
return callback_kwargs
|
| 210 |
+
|
| 211 |
+
|
| 212 |
+
class SD3CFGCutoffCallback(PipelineCallback):
|
| 213 |
+
"""
|
| 214 |
+
Callback function for Stable Diffusion 3 Pipelines. After certain number of steps (set by `cutoff_step_ratio` or
|
| 215 |
+
`cutoff_step_index`), this callback will disable the CFG.
|
| 216 |
+
|
| 217 |
+
Note: This callback mutates the pipeline by changing the `_guidance_scale` attribute to 0.0 after the cutoff step.
|
| 218 |
+
"""
|
| 219 |
+
|
| 220 |
+
tensor_inputs = ["prompt_embeds", "pooled_prompt_embeds"]
|
| 221 |
+
|
| 222 |
+
def callback_fn(self, pipeline, step_index, timestep, callback_kwargs) -> Dict[str, Any]:
|
| 223 |
+
cutoff_step_ratio = self.config.cutoff_step_ratio
|
| 224 |
+
cutoff_step_index = self.config.cutoff_step_index
|
| 225 |
+
|
| 226 |
+
# Use cutoff_step_index if it's not None, otherwise use cutoff_step_ratio
|
| 227 |
+
cutoff_step = (
|
| 228 |
+
cutoff_step_index if cutoff_step_index is not None else int(pipeline.num_timesteps * cutoff_step_ratio)
|
| 229 |
+
)
|
| 230 |
+
|
| 231 |
+
if step_index == cutoff_step:
|
| 232 |
+
prompt_embeds = callback_kwargs[self.tensor_inputs[0]]
|
| 233 |
+
prompt_embeds = prompt_embeds[-1:] # "-1" denotes the embeddings for conditional text tokens.
|
| 234 |
+
|
| 235 |
+
pooled_prompt_embeds = callback_kwargs[self.tensor_inputs[1]]
|
| 236 |
+
pooled_prompt_embeds = pooled_prompt_embeds[
|
| 237 |
+
-1:
|
| 238 |
+
] # "-1" denotes the embeddings for conditional pooled text tokens.
|
| 239 |
+
|
| 240 |
+
pipeline._guidance_scale = 0.0
|
| 241 |
+
|
| 242 |
+
callback_kwargs[self.tensor_inputs[0]] = prompt_embeds
|
| 243 |
+
callback_kwargs[self.tensor_inputs[1]] = pooled_prompt_embeds
|
| 244 |
+
return callback_kwargs
|
pythonProject/.venv/Lib/site-packages/diffusers/configuration_utils.py
ADDED
|
@@ -0,0 +1,769 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
# Copyright (c) 2022, NVIDIA CORPORATION. All rights reserved.
|
| 4 |
+
#
|
| 5 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 6 |
+
# you may not use this file except in compliance with the License.
|
| 7 |
+
# You may obtain a copy of the License at
|
| 8 |
+
#
|
| 9 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 10 |
+
#
|
| 11 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 12 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 13 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 14 |
+
# See the License for the specific language governing permissions and
|
| 15 |
+
# limitations under the License.
|
| 16 |
+
"""ConfigMixin base class and utilities."""
|
| 17 |
+
|
| 18 |
+
import dataclasses
|
| 19 |
+
import functools
|
| 20 |
+
import importlib
|
| 21 |
+
import inspect
|
| 22 |
+
import json
|
| 23 |
+
import os
|
| 24 |
+
import re
|
| 25 |
+
from collections import OrderedDict
|
| 26 |
+
from pathlib import Path
|
| 27 |
+
from typing import Any, Dict, Optional, Tuple, Union
|
| 28 |
+
|
| 29 |
+
import numpy as np
|
| 30 |
+
from huggingface_hub import DDUFEntry, create_repo, hf_hub_download
|
| 31 |
+
from huggingface_hub.utils import (
|
| 32 |
+
EntryNotFoundError,
|
| 33 |
+
RepositoryNotFoundError,
|
| 34 |
+
RevisionNotFoundError,
|
| 35 |
+
validate_hf_hub_args,
|
| 36 |
+
)
|
| 37 |
+
from requests import HTTPError
|
| 38 |
+
from typing_extensions import Self
|
| 39 |
+
|
| 40 |
+
from . import __version__
|
| 41 |
+
from .utils import (
|
| 42 |
+
HUGGINGFACE_CO_RESOLVE_ENDPOINT,
|
| 43 |
+
DummyObject,
|
| 44 |
+
deprecate,
|
| 45 |
+
extract_commit_hash,
|
| 46 |
+
http_user_agent,
|
| 47 |
+
logging,
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
|
| 51 |
+
logger = logging.get_logger(__name__)
|
| 52 |
+
|
| 53 |
+
_re_configuration_file = re.compile(r"config\.(.*)\.json")
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class FrozenDict(OrderedDict):
|
| 57 |
+
def __init__(self, *args, **kwargs):
|
| 58 |
+
super().__init__(*args, **kwargs)
|
| 59 |
+
|
| 60 |
+
for key, value in self.items():
|
| 61 |
+
setattr(self, key, value)
|
| 62 |
+
|
| 63 |
+
self.__frozen = True
|
| 64 |
+
|
| 65 |
+
def __delitem__(self, *args, **kwargs):
|
| 66 |
+
raise Exception(f"You cannot use ``__delitem__`` on a {self.__class__.__name__} instance.")
|
| 67 |
+
|
| 68 |
+
def setdefault(self, *args, **kwargs):
|
| 69 |
+
raise Exception(f"You cannot use ``setdefault`` on a {self.__class__.__name__} instance.")
|
| 70 |
+
|
| 71 |
+
def pop(self, *args, **kwargs):
|
| 72 |
+
raise Exception(f"You cannot use ``pop`` on a {self.__class__.__name__} instance.")
|
| 73 |
+
|
| 74 |
+
def update(self, *args, **kwargs):
|
| 75 |
+
raise Exception(f"You cannot use ``update`` on a {self.__class__.__name__} instance.")
|
| 76 |
+
|
| 77 |
+
def __setattr__(self, name, value):
|
| 78 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
| 79 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
| 80 |
+
super().__setattr__(name, value)
|
| 81 |
+
|
| 82 |
+
def __setitem__(self, name, value):
|
| 83 |
+
if hasattr(self, "__frozen") and self.__frozen:
|
| 84 |
+
raise Exception(f"You cannot use ``__setattr__`` on a {self.__class__.__name__} instance.")
|
| 85 |
+
super().__setitem__(name, value)
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class ConfigMixin:
|
| 89 |
+
r"""
|
| 90 |
+
Base class for all configuration classes. All configuration parameters are stored under `self.config`. Also
|
| 91 |
+
provides the [`~ConfigMixin.from_config`] and [`~ConfigMixin.save_config`] methods for loading, downloading, and
|
| 92 |
+
saving classes that inherit from [`ConfigMixin`].
|
| 93 |
+
|
| 94 |
+
Class attributes:
|
| 95 |
+
- **config_name** (`str`) -- A filename under which the config should stored when calling
|
| 96 |
+
[`~ConfigMixin.save_config`] (should be overridden by parent class).
|
| 97 |
+
- **ignore_for_config** (`List[str]`) -- A list of attributes that should not be saved in the config (should be
|
| 98 |
+
overridden by subclass).
|
| 99 |
+
- **has_compatibles** (`bool`) -- Whether the class has compatible classes (should be overridden by subclass).
|
| 100 |
+
- **_deprecated_kwargs** (`List[str]`) -- Keyword arguments that are deprecated. Note that the `init` function
|
| 101 |
+
should only have a `kwargs` argument if at least one argument is deprecated (should be overridden by
|
| 102 |
+
subclass).
|
| 103 |
+
"""
|
| 104 |
+
|
| 105 |
+
config_name = None
|
| 106 |
+
ignore_for_config = []
|
| 107 |
+
has_compatibles = False
|
| 108 |
+
|
| 109 |
+
_deprecated_kwargs = []
|
| 110 |
+
|
| 111 |
+
def register_to_config(self, **kwargs):
|
| 112 |
+
if self.config_name is None:
|
| 113 |
+
raise NotImplementedError(f"Make sure that {self.__class__} has defined a class name `config_name`")
|
| 114 |
+
# Special case for `kwargs` used in deprecation warning added to schedulers
|
| 115 |
+
# TODO: remove this when we remove the deprecation warning, and the `kwargs` argument,
|
| 116 |
+
# or solve in a more general way.
|
| 117 |
+
kwargs.pop("kwargs", None)
|
| 118 |
+
|
| 119 |
+
if not hasattr(self, "_internal_dict"):
|
| 120 |
+
internal_dict = kwargs
|
| 121 |
+
else:
|
| 122 |
+
previous_dict = dict(self._internal_dict)
|
| 123 |
+
internal_dict = {**self._internal_dict, **kwargs}
|
| 124 |
+
logger.debug(f"Updating config from {previous_dict} to {internal_dict}")
|
| 125 |
+
|
| 126 |
+
self._internal_dict = FrozenDict(internal_dict)
|
| 127 |
+
|
| 128 |
+
def __getattr__(self, name: str) -> Any:
|
| 129 |
+
"""The only reason we overwrite `getattr` here is to gracefully deprecate accessing
|
| 130 |
+
config attributes directly. See https://github.com/huggingface/diffusers/pull/3129
|
| 131 |
+
|
| 132 |
+
This function is mostly copied from PyTorch's __getattr__ overwrite:
|
| 133 |
+
https://pytorch.org/docs/stable/_modules/torch/nn/modules/module.html#Module
|
| 134 |
+
"""
|
| 135 |
+
|
| 136 |
+
is_in_config = "_internal_dict" in self.__dict__ and hasattr(self.__dict__["_internal_dict"], name)
|
| 137 |
+
is_attribute = name in self.__dict__
|
| 138 |
+
|
| 139 |
+
if is_in_config and not is_attribute:
|
| 140 |
+
deprecation_message = f"Accessing config attribute `{name}` directly via '{type(self).__name__}' object attribute is deprecated. Please access '{name}' over '{type(self).__name__}'s config object instead, e.g. 'scheduler.config.{name}'."
|
| 141 |
+
deprecate("direct config name access", "1.0.0", deprecation_message, standard_warn=False)
|
| 142 |
+
return self._internal_dict[name]
|
| 143 |
+
|
| 144 |
+
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")
|
| 145 |
+
|
| 146 |
+
def save_config(self, save_directory: Union[str, os.PathLike], push_to_hub: bool = False, **kwargs):
|
| 147 |
+
"""
|
| 148 |
+
Save a configuration object to the directory specified in `save_directory` so that it can be reloaded using the
|
| 149 |
+
[`~ConfigMixin.from_config`] class method.
|
| 150 |
+
|
| 151 |
+
Args:
|
| 152 |
+
save_directory (`str` or `os.PathLike`):
|
| 153 |
+
Directory where the configuration JSON file is saved (will be created if it does not exist).
|
| 154 |
+
push_to_hub (`bool`, *optional*, defaults to `False`):
|
| 155 |
+
Whether or not to push your model to the Hugging Face Hub after saving it. You can specify the
|
| 156 |
+
repository you want to push to with `repo_id` (will default to the name of `save_directory` in your
|
| 157 |
+
namespace).
|
| 158 |
+
kwargs (`Dict[str, Any]`, *optional*):
|
| 159 |
+
Additional keyword arguments passed along to the [`~utils.PushToHubMixin.push_to_hub`] method.
|
| 160 |
+
"""
|
| 161 |
+
if os.path.isfile(save_directory):
|
| 162 |
+
raise AssertionError(f"Provided path ({save_directory}) should be a directory, not a file")
|
| 163 |
+
|
| 164 |
+
os.makedirs(save_directory, exist_ok=True)
|
| 165 |
+
|
| 166 |
+
# If we save using the predefined names, we can load using `from_config`
|
| 167 |
+
output_config_file = os.path.join(save_directory, self.config_name)
|
| 168 |
+
|
| 169 |
+
self.to_json_file(output_config_file)
|
| 170 |
+
logger.info(f"Configuration saved in {output_config_file}")
|
| 171 |
+
|
| 172 |
+
if push_to_hub:
|
| 173 |
+
commit_message = kwargs.pop("commit_message", None)
|
| 174 |
+
private = kwargs.pop("private", None)
|
| 175 |
+
create_pr = kwargs.pop("create_pr", False)
|
| 176 |
+
token = kwargs.pop("token", None)
|
| 177 |
+
repo_id = kwargs.pop("repo_id", save_directory.split(os.path.sep)[-1])
|
| 178 |
+
repo_id = create_repo(repo_id, exist_ok=True, private=private, token=token).repo_id
|
| 179 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 180 |
+
|
| 181 |
+
self._upload_folder(
|
| 182 |
+
save_directory,
|
| 183 |
+
repo_id,
|
| 184 |
+
token=token,
|
| 185 |
+
commit_message=commit_message,
|
| 186 |
+
create_pr=create_pr,
|
| 187 |
+
subfolder=subfolder,
|
| 188 |
+
)
|
| 189 |
+
|
| 190 |
+
@classmethod
|
| 191 |
+
def from_config(
|
| 192 |
+
cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs
|
| 193 |
+
) -> Union[Self, Tuple[Self, Dict[str, Any]]]:
|
| 194 |
+
r"""
|
| 195 |
+
Instantiate a Python class from a config dictionary.
|
| 196 |
+
|
| 197 |
+
Parameters:
|
| 198 |
+
config (`Dict[str, Any]`):
|
| 199 |
+
A config dictionary from which the Python class is instantiated. Make sure to only load configuration
|
| 200 |
+
files of compatible classes.
|
| 201 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False`):
|
| 202 |
+
Whether kwargs that are not consumed by the Python class should be returned or not.
|
| 203 |
+
kwargs (remaining dictionary of keyword arguments, *optional*):
|
| 204 |
+
Can be used to update the configuration object (after it is loaded) and initiate the Python class.
|
| 205 |
+
`**kwargs` are passed directly to the underlying scheduler/model's `__init__` method and eventually
|
| 206 |
+
overwrite the same named arguments in `config`.
|
| 207 |
+
|
| 208 |
+
Returns:
|
| 209 |
+
[`ModelMixin`] or [`SchedulerMixin`]:
|
| 210 |
+
A model or scheduler object instantiated from a config dictionary.
|
| 211 |
+
|
| 212 |
+
Examples:
|
| 213 |
+
|
| 214 |
+
```python
|
| 215 |
+
>>> from diffusers import DDPMScheduler, DDIMScheduler, PNDMScheduler
|
| 216 |
+
|
| 217 |
+
>>> # Download scheduler from huggingface.co and cache.
|
| 218 |
+
>>> scheduler = DDPMScheduler.from_pretrained("google/ddpm-cifar10-32")
|
| 219 |
+
|
| 220 |
+
>>> # Instantiate DDIM scheduler class with same config as DDPM
|
| 221 |
+
>>> scheduler = DDIMScheduler.from_config(scheduler.config)
|
| 222 |
+
|
| 223 |
+
>>> # Instantiate PNDM scheduler class with same config as DDPM
|
| 224 |
+
>>> scheduler = PNDMScheduler.from_config(scheduler.config)
|
| 225 |
+
```
|
| 226 |
+
"""
|
| 227 |
+
# <===== TO BE REMOVED WITH DEPRECATION
|
| 228 |
+
# TODO(Patrick) - make sure to remove the following lines when config=="model_path" is deprecated
|
| 229 |
+
if "pretrained_model_name_or_path" in kwargs:
|
| 230 |
+
config = kwargs.pop("pretrained_model_name_or_path")
|
| 231 |
+
|
| 232 |
+
if config is None:
|
| 233 |
+
raise ValueError("Please make sure to provide a config as the first positional argument.")
|
| 234 |
+
# ======>
|
| 235 |
+
|
| 236 |
+
if not isinstance(config, dict):
|
| 237 |
+
deprecation_message = "It is deprecated to pass a pretrained model name or path to `from_config`."
|
| 238 |
+
if "Scheduler" in cls.__name__:
|
| 239 |
+
deprecation_message += (
|
| 240 |
+
f"If you were trying to load a scheduler, please use {cls}.from_pretrained(...) instead."
|
| 241 |
+
" Otherwise, please make sure to pass a configuration dictionary instead. This functionality will"
|
| 242 |
+
" be removed in v1.0.0."
|
| 243 |
+
)
|
| 244 |
+
elif "Model" in cls.__name__:
|
| 245 |
+
deprecation_message += (
|
| 246 |
+
f"If you were trying to load a model, please use {cls}.load_config(...) followed by"
|
| 247 |
+
f" {cls}.from_config(...) instead. Otherwise, please make sure to pass a configuration dictionary"
|
| 248 |
+
" instead. This functionality will be removed in v1.0.0."
|
| 249 |
+
)
|
| 250 |
+
deprecate("config-passed-as-path", "1.0.0", deprecation_message, standard_warn=False)
|
| 251 |
+
config, kwargs = cls.load_config(pretrained_model_name_or_path=config, return_unused_kwargs=True, **kwargs)
|
| 252 |
+
|
| 253 |
+
init_dict, unused_kwargs, hidden_dict = cls.extract_init_dict(config, **kwargs)
|
| 254 |
+
|
| 255 |
+
# Allow dtype to be specified on initialization
|
| 256 |
+
if "dtype" in unused_kwargs:
|
| 257 |
+
init_dict["dtype"] = unused_kwargs.pop("dtype")
|
| 258 |
+
|
| 259 |
+
# add possible deprecated kwargs
|
| 260 |
+
for deprecated_kwarg in cls._deprecated_kwargs:
|
| 261 |
+
if deprecated_kwarg in unused_kwargs:
|
| 262 |
+
init_dict[deprecated_kwarg] = unused_kwargs.pop(deprecated_kwarg)
|
| 263 |
+
|
| 264 |
+
# Return model and optionally state and/or unused_kwargs
|
| 265 |
+
model = cls(**init_dict)
|
| 266 |
+
|
| 267 |
+
# make sure to also save config parameters that might be used for compatible classes
|
| 268 |
+
# update _class_name
|
| 269 |
+
if "_class_name" in hidden_dict:
|
| 270 |
+
hidden_dict["_class_name"] = cls.__name__
|
| 271 |
+
|
| 272 |
+
model.register_to_config(**hidden_dict)
|
| 273 |
+
|
| 274 |
+
# add hidden kwargs of compatible classes to unused_kwargs
|
| 275 |
+
unused_kwargs = {**unused_kwargs, **hidden_dict}
|
| 276 |
+
|
| 277 |
+
if return_unused_kwargs:
|
| 278 |
+
return (model, unused_kwargs)
|
| 279 |
+
else:
|
| 280 |
+
return model
|
| 281 |
+
|
| 282 |
+
@classmethod
|
| 283 |
+
def get_config_dict(cls, *args, **kwargs):
|
| 284 |
+
deprecation_message = (
|
| 285 |
+
f" The function get_config_dict is deprecated. Please use {cls}.load_config instead. This function will be"
|
| 286 |
+
" removed in version v1.0.0"
|
| 287 |
+
)
|
| 288 |
+
deprecate("get_config_dict", "1.0.0", deprecation_message, standard_warn=False)
|
| 289 |
+
return cls.load_config(*args, **kwargs)
|
| 290 |
+
|
| 291 |
+
@classmethod
|
| 292 |
+
@validate_hf_hub_args
|
| 293 |
+
def load_config(
|
| 294 |
+
cls,
|
| 295 |
+
pretrained_model_name_or_path: Union[str, os.PathLike],
|
| 296 |
+
return_unused_kwargs=False,
|
| 297 |
+
return_commit_hash=False,
|
| 298 |
+
**kwargs,
|
| 299 |
+
) -> Tuple[Dict[str, Any], Dict[str, Any]]:
|
| 300 |
+
r"""
|
| 301 |
+
Load a model or scheduler configuration.
|
| 302 |
+
|
| 303 |
+
Parameters:
|
| 304 |
+
pretrained_model_name_or_path (`str` or `os.PathLike`, *optional*):
|
| 305 |
+
Can be either:
|
| 306 |
+
|
| 307 |
+
- A string, the *model id* (for example `google/ddpm-celebahq-256`) of a pretrained model hosted on
|
| 308 |
+
the Hub.
|
| 309 |
+
- A path to a *directory* (for example `./my_model_directory`) containing model weights saved with
|
| 310 |
+
[`~ConfigMixin.save_config`].
|
| 311 |
+
|
| 312 |
+
cache_dir (`Union[str, os.PathLike]`, *optional*):
|
| 313 |
+
Path to a directory where a downloaded pretrained model configuration is cached if the standard cache
|
| 314 |
+
is not used.
|
| 315 |
+
force_download (`bool`, *optional*, defaults to `False`):
|
| 316 |
+
Whether or not to force the (re-)download of the model weights and configuration files, overriding the
|
| 317 |
+
cached versions if they exist.
|
| 318 |
+
proxies (`Dict[str, str]`, *optional*):
|
| 319 |
+
A dictionary of proxy servers to use by protocol or endpoint, for example, `{'http': 'foo.bar:3128',
|
| 320 |
+
'http://hostname': 'foo.bar:4012'}`. The proxies are used on each request.
|
| 321 |
+
output_loading_info(`bool`, *optional*, defaults to `False`):
|
| 322 |
+
Whether or not to also return a dictionary containing missing keys, unexpected keys and error messages.
|
| 323 |
+
local_files_only (`bool`, *optional*, defaults to `False`):
|
| 324 |
+
Whether to only load local model weights and configuration files or not. If set to `True`, the model
|
| 325 |
+
won't be downloaded from the Hub.
|
| 326 |
+
token (`str` or *bool*, *optional*):
|
| 327 |
+
The token to use as HTTP bearer authorization for remote files. If `True`, the token generated from
|
| 328 |
+
`diffusers-cli login` (stored in `~/.huggingface`) is used.
|
| 329 |
+
revision (`str`, *optional*, defaults to `"main"`):
|
| 330 |
+
The specific model version to use. It can be a branch name, a tag name, a commit id, or any identifier
|
| 331 |
+
allowed by Git.
|
| 332 |
+
subfolder (`str`, *optional*, defaults to `""`):
|
| 333 |
+
The subfolder location of a model file within a larger model repository on the Hub or locally.
|
| 334 |
+
return_unused_kwargs (`bool`, *optional*, defaults to `False):
|
| 335 |
+
Whether unused keyword arguments of the config are returned.
|
| 336 |
+
return_commit_hash (`bool`, *optional*, defaults to `False):
|
| 337 |
+
Whether the `commit_hash` of the loaded configuration are returned.
|
| 338 |
+
|
| 339 |
+
Returns:
|
| 340 |
+
`dict`:
|
| 341 |
+
A dictionary of all the parameters stored in a JSON configuration file.
|
| 342 |
+
|
| 343 |
+
"""
|
| 344 |
+
cache_dir = kwargs.pop("cache_dir", None)
|
| 345 |
+
local_dir = kwargs.pop("local_dir", None)
|
| 346 |
+
local_dir_use_symlinks = kwargs.pop("local_dir_use_symlinks", "auto")
|
| 347 |
+
force_download = kwargs.pop("force_download", False)
|
| 348 |
+
proxies = kwargs.pop("proxies", None)
|
| 349 |
+
token = kwargs.pop("token", None)
|
| 350 |
+
local_files_only = kwargs.pop("local_files_only", False)
|
| 351 |
+
revision = kwargs.pop("revision", None)
|
| 352 |
+
_ = kwargs.pop("mirror", None)
|
| 353 |
+
subfolder = kwargs.pop("subfolder", None)
|
| 354 |
+
user_agent = kwargs.pop("user_agent", {})
|
| 355 |
+
dduf_entries: Optional[Dict[str, DDUFEntry]] = kwargs.pop("dduf_entries", None)
|
| 356 |
+
|
| 357 |
+
user_agent = {**user_agent, "file_type": "config"}
|
| 358 |
+
user_agent = http_user_agent(user_agent)
|
| 359 |
+
|
| 360 |
+
pretrained_model_name_or_path = str(pretrained_model_name_or_path)
|
| 361 |
+
|
| 362 |
+
if cls.config_name is None:
|
| 363 |
+
raise ValueError(
|
| 364 |
+
"`self.config_name` is not defined. Note that one should not load a config from "
|
| 365 |
+
"`ConfigMixin`. Please make sure to define `config_name` in a class inheriting from `ConfigMixin`"
|
| 366 |
+
)
|
| 367 |
+
# Custom path for now
|
| 368 |
+
if dduf_entries:
|
| 369 |
+
if subfolder is not None:
|
| 370 |
+
raise ValueError(
|
| 371 |
+
"DDUF file only allow for 1 level of directory (e.g transformer/model1/model.safetentors is not allowed). "
|
| 372 |
+
"Please check the DDUF structure"
|
| 373 |
+
)
|
| 374 |
+
config_file = cls._get_config_file_from_dduf(pretrained_model_name_or_path, dduf_entries)
|
| 375 |
+
elif os.path.isfile(pretrained_model_name_or_path):
|
| 376 |
+
config_file = pretrained_model_name_or_path
|
| 377 |
+
elif os.path.isdir(pretrained_model_name_or_path):
|
| 378 |
+
if subfolder is not None and os.path.isfile(
|
| 379 |
+
os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
| 380 |
+
):
|
| 381 |
+
config_file = os.path.join(pretrained_model_name_or_path, subfolder, cls.config_name)
|
| 382 |
+
elif os.path.isfile(os.path.join(pretrained_model_name_or_path, cls.config_name)):
|
| 383 |
+
# Load from a PyTorch checkpoint
|
| 384 |
+
config_file = os.path.join(pretrained_model_name_or_path, cls.config_name)
|
| 385 |
+
else:
|
| 386 |
+
raise EnvironmentError(
|
| 387 |
+
f"Error no file named {cls.config_name} found in directory {pretrained_model_name_or_path}."
|
| 388 |
+
)
|
| 389 |
+
else:
|
| 390 |
+
try:
|
| 391 |
+
# Load from URL or cache if already cached
|
| 392 |
+
config_file = hf_hub_download(
|
| 393 |
+
pretrained_model_name_or_path,
|
| 394 |
+
filename=cls.config_name,
|
| 395 |
+
cache_dir=cache_dir,
|
| 396 |
+
force_download=force_download,
|
| 397 |
+
proxies=proxies,
|
| 398 |
+
local_files_only=local_files_only,
|
| 399 |
+
token=token,
|
| 400 |
+
user_agent=user_agent,
|
| 401 |
+
subfolder=subfolder,
|
| 402 |
+
revision=revision,
|
| 403 |
+
local_dir=local_dir,
|
| 404 |
+
local_dir_use_symlinks=local_dir_use_symlinks,
|
| 405 |
+
)
|
| 406 |
+
except RepositoryNotFoundError:
|
| 407 |
+
raise EnvironmentError(
|
| 408 |
+
f"{pretrained_model_name_or_path} is not a local folder and is not a valid model identifier"
|
| 409 |
+
" listed on 'https://huggingface.co/models'\nIf this is a private repository, make sure to pass a"
|
| 410 |
+
" token having permission to this repo with `token` or log in with `hf auth login`."
|
| 411 |
+
)
|
| 412 |
+
except RevisionNotFoundError:
|
| 413 |
+
raise EnvironmentError(
|
| 414 |
+
f"{revision} is not a valid git identifier (branch name, tag name or commit id) that exists for"
|
| 415 |
+
" this model name. Check the model page at"
|
| 416 |
+
f" 'https://huggingface.co/{pretrained_model_name_or_path}' for available revisions."
|
| 417 |
+
)
|
| 418 |
+
except EntryNotFoundError:
|
| 419 |
+
raise EnvironmentError(
|
| 420 |
+
f"{pretrained_model_name_or_path} does not appear to have a file named {cls.config_name}."
|
| 421 |
+
)
|
| 422 |
+
except HTTPError as err:
|
| 423 |
+
raise EnvironmentError(
|
| 424 |
+
"There was a specific connection error when trying to load"
|
| 425 |
+
f" {pretrained_model_name_or_path}:\n{err}"
|
| 426 |
+
)
|
| 427 |
+
except ValueError:
|
| 428 |
+
raise EnvironmentError(
|
| 429 |
+
f"We couldn't connect to '{HUGGINGFACE_CO_RESOLVE_ENDPOINT}' to load this model, couldn't find it"
|
| 430 |
+
f" in the cached files and it looks like {pretrained_model_name_or_path} is not the path to a"
|
| 431 |
+
f" directory containing a {cls.config_name} file.\nCheckout your internet connection or see how to"
|
| 432 |
+
" run the library in offline mode at"
|
| 433 |
+
" 'https://huggingface.co/docs/diffusers/installation#offline-mode'."
|
| 434 |
+
)
|
| 435 |
+
except EnvironmentError:
|
| 436 |
+
raise EnvironmentError(
|
| 437 |
+
f"Can't load config for '{pretrained_model_name_or_path}'. If you were trying to load it from "
|
| 438 |
+
"'https://huggingface.co/models', make sure you don't have a local directory with the same name. "
|
| 439 |
+
f"Otherwise, make sure '{pretrained_model_name_or_path}' is the correct path to a directory "
|
| 440 |
+
f"containing a {cls.config_name} file"
|
| 441 |
+
)
|
| 442 |
+
try:
|
| 443 |
+
config_dict = cls._dict_from_json_file(config_file, dduf_entries=dduf_entries)
|
| 444 |
+
|
| 445 |
+
commit_hash = extract_commit_hash(config_file)
|
| 446 |
+
except (json.JSONDecodeError, UnicodeDecodeError):
|
| 447 |
+
raise EnvironmentError(f"It looks like the config file at '{config_file}' is not a valid JSON file.")
|
| 448 |
+
|
| 449 |
+
if not (return_unused_kwargs or return_commit_hash):
|
| 450 |
+
return config_dict
|
| 451 |
+
|
| 452 |
+
outputs = (config_dict,)
|
| 453 |
+
|
| 454 |
+
if return_unused_kwargs:
|
| 455 |
+
outputs += (kwargs,)
|
| 456 |
+
|
| 457 |
+
if return_commit_hash:
|
| 458 |
+
outputs += (commit_hash,)
|
| 459 |
+
|
| 460 |
+
return outputs
|
| 461 |
+
|
| 462 |
+
@staticmethod
|
| 463 |
+
def _get_init_keys(input_class):
|
| 464 |
+
return set(dict(inspect.signature(input_class.__init__).parameters).keys())
|
| 465 |
+
|
| 466 |
+
@classmethod
|
| 467 |
+
def extract_init_dict(cls, config_dict, **kwargs):
|
| 468 |
+
# Skip keys that were not present in the original config, so default __init__ values were used
|
| 469 |
+
used_defaults = config_dict.get("_use_default_values", [])
|
| 470 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in used_defaults and k != "_use_default_values"}
|
| 471 |
+
|
| 472 |
+
# 0. Copy origin config dict
|
| 473 |
+
original_dict = dict(config_dict.items())
|
| 474 |
+
|
| 475 |
+
# 1. Retrieve expected config attributes from __init__ signature
|
| 476 |
+
expected_keys = cls._get_init_keys(cls)
|
| 477 |
+
expected_keys.remove("self")
|
| 478 |
+
# remove general kwargs if present in dict
|
| 479 |
+
if "kwargs" in expected_keys:
|
| 480 |
+
expected_keys.remove("kwargs")
|
| 481 |
+
# remove flax internal keys
|
| 482 |
+
if hasattr(cls, "_flax_internal_args"):
|
| 483 |
+
for arg in cls._flax_internal_args:
|
| 484 |
+
expected_keys.remove(arg)
|
| 485 |
+
|
| 486 |
+
# 2. Remove attributes that cannot be expected from expected config attributes
|
| 487 |
+
# remove keys to be ignored
|
| 488 |
+
if len(cls.ignore_for_config) > 0:
|
| 489 |
+
expected_keys = expected_keys - set(cls.ignore_for_config)
|
| 490 |
+
|
| 491 |
+
# load diffusers library to import compatible and original scheduler
|
| 492 |
+
diffusers_library = importlib.import_module(__name__.split(".")[0])
|
| 493 |
+
|
| 494 |
+
if cls.has_compatibles:
|
| 495 |
+
compatible_classes = [c for c in cls._get_compatibles() if not isinstance(c, DummyObject)]
|
| 496 |
+
else:
|
| 497 |
+
compatible_classes = []
|
| 498 |
+
|
| 499 |
+
expected_keys_comp_cls = set()
|
| 500 |
+
for c in compatible_classes:
|
| 501 |
+
expected_keys_c = cls._get_init_keys(c)
|
| 502 |
+
expected_keys_comp_cls = expected_keys_comp_cls.union(expected_keys_c)
|
| 503 |
+
expected_keys_comp_cls = expected_keys_comp_cls - cls._get_init_keys(cls)
|
| 504 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in expected_keys_comp_cls}
|
| 505 |
+
|
| 506 |
+
# remove attributes from orig class that cannot be expected
|
| 507 |
+
orig_cls_name = config_dict.pop("_class_name", cls.__name__)
|
| 508 |
+
if (
|
| 509 |
+
isinstance(orig_cls_name, str)
|
| 510 |
+
and orig_cls_name != cls.__name__
|
| 511 |
+
and hasattr(diffusers_library, orig_cls_name)
|
| 512 |
+
):
|
| 513 |
+
orig_cls = getattr(diffusers_library, orig_cls_name)
|
| 514 |
+
unexpected_keys_from_orig = cls._get_init_keys(orig_cls) - expected_keys
|
| 515 |
+
config_dict = {k: v for k, v in config_dict.items() if k not in unexpected_keys_from_orig}
|
| 516 |
+
elif not isinstance(orig_cls_name, str) and not isinstance(orig_cls_name, (list, tuple)):
|
| 517 |
+
raise ValueError(
|
| 518 |
+
"Make sure that the `_class_name` is of type string or list of string (for custom pipelines)."
|
| 519 |
+
)
|
| 520 |
+
|
| 521 |
+
# remove private attributes
|
| 522 |
+
config_dict = {k: v for k, v in config_dict.items() if not k.startswith("_")}
|
| 523 |
+
|
| 524 |
+
# remove quantization_config
|
| 525 |
+
config_dict = {k: v for k, v in config_dict.items() if k != "quantization_config"}
|
| 526 |
+
|
| 527 |
+
# 3. Create keyword arguments that will be passed to __init__ from expected keyword arguments
|
| 528 |
+
init_dict = {}
|
| 529 |
+
for key in expected_keys:
|
| 530 |
+
# if config param is passed to kwarg and is present in config dict
|
| 531 |
+
# it should overwrite existing config dict key
|
| 532 |
+
if key in kwargs and key in config_dict:
|
| 533 |
+
config_dict[key] = kwargs.pop(key)
|
| 534 |
+
|
| 535 |
+
if key in kwargs:
|
| 536 |
+
# overwrite key
|
| 537 |
+
init_dict[key] = kwargs.pop(key)
|
| 538 |
+
elif key in config_dict:
|
| 539 |
+
# use value from config dict
|
| 540 |
+
init_dict[key] = config_dict.pop(key)
|
| 541 |
+
|
| 542 |
+
# 4. Give nice warning if unexpected values have been passed
|
| 543 |
+
if len(config_dict) > 0:
|
| 544 |
+
logger.warning(
|
| 545 |
+
f"The config attributes {config_dict} were passed to {cls.__name__}, "
|
| 546 |
+
"but are not expected and will be ignored. Please verify your "
|
| 547 |
+
f"{cls.config_name} configuration file."
|
| 548 |
+
)
|
| 549 |
+
|
| 550 |
+
# 5. Give nice info if config attributes are initialized to default because they have not been passed
|
| 551 |
+
passed_keys = set(init_dict.keys())
|
| 552 |
+
if len(expected_keys - passed_keys) > 0:
|
| 553 |
+
logger.info(
|
| 554 |
+
f"{expected_keys - passed_keys} was not found in config. Values will be initialized to default values."
|
| 555 |
+
)
|
| 556 |
+
|
| 557 |
+
# 6. Define unused keyword arguments
|
| 558 |
+
unused_kwargs = {**config_dict, **kwargs}
|
| 559 |
+
|
| 560 |
+
# 7. Define "hidden" config parameters that were saved for compatible classes
|
| 561 |
+
hidden_config_dict = {k: v for k, v in original_dict.items() if k not in init_dict}
|
| 562 |
+
|
| 563 |
+
return init_dict, unused_kwargs, hidden_config_dict
|
| 564 |
+
|
| 565 |
+
@classmethod
|
| 566 |
+
def _dict_from_json_file(
|
| 567 |
+
cls, json_file: Union[str, os.PathLike], dduf_entries: Optional[Dict[str, DDUFEntry]] = None
|
| 568 |
+
):
|
| 569 |
+
if dduf_entries:
|
| 570 |
+
text = dduf_entries[json_file].read_text()
|
| 571 |
+
else:
|
| 572 |
+
with open(json_file, "r", encoding="utf-8") as reader:
|
| 573 |
+
text = reader.read()
|
| 574 |
+
return json.loads(text)
|
| 575 |
+
|
| 576 |
+
def __repr__(self):
|
| 577 |
+
return f"{self.__class__.__name__} {self.to_json_string()}"
|
| 578 |
+
|
| 579 |
+
@property
|
| 580 |
+
def config(self) -> Dict[str, Any]:
|
| 581 |
+
"""
|
| 582 |
+
Returns the config of the class as a frozen dictionary
|
| 583 |
+
|
| 584 |
+
Returns:
|
| 585 |
+
`Dict[str, Any]`: Config of the class.
|
| 586 |
+
"""
|
| 587 |
+
return self._internal_dict
|
| 588 |
+
|
| 589 |
+
def to_json_string(self) -> str:
|
| 590 |
+
"""
|
| 591 |
+
Serializes the configuration instance to a JSON string.
|
| 592 |
+
|
| 593 |
+
Returns:
|
| 594 |
+
`str`:
|
| 595 |
+
String containing all the attributes that make up the configuration instance in JSON format.
|
| 596 |
+
"""
|
| 597 |
+
config_dict = self._internal_dict if hasattr(self, "_internal_dict") else {}
|
| 598 |
+
config_dict["_class_name"] = self.__class__.__name__
|
| 599 |
+
config_dict["_diffusers_version"] = __version__
|
| 600 |
+
|
| 601 |
+
def to_json_saveable(value):
|
| 602 |
+
if isinstance(value, np.ndarray):
|
| 603 |
+
value = value.tolist()
|
| 604 |
+
elif isinstance(value, Path):
|
| 605 |
+
value = value.as_posix()
|
| 606 |
+
elif hasattr(value, "to_dict") and callable(value.to_dict):
|
| 607 |
+
value = value.to_dict()
|
| 608 |
+
elif isinstance(value, list):
|
| 609 |
+
value = [to_json_saveable(v) for v in value]
|
| 610 |
+
return value
|
| 611 |
+
|
| 612 |
+
if "quantization_config" in config_dict:
|
| 613 |
+
config_dict["quantization_config"] = (
|
| 614 |
+
config_dict.quantization_config.to_dict()
|
| 615 |
+
if not isinstance(config_dict.quantization_config, dict)
|
| 616 |
+
else config_dict.quantization_config
|
| 617 |
+
)
|
| 618 |
+
|
| 619 |
+
config_dict = {k: to_json_saveable(v) for k, v in config_dict.items()}
|
| 620 |
+
# Don't save "_ignore_files" or "_use_default_values"
|
| 621 |
+
config_dict.pop("_ignore_files", None)
|
| 622 |
+
config_dict.pop("_use_default_values", None)
|
| 623 |
+
# pop the `_pre_quantization_dtype` as torch.dtypes are not serializable.
|
| 624 |
+
_ = config_dict.pop("_pre_quantization_dtype", None)
|
| 625 |
+
|
| 626 |
+
return json.dumps(config_dict, indent=2, sort_keys=True) + "\n"
|
| 627 |
+
|
| 628 |
+
def to_json_file(self, json_file_path: Union[str, os.PathLike]):
|
| 629 |
+
"""
|
| 630 |
+
Save the configuration instance's parameters to a JSON file.
|
| 631 |
+
|
| 632 |
+
Args:
|
| 633 |
+
json_file_path (`str` or `os.PathLike`):
|
| 634 |
+
Path to the JSON file to save a configuration instance's parameters.
|
| 635 |
+
"""
|
| 636 |
+
with open(json_file_path, "w", encoding="utf-8") as writer:
|
| 637 |
+
writer.write(self.to_json_string())
|
| 638 |
+
|
| 639 |
+
@classmethod
|
| 640 |
+
def _get_config_file_from_dduf(cls, pretrained_model_name_or_path: str, dduf_entries: Dict[str, DDUFEntry]):
|
| 641 |
+
# paths inside a DDUF file must always be "/"
|
| 642 |
+
config_file = (
|
| 643 |
+
cls.config_name
|
| 644 |
+
if pretrained_model_name_or_path == ""
|
| 645 |
+
else "/".join([pretrained_model_name_or_path, cls.config_name])
|
| 646 |
+
)
|
| 647 |
+
if config_file not in dduf_entries:
|
| 648 |
+
raise ValueError(
|
| 649 |
+
f"We did not manage to find the file {config_file} in the dduf file. We only have the following files {dduf_entries.keys()}"
|
| 650 |
+
)
|
| 651 |
+
return config_file
|
| 652 |
+
|
| 653 |
+
|
| 654 |
+
def register_to_config(init):
|
| 655 |
+
r"""
|
| 656 |
+
Decorator to apply on the init of classes inheriting from [`ConfigMixin`] so that all the arguments are
|
| 657 |
+
automatically sent to `self.register_for_config`. To ignore a specific argument accepted by the init but that
|
| 658 |
+
shouldn't be registered in the config, use the `ignore_for_config` class variable
|
| 659 |
+
|
| 660 |
+
Warning: Once decorated, all private arguments (beginning with an underscore) are trashed and not sent to the init!
|
| 661 |
+
"""
|
| 662 |
+
|
| 663 |
+
@functools.wraps(init)
|
| 664 |
+
def inner_init(self, *args, **kwargs):
|
| 665 |
+
# Ignore private kwargs in the init.
|
| 666 |
+
init_kwargs = {k: v for k, v in kwargs.items() if not k.startswith("_")}
|
| 667 |
+
config_init_kwargs = {k: v for k, v in kwargs.items() if k.startswith("_")}
|
| 668 |
+
if not isinstance(self, ConfigMixin):
|
| 669 |
+
raise RuntimeError(
|
| 670 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
| 671 |
+
"not inherit from `ConfigMixin`."
|
| 672 |
+
)
|
| 673 |
+
|
| 674 |
+
ignore = getattr(self, "ignore_for_config", [])
|
| 675 |
+
# Get positional arguments aligned with kwargs
|
| 676 |
+
new_kwargs = {}
|
| 677 |
+
signature = inspect.signature(init)
|
| 678 |
+
parameters = {
|
| 679 |
+
name: p.default for i, (name, p) in enumerate(signature.parameters.items()) if i > 0 and name not in ignore
|
| 680 |
+
}
|
| 681 |
+
for arg, name in zip(args, parameters.keys()):
|
| 682 |
+
new_kwargs[name] = arg
|
| 683 |
+
|
| 684 |
+
# Then add all kwargs
|
| 685 |
+
new_kwargs.update(
|
| 686 |
+
{
|
| 687 |
+
k: init_kwargs.get(k, default)
|
| 688 |
+
for k, default in parameters.items()
|
| 689 |
+
if k not in ignore and k not in new_kwargs
|
| 690 |
+
}
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
# Take note of the parameters that were not present in the loaded config
|
| 694 |
+
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
| 695 |
+
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
| 696 |
+
|
| 697 |
+
new_kwargs = {**config_init_kwargs, **new_kwargs}
|
| 698 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
| 699 |
+
init(self, *args, **init_kwargs)
|
| 700 |
+
|
| 701 |
+
return inner_init
|
| 702 |
+
|
| 703 |
+
|
| 704 |
+
def flax_register_to_config(cls):
|
| 705 |
+
original_init = cls.__init__
|
| 706 |
+
|
| 707 |
+
@functools.wraps(original_init)
|
| 708 |
+
def init(self, *args, **kwargs):
|
| 709 |
+
if not isinstance(self, ConfigMixin):
|
| 710 |
+
raise RuntimeError(
|
| 711 |
+
f"`@register_for_config` was applied to {self.__class__.__name__} init method, but this class does "
|
| 712 |
+
"not inherit from `ConfigMixin`."
|
| 713 |
+
)
|
| 714 |
+
|
| 715 |
+
# Ignore private kwargs in the init. Retrieve all passed attributes
|
| 716 |
+
init_kwargs = dict(kwargs.items())
|
| 717 |
+
|
| 718 |
+
# Retrieve default values
|
| 719 |
+
fields = dataclasses.fields(self)
|
| 720 |
+
default_kwargs = {}
|
| 721 |
+
for field in fields:
|
| 722 |
+
# ignore flax specific attributes
|
| 723 |
+
if field.name in self._flax_internal_args:
|
| 724 |
+
continue
|
| 725 |
+
if type(field.default) == dataclasses._MISSING_TYPE:
|
| 726 |
+
default_kwargs[field.name] = None
|
| 727 |
+
else:
|
| 728 |
+
default_kwargs[field.name] = getattr(self, field.name)
|
| 729 |
+
|
| 730 |
+
# Make sure init_kwargs override default kwargs
|
| 731 |
+
new_kwargs = {**default_kwargs, **init_kwargs}
|
| 732 |
+
# dtype should be part of `init_kwargs`, but not `new_kwargs`
|
| 733 |
+
if "dtype" in new_kwargs:
|
| 734 |
+
new_kwargs.pop("dtype")
|
| 735 |
+
|
| 736 |
+
# Get positional arguments aligned with kwargs
|
| 737 |
+
for i, arg in enumerate(args):
|
| 738 |
+
name = fields[i].name
|
| 739 |
+
new_kwargs[name] = arg
|
| 740 |
+
|
| 741 |
+
# Take note of the parameters that were not present in the loaded config
|
| 742 |
+
if len(set(new_kwargs.keys()) - set(init_kwargs)) > 0:
|
| 743 |
+
new_kwargs["_use_default_values"] = list(set(new_kwargs.keys()) - set(init_kwargs))
|
| 744 |
+
|
| 745 |
+
getattr(self, "register_to_config")(**new_kwargs)
|
| 746 |
+
original_init(self, *args, **kwargs)
|
| 747 |
+
|
| 748 |
+
cls.__init__ = init
|
| 749 |
+
return cls
|
| 750 |
+
|
| 751 |
+
|
| 752 |
+
class LegacyConfigMixin(ConfigMixin):
|
| 753 |
+
r"""
|
| 754 |
+
A subclass of `ConfigMixin` to resolve class mapping from legacy classes (like `Transformer2DModel`) to more
|
| 755 |
+
pipeline-specific classes (like `DiTTransformer2DModel`).
|
| 756 |
+
"""
|
| 757 |
+
|
| 758 |
+
@classmethod
|
| 759 |
+
def from_config(cls, config: Union[FrozenDict, Dict[str, Any]] = None, return_unused_kwargs=False, **kwargs):
|
| 760 |
+
# To prevent dependency import problem.
|
| 761 |
+
from .models.model_loading_utils import _fetch_remapped_cls_from_config
|
| 762 |
+
|
| 763 |
+
# resolve remapping
|
| 764 |
+
remapped_class = _fetch_remapped_cls_from_config(config, cls)
|
| 765 |
+
|
| 766 |
+
if remapped_class is cls:
|
| 767 |
+
return super(LegacyConfigMixin, remapped_class).from_config(config, return_unused_kwargs, **kwargs)
|
| 768 |
+
else:
|
| 769 |
+
return remapped_class.from_config(config, return_unused_kwargs, **kwargs)
|
pythonProject/.venv/Lib/site-packages/diffusers/dependency_versions_check.py
ADDED
|
@@ -0,0 +1,34 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
from .dependency_versions_table import deps
|
| 16 |
+
from .utils.versions import require_version, require_version_core
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
# define which module versions we always want to check at run time
|
| 20 |
+
# (usually the ones defined in `install_requires` in setup.py)
|
| 21 |
+
#
|
| 22 |
+
# order specific notes:
|
| 23 |
+
# - tqdm must be checked before tokenizers
|
| 24 |
+
|
| 25 |
+
pkgs_to_check_at_runtime = "python requests filelock numpy".split()
|
| 26 |
+
for pkg in pkgs_to_check_at_runtime:
|
| 27 |
+
if pkg in deps:
|
| 28 |
+
require_version_core(deps[pkg])
|
| 29 |
+
else:
|
| 30 |
+
raise ValueError(f"can't find {pkg} in {deps.keys()}, check dependency_versions_table.py")
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def dep_version_check(pkg, hint=None):
|
| 34 |
+
require_version(deps[pkg], hint)
|
pythonProject/.venv/Lib/site-packages/diffusers/dependency_versions_table.py
ADDED
|
@@ -0,0 +1,54 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# THIS FILE HAS BEEN AUTOGENERATED. To update:
|
| 2 |
+
# 1. modify the `_deps` dict in setup.py
|
| 3 |
+
# 2. run `make deps_table_update`
|
| 4 |
+
deps = {
|
| 5 |
+
"Pillow": "Pillow",
|
| 6 |
+
"accelerate": "accelerate>=0.31.0",
|
| 7 |
+
"compel": "compel==0.1.8",
|
| 8 |
+
"datasets": "datasets",
|
| 9 |
+
"filelock": "filelock",
|
| 10 |
+
"flax": "flax>=0.4.1",
|
| 11 |
+
"hf-doc-builder": "hf-doc-builder>=0.3.0",
|
| 12 |
+
"huggingface-hub": "huggingface-hub>=0.34.0",
|
| 13 |
+
"requests-mock": "requests-mock==1.10.0",
|
| 14 |
+
"importlib_metadata": "importlib_metadata",
|
| 15 |
+
"invisible-watermark": "invisible-watermark>=0.2.0",
|
| 16 |
+
"isort": "isort>=5.5.4",
|
| 17 |
+
"jax": "jax>=0.4.1",
|
| 18 |
+
"jaxlib": "jaxlib>=0.4.1",
|
| 19 |
+
"Jinja2": "Jinja2",
|
| 20 |
+
"k-diffusion": "k-diffusion==0.0.12",
|
| 21 |
+
"torchsde": "torchsde",
|
| 22 |
+
"note_seq": "note_seq",
|
| 23 |
+
"librosa": "librosa",
|
| 24 |
+
"numpy": "numpy",
|
| 25 |
+
"parameterized": "parameterized",
|
| 26 |
+
"peft": "peft>=0.17.0",
|
| 27 |
+
"protobuf": "protobuf>=3.20.3,<4",
|
| 28 |
+
"pytest": "pytest",
|
| 29 |
+
"pytest-timeout": "pytest-timeout",
|
| 30 |
+
"pytest-xdist": "pytest-xdist",
|
| 31 |
+
"python": "python>=3.8.0",
|
| 32 |
+
"ruff": "ruff==0.9.10",
|
| 33 |
+
"safetensors": "safetensors>=0.3.1",
|
| 34 |
+
"sentencepiece": "sentencepiece>=0.1.91,!=0.1.92",
|
| 35 |
+
"GitPython": "GitPython<3.1.19",
|
| 36 |
+
"scipy": "scipy",
|
| 37 |
+
"onnx": "onnx",
|
| 38 |
+
"optimum_quanto": "optimum_quanto>=0.2.6",
|
| 39 |
+
"gguf": "gguf>=0.10.0",
|
| 40 |
+
"torchao": "torchao>=0.7.0",
|
| 41 |
+
"bitsandbytes": "bitsandbytes>=0.43.3",
|
| 42 |
+
"nvidia_modelopt[hf]": "nvidia_modelopt[hf]>=0.33.1",
|
| 43 |
+
"regex": "regex!=2019.12.17",
|
| 44 |
+
"requests": "requests",
|
| 45 |
+
"tensorboard": "tensorboard",
|
| 46 |
+
"tiktoken": "tiktoken>=0.7.0",
|
| 47 |
+
"torch": "torch>=1.4",
|
| 48 |
+
"torchvision": "torchvision",
|
| 49 |
+
"transformers": "transformers>=4.41.2",
|
| 50 |
+
"urllib3": "urllib3<=2.0.0",
|
| 51 |
+
"black": "black",
|
| 52 |
+
"phonemizer": "phonemizer",
|
| 53 |
+
"opencv-python": "opencv-python",
|
| 54 |
+
}
|
pythonProject/.venv/Lib/site-packages/diffusers/image_processor.py
ADDED
|
@@ -0,0 +1,1451 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import math
|
| 16 |
+
import warnings
|
| 17 |
+
from typing import List, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import PIL.Image
|
| 21 |
+
import torch
|
| 22 |
+
import torch.nn.functional as F
|
| 23 |
+
from PIL import Image, ImageFilter, ImageOps
|
| 24 |
+
|
| 25 |
+
from .configuration_utils import ConfigMixin, register_to_config
|
| 26 |
+
from .utils import CONFIG_NAME, PIL_INTERPOLATION, deprecate
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
PipelineImageInput = Union[
|
| 30 |
+
PIL.Image.Image,
|
| 31 |
+
np.ndarray,
|
| 32 |
+
torch.Tensor,
|
| 33 |
+
List[PIL.Image.Image],
|
| 34 |
+
List[np.ndarray],
|
| 35 |
+
List[torch.Tensor],
|
| 36 |
+
]
|
| 37 |
+
|
| 38 |
+
PipelineDepthInput = PipelineImageInput
|
| 39 |
+
|
| 40 |
+
|
| 41 |
+
def is_valid_image(image) -> bool:
|
| 42 |
+
r"""
|
| 43 |
+
Checks if the input is a valid image.
|
| 44 |
+
|
| 45 |
+
A valid image can be:
|
| 46 |
+
- A `PIL.Image.Image`.
|
| 47 |
+
- A 2D or 3D `np.ndarray` or `torch.Tensor` (grayscale or color image).
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
|
| 51 |
+
The image to validate. It can be a PIL image, a NumPy array, or a torch tensor.
|
| 52 |
+
|
| 53 |
+
Returns:
|
| 54 |
+
`bool`:
|
| 55 |
+
`True` if the input is a valid image, `False` otherwise.
|
| 56 |
+
"""
|
| 57 |
+
return isinstance(image, PIL.Image.Image) or isinstance(image, (np.ndarray, torch.Tensor)) and image.ndim in (2, 3)
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def is_valid_image_imagelist(images):
|
| 61 |
+
r"""
|
| 62 |
+
Checks if the input is a valid image or list of images.
|
| 63 |
+
|
| 64 |
+
The input can be one of the following formats:
|
| 65 |
+
- A 4D tensor or numpy array (batch of images).
|
| 66 |
+
- A valid single image: `PIL.Image.Image`, 2D `np.ndarray` or `torch.Tensor` (grayscale image), 3D `np.ndarray` or
|
| 67 |
+
`torch.Tensor`.
|
| 68 |
+
- A list of valid images.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
images (`Union[np.ndarray, torch.Tensor, PIL.Image.Image, List]`):
|
| 72 |
+
The image(s) to check. Can be a batch of images (4D tensor/array), a single image, or a list of valid
|
| 73 |
+
images.
|
| 74 |
+
|
| 75 |
+
Returns:
|
| 76 |
+
`bool`:
|
| 77 |
+
`True` if the input is valid, `False` otherwise.
|
| 78 |
+
"""
|
| 79 |
+
if isinstance(images, (np.ndarray, torch.Tensor)) and images.ndim == 4:
|
| 80 |
+
return True
|
| 81 |
+
elif is_valid_image(images):
|
| 82 |
+
return True
|
| 83 |
+
elif isinstance(images, list):
|
| 84 |
+
return all(is_valid_image(image) for image in images)
|
| 85 |
+
return False
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
class VaeImageProcessor(ConfigMixin):
|
| 89 |
+
"""
|
| 90 |
+
Image processor for VAE.
|
| 91 |
+
|
| 92 |
+
Args:
|
| 93 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 94 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
| 95 |
+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
| 96 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
| 97 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
| 98 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
| 99 |
+
Resampling filter to use when resizing the image.
|
| 100 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 101 |
+
Whether to normalize the image to [-1,1].
|
| 102 |
+
do_binarize (`bool`, *optional*, defaults to `False`):
|
| 103 |
+
Whether to binarize the image to 0/1.
|
| 104 |
+
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
| 105 |
+
Whether to convert the images to RGB format.
|
| 106 |
+
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
|
| 107 |
+
Whether to convert the images to grayscale format.
|
| 108 |
+
"""
|
| 109 |
+
|
| 110 |
+
config_name = CONFIG_NAME
|
| 111 |
+
|
| 112 |
+
@register_to_config
|
| 113 |
+
def __init__(
|
| 114 |
+
self,
|
| 115 |
+
do_resize: bool = True,
|
| 116 |
+
vae_scale_factor: int = 8,
|
| 117 |
+
vae_latent_channels: int = 4,
|
| 118 |
+
resample: str = "lanczos",
|
| 119 |
+
reducing_gap: int = None,
|
| 120 |
+
do_normalize: bool = True,
|
| 121 |
+
do_binarize: bool = False,
|
| 122 |
+
do_convert_rgb: bool = False,
|
| 123 |
+
do_convert_grayscale: bool = False,
|
| 124 |
+
):
|
| 125 |
+
super().__init__()
|
| 126 |
+
if do_convert_rgb and do_convert_grayscale:
|
| 127 |
+
raise ValueError(
|
| 128 |
+
"`do_convert_rgb` and `do_convert_grayscale` can not both be set to `True`,"
|
| 129 |
+
" if you intended to convert the image into RGB format, please set `do_convert_grayscale = False`.",
|
| 130 |
+
" if you intended to convert the image into grayscale format, please set `do_convert_rgb = False`",
|
| 131 |
+
)
|
| 132 |
+
|
| 133 |
+
@staticmethod
|
| 134 |
+
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
| 135 |
+
r"""
|
| 136 |
+
Convert a numpy image or a batch of images to a PIL image.
|
| 137 |
+
|
| 138 |
+
Args:
|
| 139 |
+
images (`np.ndarray`):
|
| 140 |
+
The image array to convert to PIL format.
|
| 141 |
+
|
| 142 |
+
Returns:
|
| 143 |
+
`List[PIL.Image.Image]`:
|
| 144 |
+
A list of PIL images.
|
| 145 |
+
"""
|
| 146 |
+
if images.ndim == 3:
|
| 147 |
+
images = images[None, ...]
|
| 148 |
+
images = (images * 255).round().astype("uint8")
|
| 149 |
+
if images.shape[-1] == 1:
|
| 150 |
+
# special case for grayscale (single channel) images
|
| 151 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
| 152 |
+
else:
|
| 153 |
+
pil_images = [Image.fromarray(image) for image in images]
|
| 154 |
+
|
| 155 |
+
return pil_images
|
| 156 |
+
|
| 157 |
+
@staticmethod
|
| 158 |
+
def pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
| 159 |
+
r"""
|
| 160 |
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
images (`PIL.Image.Image` or `List[PIL.Image.Image]`):
|
| 164 |
+
The PIL image or list of images to convert to NumPy format.
|
| 165 |
+
|
| 166 |
+
Returns:
|
| 167 |
+
`np.ndarray`:
|
| 168 |
+
A NumPy array representation of the images.
|
| 169 |
+
"""
|
| 170 |
+
if not isinstance(images, list):
|
| 171 |
+
images = [images]
|
| 172 |
+
images = [np.array(image).astype(np.float32) / 255.0 for image in images]
|
| 173 |
+
images = np.stack(images, axis=0)
|
| 174 |
+
|
| 175 |
+
return images
|
| 176 |
+
|
| 177 |
+
@staticmethod
|
| 178 |
+
def numpy_to_pt(images: np.ndarray) -> torch.Tensor:
|
| 179 |
+
r"""
|
| 180 |
+
Convert a NumPy image to a PyTorch tensor.
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
images (`np.ndarray`):
|
| 184 |
+
The NumPy image array to convert to PyTorch format.
|
| 185 |
+
|
| 186 |
+
Returns:
|
| 187 |
+
`torch.Tensor`:
|
| 188 |
+
A PyTorch tensor representation of the images.
|
| 189 |
+
"""
|
| 190 |
+
if images.ndim == 3:
|
| 191 |
+
images = images[..., None]
|
| 192 |
+
|
| 193 |
+
images = torch.from_numpy(images.transpose(0, 3, 1, 2))
|
| 194 |
+
return images
|
| 195 |
+
|
| 196 |
+
@staticmethod
|
| 197 |
+
def pt_to_numpy(images: torch.Tensor) -> np.ndarray:
|
| 198 |
+
r"""
|
| 199 |
+
Convert a PyTorch tensor to a NumPy image.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
images (`torch.Tensor`):
|
| 203 |
+
The PyTorch tensor to convert to NumPy format.
|
| 204 |
+
|
| 205 |
+
Returns:
|
| 206 |
+
`np.ndarray`:
|
| 207 |
+
A NumPy array representation of the images.
|
| 208 |
+
"""
|
| 209 |
+
images = images.cpu().permute(0, 2, 3, 1).float().numpy()
|
| 210 |
+
return images
|
| 211 |
+
|
| 212 |
+
@staticmethod
|
| 213 |
+
def normalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
| 214 |
+
r"""
|
| 215 |
+
Normalize an image array to [-1,1].
|
| 216 |
+
|
| 217 |
+
Args:
|
| 218 |
+
images (`np.ndarray` or `torch.Tensor`):
|
| 219 |
+
The image array to normalize.
|
| 220 |
+
|
| 221 |
+
Returns:
|
| 222 |
+
`np.ndarray` or `torch.Tensor`:
|
| 223 |
+
The normalized image array.
|
| 224 |
+
"""
|
| 225 |
+
return 2.0 * images - 1.0
|
| 226 |
+
|
| 227 |
+
@staticmethod
|
| 228 |
+
def denormalize(images: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
| 229 |
+
r"""
|
| 230 |
+
Denormalize an image array to [0,1].
|
| 231 |
+
|
| 232 |
+
Args:
|
| 233 |
+
images (`np.ndarray` or `torch.Tensor`):
|
| 234 |
+
The image array to denormalize.
|
| 235 |
+
|
| 236 |
+
Returns:
|
| 237 |
+
`np.ndarray` or `torch.Tensor`:
|
| 238 |
+
The denormalized image array.
|
| 239 |
+
"""
|
| 240 |
+
return (images * 0.5 + 0.5).clamp(0, 1)
|
| 241 |
+
|
| 242 |
+
@staticmethod
|
| 243 |
+
def convert_to_rgb(image: PIL.Image.Image) -> PIL.Image.Image:
|
| 244 |
+
r"""
|
| 245 |
+
Converts a PIL image to RGB format.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
image (`PIL.Image.Image`):
|
| 249 |
+
The PIL image to convert to RGB.
|
| 250 |
+
|
| 251 |
+
Returns:
|
| 252 |
+
`PIL.Image.Image`:
|
| 253 |
+
The RGB-converted PIL image.
|
| 254 |
+
"""
|
| 255 |
+
image = image.convert("RGB")
|
| 256 |
+
|
| 257 |
+
return image
|
| 258 |
+
|
| 259 |
+
@staticmethod
|
| 260 |
+
def convert_to_grayscale(image: PIL.Image.Image) -> PIL.Image.Image:
|
| 261 |
+
r"""
|
| 262 |
+
Converts a given PIL image to grayscale.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
image (`PIL.Image.Image`):
|
| 266 |
+
The input image to convert.
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
`PIL.Image.Image`:
|
| 270 |
+
The image converted to grayscale.
|
| 271 |
+
"""
|
| 272 |
+
image = image.convert("L")
|
| 273 |
+
|
| 274 |
+
return image
|
| 275 |
+
|
| 276 |
+
@staticmethod
|
| 277 |
+
def blur(image: PIL.Image.Image, blur_factor: int = 4) -> PIL.Image.Image:
|
| 278 |
+
r"""
|
| 279 |
+
Applies Gaussian blur to an image.
|
| 280 |
+
|
| 281 |
+
Args:
|
| 282 |
+
image (`PIL.Image.Image`):
|
| 283 |
+
The PIL image to convert to grayscale.
|
| 284 |
+
|
| 285 |
+
Returns:
|
| 286 |
+
`PIL.Image.Image`:
|
| 287 |
+
The grayscale-converted PIL image.
|
| 288 |
+
"""
|
| 289 |
+
image = image.filter(ImageFilter.GaussianBlur(blur_factor))
|
| 290 |
+
|
| 291 |
+
return image
|
| 292 |
+
|
| 293 |
+
@staticmethod
|
| 294 |
+
def get_crop_region(mask_image: PIL.Image.Image, width: int, height: int, pad=0):
|
| 295 |
+
r"""
|
| 296 |
+
Finds a rectangular region that contains all masked ares in an image, and expands region to match the aspect
|
| 297 |
+
ratio of the original image; for example, if user drew mask in a 128x32 region, and the dimensions for
|
| 298 |
+
processing are 512x512, the region will be expanded to 128x128.
|
| 299 |
+
|
| 300 |
+
Args:
|
| 301 |
+
mask_image (PIL.Image.Image): Mask image.
|
| 302 |
+
width (int): Width of the image to be processed.
|
| 303 |
+
height (int): Height of the image to be processed.
|
| 304 |
+
pad (int, optional): Padding to be added to the crop region. Defaults to 0.
|
| 305 |
+
|
| 306 |
+
Returns:
|
| 307 |
+
tuple: (x1, y1, x2, y2) represent a rectangular region that contains all masked ares in an image and
|
| 308 |
+
matches the original aspect ratio.
|
| 309 |
+
"""
|
| 310 |
+
|
| 311 |
+
mask_image = mask_image.convert("L")
|
| 312 |
+
mask = np.array(mask_image)
|
| 313 |
+
|
| 314 |
+
# 1. find a rectangular region that contains all masked ares in an image
|
| 315 |
+
h, w = mask.shape
|
| 316 |
+
crop_left = 0
|
| 317 |
+
for i in range(w):
|
| 318 |
+
if not (mask[:, i] == 0).all():
|
| 319 |
+
break
|
| 320 |
+
crop_left += 1
|
| 321 |
+
|
| 322 |
+
crop_right = 0
|
| 323 |
+
for i in reversed(range(w)):
|
| 324 |
+
if not (mask[:, i] == 0).all():
|
| 325 |
+
break
|
| 326 |
+
crop_right += 1
|
| 327 |
+
|
| 328 |
+
crop_top = 0
|
| 329 |
+
for i in range(h):
|
| 330 |
+
if not (mask[i] == 0).all():
|
| 331 |
+
break
|
| 332 |
+
crop_top += 1
|
| 333 |
+
|
| 334 |
+
crop_bottom = 0
|
| 335 |
+
for i in reversed(range(h)):
|
| 336 |
+
if not (mask[i] == 0).all():
|
| 337 |
+
break
|
| 338 |
+
crop_bottom += 1
|
| 339 |
+
|
| 340 |
+
# 2. add padding to the crop region
|
| 341 |
+
x1, y1, x2, y2 = (
|
| 342 |
+
int(max(crop_left - pad, 0)),
|
| 343 |
+
int(max(crop_top - pad, 0)),
|
| 344 |
+
int(min(w - crop_right + pad, w)),
|
| 345 |
+
int(min(h - crop_bottom + pad, h)),
|
| 346 |
+
)
|
| 347 |
+
|
| 348 |
+
# 3. expands crop region to match the aspect ratio of the image to be processed
|
| 349 |
+
ratio_crop_region = (x2 - x1) / (y2 - y1)
|
| 350 |
+
ratio_processing = width / height
|
| 351 |
+
|
| 352 |
+
if ratio_crop_region > ratio_processing:
|
| 353 |
+
desired_height = (x2 - x1) / ratio_processing
|
| 354 |
+
desired_height_diff = int(desired_height - (y2 - y1))
|
| 355 |
+
y1 -= desired_height_diff // 2
|
| 356 |
+
y2 += desired_height_diff - desired_height_diff // 2
|
| 357 |
+
if y2 >= mask_image.height:
|
| 358 |
+
diff = y2 - mask_image.height
|
| 359 |
+
y2 -= diff
|
| 360 |
+
y1 -= diff
|
| 361 |
+
if y1 < 0:
|
| 362 |
+
y2 -= y1
|
| 363 |
+
y1 -= y1
|
| 364 |
+
if y2 >= mask_image.height:
|
| 365 |
+
y2 = mask_image.height
|
| 366 |
+
else:
|
| 367 |
+
desired_width = (y2 - y1) * ratio_processing
|
| 368 |
+
desired_width_diff = int(desired_width - (x2 - x1))
|
| 369 |
+
x1 -= desired_width_diff // 2
|
| 370 |
+
x2 += desired_width_diff - desired_width_diff // 2
|
| 371 |
+
if x2 >= mask_image.width:
|
| 372 |
+
diff = x2 - mask_image.width
|
| 373 |
+
x2 -= diff
|
| 374 |
+
x1 -= diff
|
| 375 |
+
if x1 < 0:
|
| 376 |
+
x2 -= x1
|
| 377 |
+
x1 -= x1
|
| 378 |
+
if x2 >= mask_image.width:
|
| 379 |
+
x2 = mask_image.width
|
| 380 |
+
|
| 381 |
+
return x1, y1, x2, y2
|
| 382 |
+
|
| 383 |
+
def _resize_and_fill(
|
| 384 |
+
self,
|
| 385 |
+
image: PIL.Image.Image,
|
| 386 |
+
width: int,
|
| 387 |
+
height: int,
|
| 388 |
+
) -> PIL.Image.Image:
|
| 389 |
+
r"""
|
| 390 |
+
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
| 391 |
+
the image within the dimensions, filling empty with data from image.
|
| 392 |
+
|
| 393 |
+
Args:
|
| 394 |
+
image (`PIL.Image.Image`):
|
| 395 |
+
The image to resize and fill.
|
| 396 |
+
width (`int`):
|
| 397 |
+
The width to resize the image to.
|
| 398 |
+
height (`int`):
|
| 399 |
+
The height to resize the image to.
|
| 400 |
+
|
| 401 |
+
Returns:
|
| 402 |
+
`PIL.Image.Image`:
|
| 403 |
+
The resized and filled image.
|
| 404 |
+
"""
|
| 405 |
+
|
| 406 |
+
ratio = width / height
|
| 407 |
+
src_ratio = image.width / image.height
|
| 408 |
+
|
| 409 |
+
src_w = width if ratio < src_ratio else image.width * height // image.height
|
| 410 |
+
src_h = height if ratio >= src_ratio else image.height * width // image.width
|
| 411 |
+
|
| 412 |
+
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
| 413 |
+
res = Image.new("RGB", (width, height))
|
| 414 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
| 415 |
+
|
| 416 |
+
if ratio < src_ratio:
|
| 417 |
+
fill_height = height // 2 - src_h // 2
|
| 418 |
+
if fill_height > 0:
|
| 419 |
+
res.paste(resized.resize((width, fill_height), box=(0, 0, width, 0)), box=(0, 0))
|
| 420 |
+
res.paste(
|
| 421 |
+
resized.resize((width, fill_height), box=(0, resized.height, width, resized.height)),
|
| 422 |
+
box=(0, fill_height + src_h),
|
| 423 |
+
)
|
| 424 |
+
elif ratio > src_ratio:
|
| 425 |
+
fill_width = width // 2 - src_w // 2
|
| 426 |
+
if fill_width > 0:
|
| 427 |
+
res.paste(resized.resize((fill_width, height), box=(0, 0, 0, height)), box=(0, 0))
|
| 428 |
+
res.paste(
|
| 429 |
+
resized.resize((fill_width, height), box=(resized.width, 0, resized.width, height)),
|
| 430 |
+
box=(fill_width + src_w, 0),
|
| 431 |
+
)
|
| 432 |
+
|
| 433 |
+
return res
|
| 434 |
+
|
| 435 |
+
def _resize_and_crop(
|
| 436 |
+
self,
|
| 437 |
+
image: PIL.Image.Image,
|
| 438 |
+
width: int,
|
| 439 |
+
height: int,
|
| 440 |
+
) -> PIL.Image.Image:
|
| 441 |
+
r"""
|
| 442 |
+
Resize the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
| 443 |
+
the image within the dimensions, cropping the excess.
|
| 444 |
+
|
| 445 |
+
Args:
|
| 446 |
+
image (`PIL.Image.Image`):
|
| 447 |
+
The image to resize and crop.
|
| 448 |
+
width (`int`):
|
| 449 |
+
The width to resize the image to.
|
| 450 |
+
height (`int`):
|
| 451 |
+
The height to resize the image to.
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
`PIL.Image.Image`:
|
| 455 |
+
The resized and cropped image.
|
| 456 |
+
"""
|
| 457 |
+
ratio = width / height
|
| 458 |
+
src_ratio = image.width / image.height
|
| 459 |
+
|
| 460 |
+
src_w = width if ratio > src_ratio else image.width * height // image.height
|
| 461 |
+
src_h = height if ratio <= src_ratio else image.height * width // image.width
|
| 462 |
+
|
| 463 |
+
resized = image.resize((src_w, src_h), resample=PIL_INTERPOLATION["lanczos"])
|
| 464 |
+
res = Image.new("RGB", (width, height))
|
| 465 |
+
res.paste(resized, box=(width // 2 - src_w // 2, height // 2 - src_h // 2))
|
| 466 |
+
return res
|
| 467 |
+
|
| 468 |
+
def resize(
|
| 469 |
+
self,
|
| 470 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
| 471 |
+
height: int,
|
| 472 |
+
width: int,
|
| 473 |
+
resize_mode: str = "default", # "default", "fill", "crop"
|
| 474 |
+
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
| 475 |
+
"""
|
| 476 |
+
Resize image.
|
| 477 |
+
|
| 478 |
+
Args:
|
| 479 |
+
image (`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`):
|
| 480 |
+
The image input, can be a PIL image, numpy array or pytorch tensor.
|
| 481 |
+
height (`int`):
|
| 482 |
+
The height to resize to.
|
| 483 |
+
width (`int`):
|
| 484 |
+
The width to resize to.
|
| 485 |
+
resize_mode (`str`, *optional*, defaults to `default`):
|
| 486 |
+
The resize mode to use, can be one of `default` or `fill`. If `default`, will resize the image to fit
|
| 487 |
+
within the specified width and height, and it may not maintaining the original aspect ratio. If `fill`,
|
| 488 |
+
will resize the image to fit within the specified width and height, maintaining the aspect ratio, and
|
| 489 |
+
then center the image within the dimensions, filling empty with data from image. If `crop`, will resize
|
| 490 |
+
the image to fit within the specified width and height, maintaining the aspect ratio, and then center
|
| 491 |
+
the image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
| 492 |
+
supported for PIL image input.
|
| 493 |
+
|
| 494 |
+
Returns:
|
| 495 |
+
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
| 496 |
+
The resized image.
|
| 497 |
+
"""
|
| 498 |
+
if resize_mode != "default" and not isinstance(image, PIL.Image.Image):
|
| 499 |
+
raise ValueError(f"Only PIL image input is supported for resize_mode {resize_mode}")
|
| 500 |
+
if isinstance(image, PIL.Image.Image):
|
| 501 |
+
if resize_mode == "default":
|
| 502 |
+
image = image.resize(
|
| 503 |
+
(width, height),
|
| 504 |
+
resample=PIL_INTERPOLATION[self.config.resample],
|
| 505 |
+
reducing_gap=self.config.reducing_gap,
|
| 506 |
+
)
|
| 507 |
+
elif resize_mode == "fill":
|
| 508 |
+
image = self._resize_and_fill(image, width, height)
|
| 509 |
+
elif resize_mode == "crop":
|
| 510 |
+
image = self._resize_and_crop(image, width, height)
|
| 511 |
+
else:
|
| 512 |
+
raise ValueError(f"resize_mode {resize_mode} is not supported")
|
| 513 |
+
|
| 514 |
+
elif isinstance(image, torch.Tensor):
|
| 515 |
+
image = torch.nn.functional.interpolate(
|
| 516 |
+
image,
|
| 517 |
+
size=(height, width),
|
| 518 |
+
)
|
| 519 |
+
elif isinstance(image, np.ndarray):
|
| 520 |
+
image = self.numpy_to_pt(image)
|
| 521 |
+
image = torch.nn.functional.interpolate(
|
| 522 |
+
image,
|
| 523 |
+
size=(height, width),
|
| 524 |
+
)
|
| 525 |
+
image = self.pt_to_numpy(image)
|
| 526 |
+
|
| 527 |
+
return image
|
| 528 |
+
|
| 529 |
+
def binarize(self, image: PIL.Image.Image) -> PIL.Image.Image:
|
| 530 |
+
"""
|
| 531 |
+
Create a mask.
|
| 532 |
+
|
| 533 |
+
Args:
|
| 534 |
+
image (`PIL.Image.Image`):
|
| 535 |
+
The image input, should be a PIL image.
|
| 536 |
+
|
| 537 |
+
Returns:
|
| 538 |
+
`PIL.Image.Image`:
|
| 539 |
+
The binarized image. Values less than 0.5 are set to 0, values greater than 0.5 are set to 1.
|
| 540 |
+
"""
|
| 541 |
+
image[image < 0.5] = 0
|
| 542 |
+
image[image >= 0.5] = 1
|
| 543 |
+
|
| 544 |
+
return image
|
| 545 |
+
|
| 546 |
+
def _denormalize_conditionally(
|
| 547 |
+
self, images: torch.Tensor, do_denormalize: Optional[List[bool]] = None
|
| 548 |
+
) -> torch.Tensor:
|
| 549 |
+
r"""
|
| 550 |
+
Denormalize a batch of images based on a condition list.
|
| 551 |
+
|
| 552 |
+
Args:
|
| 553 |
+
images (`torch.Tensor`):
|
| 554 |
+
The input image tensor.
|
| 555 |
+
do_denormalize (`Optional[List[bool]`, *optional*, defaults to `None`):
|
| 556 |
+
A list of booleans indicating whether to denormalize each image in the batch. If `None`, will use the
|
| 557 |
+
value of `do_normalize` in the `VaeImageProcessor` config.
|
| 558 |
+
"""
|
| 559 |
+
if do_denormalize is None:
|
| 560 |
+
return self.denormalize(images) if self.config.do_normalize else images
|
| 561 |
+
|
| 562 |
+
return torch.stack(
|
| 563 |
+
[self.denormalize(images[i]) if do_denormalize[i] else images[i] for i in range(images.shape[0])]
|
| 564 |
+
)
|
| 565 |
+
|
| 566 |
+
def get_default_height_width(
|
| 567 |
+
self,
|
| 568 |
+
image: Union[PIL.Image.Image, np.ndarray, torch.Tensor],
|
| 569 |
+
height: Optional[int] = None,
|
| 570 |
+
width: Optional[int] = None,
|
| 571 |
+
) -> Tuple[int, int]:
|
| 572 |
+
r"""
|
| 573 |
+
Returns the height and width of the image, downscaled to the next integer multiple of `vae_scale_factor`.
|
| 574 |
+
|
| 575 |
+
Args:
|
| 576 |
+
image (`Union[PIL.Image.Image, np.ndarray, torch.Tensor]`):
|
| 577 |
+
The image input, which can be a PIL image, NumPy array, or PyTorch tensor. If it is a NumPy array, it
|
| 578 |
+
should have shape `[batch, height, width]` or `[batch, height, width, channels]`. If it is a PyTorch
|
| 579 |
+
tensor, it should have shape `[batch, channels, height, width]`.
|
| 580 |
+
height (`Optional[int]`, *optional*, defaults to `None`):
|
| 581 |
+
The height of the preprocessed image. If `None`, the height of the `image` input will be used.
|
| 582 |
+
width (`Optional[int]`, *optional*, defaults to `None`):
|
| 583 |
+
The width of the preprocessed image. If `None`, the width of the `image` input will be used.
|
| 584 |
+
|
| 585 |
+
Returns:
|
| 586 |
+
`Tuple[int, int]`:
|
| 587 |
+
A tuple containing the height and width, both resized to the nearest integer multiple of
|
| 588 |
+
`vae_scale_factor`.
|
| 589 |
+
"""
|
| 590 |
+
|
| 591 |
+
if height is None:
|
| 592 |
+
if isinstance(image, PIL.Image.Image):
|
| 593 |
+
height = image.height
|
| 594 |
+
elif isinstance(image, torch.Tensor):
|
| 595 |
+
height = image.shape[2]
|
| 596 |
+
else:
|
| 597 |
+
height = image.shape[1]
|
| 598 |
+
|
| 599 |
+
if width is None:
|
| 600 |
+
if isinstance(image, PIL.Image.Image):
|
| 601 |
+
width = image.width
|
| 602 |
+
elif isinstance(image, torch.Tensor):
|
| 603 |
+
width = image.shape[3]
|
| 604 |
+
else:
|
| 605 |
+
width = image.shape[2]
|
| 606 |
+
|
| 607 |
+
width, height = (
|
| 608 |
+
x - x % self.config.vae_scale_factor for x in (width, height)
|
| 609 |
+
) # resize to integer multiple of vae_scale_factor
|
| 610 |
+
|
| 611 |
+
return height, width
|
| 612 |
+
|
| 613 |
+
def preprocess(
|
| 614 |
+
self,
|
| 615 |
+
image: PipelineImageInput,
|
| 616 |
+
height: Optional[int] = None,
|
| 617 |
+
width: Optional[int] = None,
|
| 618 |
+
resize_mode: str = "default", # "default", "fill", "crop"
|
| 619 |
+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
| 620 |
+
) -> torch.Tensor:
|
| 621 |
+
"""
|
| 622 |
+
Preprocess the image input.
|
| 623 |
+
|
| 624 |
+
Args:
|
| 625 |
+
image (`PipelineImageInput`):
|
| 626 |
+
The image input, accepted formats are PIL images, NumPy arrays, PyTorch tensors; Also accept list of
|
| 627 |
+
supported formats.
|
| 628 |
+
height (`int`, *optional*):
|
| 629 |
+
The height in preprocessed image. If `None`, will use the `get_default_height_width()` to get default
|
| 630 |
+
height.
|
| 631 |
+
width (`int`, *optional*):
|
| 632 |
+
The width in preprocessed. If `None`, will use get_default_height_width()` to get the default width.
|
| 633 |
+
resize_mode (`str`, *optional*, defaults to `default`):
|
| 634 |
+
The resize mode, can be one of `default` or `fill`. If `default`, will resize the image to fit within
|
| 635 |
+
the specified width and height, and it may not maintaining the original aspect ratio. If `fill`, will
|
| 636 |
+
resize the image to fit within the specified width and height, maintaining the aspect ratio, and then
|
| 637 |
+
center the image within the dimensions, filling empty with data from image. If `crop`, will resize the
|
| 638 |
+
image to fit within the specified width and height, maintaining the aspect ratio, and then center the
|
| 639 |
+
image within the dimensions, cropping the excess. Note that resize_mode `fill` and `crop` are only
|
| 640 |
+
supported for PIL image input.
|
| 641 |
+
crops_coords (`List[Tuple[int, int, int, int]]`, *optional*, defaults to `None`):
|
| 642 |
+
The crop coordinates for each image in the batch. If `None`, will not crop the image.
|
| 643 |
+
|
| 644 |
+
Returns:
|
| 645 |
+
`torch.Tensor`:
|
| 646 |
+
The preprocessed image.
|
| 647 |
+
"""
|
| 648 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
| 649 |
+
|
| 650 |
+
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
| 651 |
+
if self.config.do_convert_grayscale and isinstance(image, (torch.Tensor, np.ndarray)) and image.ndim == 3:
|
| 652 |
+
if isinstance(image, torch.Tensor):
|
| 653 |
+
# if image is a pytorch tensor could have 2 possible shapes:
|
| 654 |
+
# 1. batch x height x width: we should insert the channel dimension at position 1
|
| 655 |
+
# 2. channel x height x width: we should insert batch dimension at position 0,
|
| 656 |
+
# however, since both channel and batch dimension has same size 1, it is same to insert at position 1
|
| 657 |
+
# for simplicity, we insert a dimension of size 1 at position 1 for both cases
|
| 658 |
+
image = image.unsqueeze(1)
|
| 659 |
+
else:
|
| 660 |
+
# if it is a numpy array, it could have 2 possible shapes:
|
| 661 |
+
# 1. batch x height x width: insert channel dimension on last position
|
| 662 |
+
# 2. height x width x channel: insert batch dimension on first position
|
| 663 |
+
if image.shape[-1] == 1:
|
| 664 |
+
image = np.expand_dims(image, axis=0)
|
| 665 |
+
else:
|
| 666 |
+
image = np.expand_dims(image, axis=-1)
|
| 667 |
+
|
| 668 |
+
if isinstance(image, list) and isinstance(image[0], np.ndarray) and image[0].ndim == 4:
|
| 669 |
+
warnings.warn(
|
| 670 |
+
"Passing `image` as a list of 4d np.ndarray is deprecated."
|
| 671 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d np.ndarray",
|
| 672 |
+
FutureWarning,
|
| 673 |
+
)
|
| 674 |
+
image = np.concatenate(image, axis=0)
|
| 675 |
+
if isinstance(image, list) and isinstance(image[0], torch.Tensor) and image[0].ndim == 4:
|
| 676 |
+
warnings.warn(
|
| 677 |
+
"Passing `image` as a list of 4d torch.Tensor is deprecated."
|
| 678 |
+
"Please concatenate the list along the batch dimension and pass it as a single 4d torch.Tensor",
|
| 679 |
+
FutureWarning,
|
| 680 |
+
)
|
| 681 |
+
image = torch.cat(image, axis=0)
|
| 682 |
+
|
| 683 |
+
if not is_valid_image_imagelist(image):
|
| 684 |
+
raise ValueError(
|
| 685 |
+
f"Input is in incorrect format. Currently, we only support {', '.join(str(x) for x in supported_formats)}"
|
| 686 |
+
)
|
| 687 |
+
if not isinstance(image, list):
|
| 688 |
+
image = [image]
|
| 689 |
+
|
| 690 |
+
if isinstance(image[0], PIL.Image.Image):
|
| 691 |
+
if crops_coords is not None:
|
| 692 |
+
image = [i.crop(crops_coords) for i in image]
|
| 693 |
+
if self.config.do_resize:
|
| 694 |
+
height, width = self.get_default_height_width(image[0], height, width)
|
| 695 |
+
image = [self.resize(i, height, width, resize_mode=resize_mode) for i in image]
|
| 696 |
+
if self.config.do_convert_rgb:
|
| 697 |
+
image = [self.convert_to_rgb(i) for i in image]
|
| 698 |
+
elif self.config.do_convert_grayscale:
|
| 699 |
+
image = [self.convert_to_grayscale(i) for i in image]
|
| 700 |
+
image = self.pil_to_numpy(image) # to np
|
| 701 |
+
image = self.numpy_to_pt(image) # to pt
|
| 702 |
+
|
| 703 |
+
elif isinstance(image[0], np.ndarray):
|
| 704 |
+
image = np.concatenate(image, axis=0) if image[0].ndim == 4 else np.stack(image, axis=0)
|
| 705 |
+
|
| 706 |
+
image = self.numpy_to_pt(image)
|
| 707 |
+
|
| 708 |
+
height, width = self.get_default_height_width(image, height, width)
|
| 709 |
+
if self.config.do_resize:
|
| 710 |
+
image = self.resize(image, height, width)
|
| 711 |
+
|
| 712 |
+
elif isinstance(image[0], torch.Tensor):
|
| 713 |
+
image = torch.cat(image, axis=0) if image[0].ndim == 4 else torch.stack(image, axis=0)
|
| 714 |
+
|
| 715 |
+
if self.config.do_convert_grayscale and image.ndim == 3:
|
| 716 |
+
image = image.unsqueeze(1)
|
| 717 |
+
|
| 718 |
+
channel = image.shape[1]
|
| 719 |
+
# don't need any preprocess if the image is latents
|
| 720 |
+
if channel == self.config.vae_latent_channels:
|
| 721 |
+
return image
|
| 722 |
+
|
| 723 |
+
height, width = self.get_default_height_width(image, height, width)
|
| 724 |
+
if self.config.do_resize:
|
| 725 |
+
image = self.resize(image, height, width)
|
| 726 |
+
|
| 727 |
+
# expected range [0,1], normalize to [-1,1]
|
| 728 |
+
do_normalize = self.config.do_normalize
|
| 729 |
+
if do_normalize and image.min() < 0:
|
| 730 |
+
warnings.warn(
|
| 731 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
| 732 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{image.min()},{image.max()}]",
|
| 733 |
+
FutureWarning,
|
| 734 |
+
)
|
| 735 |
+
do_normalize = False
|
| 736 |
+
if do_normalize:
|
| 737 |
+
image = self.normalize(image)
|
| 738 |
+
|
| 739 |
+
if self.config.do_binarize:
|
| 740 |
+
image = self.binarize(image)
|
| 741 |
+
|
| 742 |
+
return image
|
| 743 |
+
|
| 744 |
+
def postprocess(
|
| 745 |
+
self,
|
| 746 |
+
image: torch.Tensor,
|
| 747 |
+
output_type: str = "pil",
|
| 748 |
+
do_denormalize: Optional[List[bool]] = None,
|
| 749 |
+
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
| 750 |
+
"""
|
| 751 |
+
Postprocess the image output from tensor to `output_type`.
|
| 752 |
+
|
| 753 |
+
Args:
|
| 754 |
+
image (`torch.Tensor`):
|
| 755 |
+
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
| 756 |
+
output_type (`str`, *optional*, defaults to `pil`):
|
| 757 |
+
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
| 758 |
+
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
| 759 |
+
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
| 760 |
+
`VaeImageProcessor` config.
|
| 761 |
+
|
| 762 |
+
Returns:
|
| 763 |
+
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
| 764 |
+
The postprocessed image.
|
| 765 |
+
"""
|
| 766 |
+
if not isinstance(image, torch.Tensor):
|
| 767 |
+
raise ValueError(
|
| 768 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
| 769 |
+
)
|
| 770 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
| 771 |
+
deprecation_message = (
|
| 772 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
| 773 |
+
"`pil`, `np`, `pt`, `latent`"
|
| 774 |
+
)
|
| 775 |
+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
| 776 |
+
output_type = "np"
|
| 777 |
+
|
| 778 |
+
if output_type == "latent":
|
| 779 |
+
return image
|
| 780 |
+
|
| 781 |
+
image = self._denormalize_conditionally(image, do_denormalize)
|
| 782 |
+
|
| 783 |
+
if output_type == "pt":
|
| 784 |
+
return image
|
| 785 |
+
|
| 786 |
+
image = self.pt_to_numpy(image)
|
| 787 |
+
|
| 788 |
+
if output_type == "np":
|
| 789 |
+
return image
|
| 790 |
+
|
| 791 |
+
if output_type == "pil":
|
| 792 |
+
return self.numpy_to_pil(image)
|
| 793 |
+
|
| 794 |
+
def apply_overlay(
|
| 795 |
+
self,
|
| 796 |
+
mask: PIL.Image.Image,
|
| 797 |
+
init_image: PIL.Image.Image,
|
| 798 |
+
image: PIL.Image.Image,
|
| 799 |
+
crop_coords: Optional[Tuple[int, int, int, int]] = None,
|
| 800 |
+
) -> PIL.Image.Image:
|
| 801 |
+
r"""
|
| 802 |
+
Applies an overlay of the mask and the inpainted image on the original image.
|
| 803 |
+
|
| 804 |
+
Args:
|
| 805 |
+
mask (`PIL.Image.Image`):
|
| 806 |
+
The mask image that highlights regions to overlay.
|
| 807 |
+
init_image (`PIL.Image.Image`):
|
| 808 |
+
The original image to which the overlay is applied.
|
| 809 |
+
image (`PIL.Image.Image`):
|
| 810 |
+
The image to overlay onto the original.
|
| 811 |
+
crop_coords (`Tuple[int, int, int, int]`, *optional*):
|
| 812 |
+
Coordinates to crop the image. If provided, the image will be cropped accordingly.
|
| 813 |
+
|
| 814 |
+
Returns:
|
| 815 |
+
`PIL.Image.Image`:
|
| 816 |
+
The final image with the overlay applied.
|
| 817 |
+
"""
|
| 818 |
+
|
| 819 |
+
width, height = init_image.width, init_image.height
|
| 820 |
+
|
| 821 |
+
init_image_masked = PIL.Image.new("RGBa", (width, height))
|
| 822 |
+
init_image_masked.paste(init_image.convert("RGBA").convert("RGBa"), mask=ImageOps.invert(mask.convert("L")))
|
| 823 |
+
|
| 824 |
+
init_image_masked = init_image_masked.convert("RGBA")
|
| 825 |
+
|
| 826 |
+
if crop_coords is not None:
|
| 827 |
+
x, y, x2, y2 = crop_coords
|
| 828 |
+
w = x2 - x
|
| 829 |
+
h = y2 - y
|
| 830 |
+
base_image = PIL.Image.new("RGBA", (width, height))
|
| 831 |
+
image = self.resize(image, height=h, width=w, resize_mode="crop")
|
| 832 |
+
base_image.paste(image, (x, y))
|
| 833 |
+
image = base_image.convert("RGB")
|
| 834 |
+
|
| 835 |
+
image = image.convert("RGBA")
|
| 836 |
+
image.alpha_composite(init_image_masked)
|
| 837 |
+
image = image.convert("RGB")
|
| 838 |
+
|
| 839 |
+
return image
|
| 840 |
+
|
| 841 |
+
|
| 842 |
+
class InpaintProcessor(ConfigMixin):
|
| 843 |
+
"""
|
| 844 |
+
Image processor for inpainting image and mask.
|
| 845 |
+
"""
|
| 846 |
+
|
| 847 |
+
config_name = CONFIG_NAME
|
| 848 |
+
|
| 849 |
+
@register_to_config
|
| 850 |
+
def __init__(
|
| 851 |
+
self,
|
| 852 |
+
do_resize: bool = True,
|
| 853 |
+
vae_scale_factor: int = 8,
|
| 854 |
+
vae_latent_channels: int = 4,
|
| 855 |
+
resample: str = "lanczos",
|
| 856 |
+
reducing_gap: int = None,
|
| 857 |
+
do_normalize: bool = True,
|
| 858 |
+
do_binarize: bool = False,
|
| 859 |
+
do_convert_grayscale: bool = False,
|
| 860 |
+
mask_do_normalize: bool = False,
|
| 861 |
+
mask_do_binarize: bool = True,
|
| 862 |
+
mask_do_convert_grayscale: bool = True,
|
| 863 |
+
):
|
| 864 |
+
super().__init__()
|
| 865 |
+
|
| 866 |
+
self._image_processor = VaeImageProcessor(
|
| 867 |
+
do_resize=do_resize,
|
| 868 |
+
vae_scale_factor=vae_scale_factor,
|
| 869 |
+
vae_latent_channels=vae_latent_channels,
|
| 870 |
+
resample=resample,
|
| 871 |
+
reducing_gap=reducing_gap,
|
| 872 |
+
do_normalize=do_normalize,
|
| 873 |
+
do_binarize=do_binarize,
|
| 874 |
+
do_convert_grayscale=do_convert_grayscale,
|
| 875 |
+
)
|
| 876 |
+
self._mask_processor = VaeImageProcessor(
|
| 877 |
+
do_resize=do_resize,
|
| 878 |
+
vae_scale_factor=vae_scale_factor,
|
| 879 |
+
vae_latent_channels=vae_latent_channels,
|
| 880 |
+
resample=resample,
|
| 881 |
+
reducing_gap=reducing_gap,
|
| 882 |
+
do_normalize=mask_do_normalize,
|
| 883 |
+
do_binarize=mask_do_binarize,
|
| 884 |
+
do_convert_grayscale=mask_do_convert_grayscale,
|
| 885 |
+
)
|
| 886 |
+
|
| 887 |
+
def preprocess(
|
| 888 |
+
self,
|
| 889 |
+
image: PIL.Image.Image,
|
| 890 |
+
mask: PIL.Image.Image = None,
|
| 891 |
+
height: int = None,
|
| 892 |
+
width: int = None,
|
| 893 |
+
padding_mask_crop: Optional[int] = None,
|
| 894 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 895 |
+
"""
|
| 896 |
+
Preprocess the image and mask.
|
| 897 |
+
"""
|
| 898 |
+
if mask is None and padding_mask_crop is not None:
|
| 899 |
+
raise ValueError("mask must be provided if padding_mask_crop is provided")
|
| 900 |
+
|
| 901 |
+
# if mask is None, same behavior as regular image processor
|
| 902 |
+
if mask is None:
|
| 903 |
+
return self._image_processor.preprocess(image, height=height, width=width)
|
| 904 |
+
|
| 905 |
+
if padding_mask_crop is not None:
|
| 906 |
+
crops_coords = self._image_processor.get_crop_region(mask, width, height, pad=padding_mask_crop)
|
| 907 |
+
resize_mode = "fill"
|
| 908 |
+
else:
|
| 909 |
+
crops_coords = None
|
| 910 |
+
resize_mode = "default"
|
| 911 |
+
|
| 912 |
+
processed_image = self._image_processor.preprocess(
|
| 913 |
+
image,
|
| 914 |
+
height=height,
|
| 915 |
+
width=width,
|
| 916 |
+
crops_coords=crops_coords,
|
| 917 |
+
resize_mode=resize_mode,
|
| 918 |
+
)
|
| 919 |
+
|
| 920 |
+
processed_mask = self._mask_processor.preprocess(
|
| 921 |
+
mask,
|
| 922 |
+
height=height,
|
| 923 |
+
width=width,
|
| 924 |
+
resize_mode=resize_mode,
|
| 925 |
+
crops_coords=crops_coords,
|
| 926 |
+
)
|
| 927 |
+
|
| 928 |
+
if crops_coords is not None:
|
| 929 |
+
postprocessing_kwargs = {
|
| 930 |
+
"crops_coords": crops_coords,
|
| 931 |
+
"original_image": image,
|
| 932 |
+
"original_mask": mask,
|
| 933 |
+
}
|
| 934 |
+
else:
|
| 935 |
+
postprocessing_kwargs = {
|
| 936 |
+
"crops_coords": None,
|
| 937 |
+
"original_image": None,
|
| 938 |
+
"original_mask": None,
|
| 939 |
+
}
|
| 940 |
+
|
| 941 |
+
return processed_image, processed_mask, postprocessing_kwargs
|
| 942 |
+
|
| 943 |
+
def postprocess(
|
| 944 |
+
self,
|
| 945 |
+
image: torch.Tensor,
|
| 946 |
+
output_type: str = "pil",
|
| 947 |
+
original_image: Optional[PIL.Image.Image] = None,
|
| 948 |
+
original_mask: Optional[PIL.Image.Image] = None,
|
| 949 |
+
crops_coords: Optional[Tuple[int, int, int, int]] = None,
|
| 950 |
+
) -> Tuple[PIL.Image.Image, PIL.Image.Image]:
|
| 951 |
+
"""
|
| 952 |
+
Postprocess the image, optionally apply mask overlay
|
| 953 |
+
"""
|
| 954 |
+
image = self._image_processor.postprocess(
|
| 955 |
+
image,
|
| 956 |
+
output_type=output_type,
|
| 957 |
+
)
|
| 958 |
+
# optionally apply the mask overlay
|
| 959 |
+
if crops_coords is not None and (original_image is None or original_mask is None):
|
| 960 |
+
raise ValueError("original_image and original_mask must be provided if crops_coords is provided")
|
| 961 |
+
|
| 962 |
+
elif crops_coords is not None and output_type != "pil":
|
| 963 |
+
raise ValueError("output_type must be 'pil' if crops_coords is provided")
|
| 964 |
+
|
| 965 |
+
elif crops_coords is not None:
|
| 966 |
+
image = [
|
| 967 |
+
self._image_processor.apply_overlay(original_mask, original_image, i, crops_coords) for i in image
|
| 968 |
+
]
|
| 969 |
+
|
| 970 |
+
return image
|
| 971 |
+
|
| 972 |
+
|
| 973 |
+
class VaeImageProcessorLDM3D(VaeImageProcessor):
|
| 974 |
+
"""
|
| 975 |
+
Image processor for VAE LDM3D.
|
| 976 |
+
|
| 977 |
+
Args:
|
| 978 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 979 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
| 980 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
| 981 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
| 982 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
| 983 |
+
Resampling filter to use when resizing the image.
|
| 984 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 985 |
+
Whether to normalize the image to [-1,1].
|
| 986 |
+
"""
|
| 987 |
+
|
| 988 |
+
config_name = CONFIG_NAME
|
| 989 |
+
|
| 990 |
+
@register_to_config
|
| 991 |
+
def __init__(
|
| 992 |
+
self,
|
| 993 |
+
do_resize: bool = True,
|
| 994 |
+
vae_scale_factor: int = 8,
|
| 995 |
+
resample: str = "lanczos",
|
| 996 |
+
do_normalize: bool = True,
|
| 997 |
+
):
|
| 998 |
+
super().__init__()
|
| 999 |
+
|
| 1000 |
+
@staticmethod
|
| 1001 |
+
def numpy_to_pil(images: np.ndarray) -> List[PIL.Image.Image]:
|
| 1002 |
+
r"""
|
| 1003 |
+
Convert a NumPy image or a batch of images to a list of PIL images.
|
| 1004 |
+
|
| 1005 |
+
Args:
|
| 1006 |
+
images (`np.ndarray`):
|
| 1007 |
+
The input NumPy array of images, which can be a single image or a batch.
|
| 1008 |
+
|
| 1009 |
+
Returns:
|
| 1010 |
+
`List[PIL.Image.Image]`:
|
| 1011 |
+
A list of PIL images converted from the input NumPy array.
|
| 1012 |
+
"""
|
| 1013 |
+
if images.ndim == 3:
|
| 1014 |
+
images = images[None, ...]
|
| 1015 |
+
images = (images * 255).round().astype("uint8")
|
| 1016 |
+
if images.shape[-1] == 1:
|
| 1017 |
+
# special case for grayscale (single channel) images
|
| 1018 |
+
pil_images = [Image.fromarray(image.squeeze(), mode="L") for image in images]
|
| 1019 |
+
else:
|
| 1020 |
+
pil_images = [Image.fromarray(image[:, :, :3]) for image in images]
|
| 1021 |
+
|
| 1022 |
+
return pil_images
|
| 1023 |
+
|
| 1024 |
+
@staticmethod
|
| 1025 |
+
def depth_pil_to_numpy(images: Union[List[PIL.Image.Image], PIL.Image.Image]) -> np.ndarray:
|
| 1026 |
+
r"""
|
| 1027 |
+
Convert a PIL image or a list of PIL images to NumPy arrays.
|
| 1028 |
+
|
| 1029 |
+
Args:
|
| 1030 |
+
images (`Union[List[PIL.Image.Image], PIL.Image.Image]`):
|
| 1031 |
+
The input image or list of images to be converted.
|
| 1032 |
+
|
| 1033 |
+
Returns:
|
| 1034 |
+
`np.ndarray`:
|
| 1035 |
+
A NumPy array of the converted images.
|
| 1036 |
+
"""
|
| 1037 |
+
if not isinstance(images, list):
|
| 1038 |
+
images = [images]
|
| 1039 |
+
|
| 1040 |
+
images = [np.array(image).astype(np.float32) / (2**16 - 1) for image in images]
|
| 1041 |
+
images = np.stack(images, axis=0)
|
| 1042 |
+
return images
|
| 1043 |
+
|
| 1044 |
+
@staticmethod
|
| 1045 |
+
def rgblike_to_depthmap(image: Union[np.ndarray, torch.Tensor]) -> Union[np.ndarray, torch.Tensor]:
|
| 1046 |
+
r"""
|
| 1047 |
+
Convert an RGB-like depth image to a depth map.
|
| 1048 |
+
|
| 1049 |
+
Args:
|
| 1050 |
+
image (`Union[np.ndarray, torch.Tensor]`):
|
| 1051 |
+
The RGB-like depth image to convert.
|
| 1052 |
+
|
| 1053 |
+
Returns:
|
| 1054 |
+
`Union[np.ndarray, torch.Tensor]`:
|
| 1055 |
+
The corresponding depth map.
|
| 1056 |
+
"""
|
| 1057 |
+
return image[:, :, 1] * 2**8 + image[:, :, 2]
|
| 1058 |
+
|
| 1059 |
+
def numpy_to_depth(self, images: np.ndarray) -> List[PIL.Image.Image]:
|
| 1060 |
+
r"""
|
| 1061 |
+
Convert a NumPy depth image or a batch of images to a list of PIL images.
|
| 1062 |
+
|
| 1063 |
+
Args:
|
| 1064 |
+
images (`np.ndarray`):
|
| 1065 |
+
The input NumPy array of depth images, which can be a single image or a batch.
|
| 1066 |
+
|
| 1067 |
+
Returns:
|
| 1068 |
+
`List[PIL.Image.Image]`:
|
| 1069 |
+
A list of PIL images converted from the input NumPy depth images.
|
| 1070 |
+
"""
|
| 1071 |
+
if images.ndim == 3:
|
| 1072 |
+
images = images[None, ...]
|
| 1073 |
+
images_depth = images[:, :, :, 3:]
|
| 1074 |
+
if images.shape[-1] == 6:
|
| 1075 |
+
images_depth = (images_depth * 255).round().astype("uint8")
|
| 1076 |
+
pil_images = [
|
| 1077 |
+
Image.fromarray(self.rgblike_to_depthmap(image_depth), mode="I;16") for image_depth in images_depth
|
| 1078 |
+
]
|
| 1079 |
+
elif images.shape[-1] == 4:
|
| 1080 |
+
images_depth = (images_depth * 65535.0).astype(np.uint16)
|
| 1081 |
+
pil_images = [Image.fromarray(image_depth, mode="I;16") for image_depth in images_depth]
|
| 1082 |
+
else:
|
| 1083 |
+
raise Exception("Not supported")
|
| 1084 |
+
|
| 1085 |
+
return pil_images
|
| 1086 |
+
|
| 1087 |
+
def postprocess(
|
| 1088 |
+
self,
|
| 1089 |
+
image: torch.Tensor,
|
| 1090 |
+
output_type: str = "pil",
|
| 1091 |
+
do_denormalize: Optional[List[bool]] = None,
|
| 1092 |
+
) -> Union[PIL.Image.Image, np.ndarray, torch.Tensor]:
|
| 1093 |
+
"""
|
| 1094 |
+
Postprocess the image output from tensor to `output_type`.
|
| 1095 |
+
|
| 1096 |
+
Args:
|
| 1097 |
+
image (`torch.Tensor`):
|
| 1098 |
+
The image input, should be a pytorch tensor with shape `B x C x H x W`.
|
| 1099 |
+
output_type (`str`, *optional*, defaults to `pil`):
|
| 1100 |
+
The output type of the image, can be one of `pil`, `np`, `pt`, `latent`.
|
| 1101 |
+
do_denormalize (`List[bool]`, *optional*, defaults to `None`):
|
| 1102 |
+
Whether to denormalize the image to [0,1]. If `None`, will use the value of `do_normalize` in the
|
| 1103 |
+
`VaeImageProcessor` config.
|
| 1104 |
+
|
| 1105 |
+
Returns:
|
| 1106 |
+
`PIL.Image.Image`, `np.ndarray` or `torch.Tensor`:
|
| 1107 |
+
The postprocessed image.
|
| 1108 |
+
"""
|
| 1109 |
+
if not isinstance(image, torch.Tensor):
|
| 1110 |
+
raise ValueError(
|
| 1111 |
+
f"Input for postprocessing is in incorrect format: {type(image)}. We only support pytorch tensor"
|
| 1112 |
+
)
|
| 1113 |
+
if output_type not in ["latent", "pt", "np", "pil"]:
|
| 1114 |
+
deprecation_message = (
|
| 1115 |
+
f"the output_type {output_type} is outdated and has been set to `np`. Please make sure to set it to one of these instead: "
|
| 1116 |
+
"`pil`, `np`, `pt`, `latent`"
|
| 1117 |
+
)
|
| 1118 |
+
deprecate("Unsupported output_type", "1.0.0", deprecation_message, standard_warn=False)
|
| 1119 |
+
output_type = "np"
|
| 1120 |
+
|
| 1121 |
+
image = self._denormalize_conditionally(image, do_denormalize)
|
| 1122 |
+
|
| 1123 |
+
image = self.pt_to_numpy(image)
|
| 1124 |
+
|
| 1125 |
+
if output_type == "np":
|
| 1126 |
+
if image.shape[-1] == 6:
|
| 1127 |
+
image_depth = np.stack([self.rgblike_to_depthmap(im[:, :, 3:]) for im in image], axis=0)
|
| 1128 |
+
else:
|
| 1129 |
+
image_depth = image[:, :, :, 3:]
|
| 1130 |
+
return image[:, :, :, :3], image_depth
|
| 1131 |
+
|
| 1132 |
+
if output_type == "pil":
|
| 1133 |
+
return self.numpy_to_pil(image), self.numpy_to_depth(image)
|
| 1134 |
+
else:
|
| 1135 |
+
raise Exception(f"This type {output_type} is not supported")
|
| 1136 |
+
|
| 1137 |
+
def preprocess(
|
| 1138 |
+
self,
|
| 1139 |
+
rgb: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
|
| 1140 |
+
depth: Union[torch.Tensor, PIL.Image.Image, np.ndarray],
|
| 1141 |
+
height: Optional[int] = None,
|
| 1142 |
+
width: Optional[int] = None,
|
| 1143 |
+
target_res: Optional[int] = None,
|
| 1144 |
+
) -> torch.Tensor:
|
| 1145 |
+
r"""
|
| 1146 |
+
Preprocess the image input. Accepted formats are PIL images, NumPy arrays, or PyTorch tensors.
|
| 1147 |
+
|
| 1148 |
+
Args:
|
| 1149 |
+
rgb (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
|
| 1150 |
+
The RGB input image, which can be a single image or a batch.
|
| 1151 |
+
depth (`Union[torch.Tensor, PIL.Image.Image, np.ndarray]`):
|
| 1152 |
+
The depth input image, which can be a single image or a batch.
|
| 1153 |
+
height (`Optional[int]`, *optional*, defaults to `None`):
|
| 1154 |
+
The desired height of the processed image. If `None`, defaults to the height of the input image.
|
| 1155 |
+
width (`Optional[int]`, *optional*, defaults to `None`):
|
| 1156 |
+
The desired width of the processed image. If `None`, defaults to the width of the input image.
|
| 1157 |
+
target_res (`Optional[int]`, *optional*, defaults to `None`):
|
| 1158 |
+
Target resolution for resizing the images. If specified, overrides height and width.
|
| 1159 |
+
|
| 1160 |
+
Returns:
|
| 1161 |
+
`Tuple[torch.Tensor, torch.Tensor]`:
|
| 1162 |
+
A tuple containing the processed RGB and depth images as PyTorch tensors.
|
| 1163 |
+
"""
|
| 1164 |
+
supported_formats = (PIL.Image.Image, np.ndarray, torch.Tensor)
|
| 1165 |
+
|
| 1166 |
+
# Expand the missing dimension for 3-dimensional pytorch tensor or numpy array that represents grayscale image
|
| 1167 |
+
if self.config.do_convert_grayscale and isinstance(rgb, (torch.Tensor, np.ndarray)) and rgb.ndim == 3:
|
| 1168 |
+
raise Exception("This is not yet supported")
|
| 1169 |
+
|
| 1170 |
+
if isinstance(rgb, supported_formats):
|
| 1171 |
+
rgb = [rgb]
|
| 1172 |
+
depth = [depth]
|
| 1173 |
+
elif not (isinstance(rgb, list) and all(isinstance(i, supported_formats) for i in rgb)):
|
| 1174 |
+
raise ValueError(
|
| 1175 |
+
f"Input is in incorrect format: {[type(i) for i in rgb]}. Currently, we only support {', '.join(supported_formats)}"
|
| 1176 |
+
)
|
| 1177 |
+
|
| 1178 |
+
if isinstance(rgb[0], PIL.Image.Image):
|
| 1179 |
+
if self.config.do_convert_rgb:
|
| 1180 |
+
raise Exception("This is not yet supported")
|
| 1181 |
+
# rgb = [self.convert_to_rgb(i) for i in rgb]
|
| 1182 |
+
# depth = [self.convert_to_depth(i) for i in depth] #TODO define convert_to_depth
|
| 1183 |
+
if self.config.do_resize or target_res:
|
| 1184 |
+
height, width = self.get_default_height_width(rgb[0], height, width) if not target_res else target_res
|
| 1185 |
+
rgb = [self.resize(i, height, width) for i in rgb]
|
| 1186 |
+
depth = [self.resize(i, height, width) for i in depth]
|
| 1187 |
+
rgb = self.pil_to_numpy(rgb) # to np
|
| 1188 |
+
rgb = self.numpy_to_pt(rgb) # to pt
|
| 1189 |
+
|
| 1190 |
+
depth = self.depth_pil_to_numpy(depth) # to np
|
| 1191 |
+
depth = self.numpy_to_pt(depth) # to pt
|
| 1192 |
+
|
| 1193 |
+
elif isinstance(rgb[0], np.ndarray):
|
| 1194 |
+
rgb = np.concatenate(rgb, axis=0) if rgb[0].ndim == 4 else np.stack(rgb, axis=0)
|
| 1195 |
+
rgb = self.numpy_to_pt(rgb)
|
| 1196 |
+
height, width = self.get_default_height_width(rgb, height, width)
|
| 1197 |
+
if self.config.do_resize:
|
| 1198 |
+
rgb = self.resize(rgb, height, width)
|
| 1199 |
+
|
| 1200 |
+
depth = np.concatenate(depth, axis=0) if rgb[0].ndim == 4 else np.stack(depth, axis=0)
|
| 1201 |
+
depth = self.numpy_to_pt(depth)
|
| 1202 |
+
height, width = self.get_default_height_width(depth, height, width)
|
| 1203 |
+
if self.config.do_resize:
|
| 1204 |
+
depth = self.resize(depth, height, width)
|
| 1205 |
+
|
| 1206 |
+
elif isinstance(rgb[0], torch.Tensor):
|
| 1207 |
+
raise Exception("This is not yet supported")
|
| 1208 |
+
# rgb = torch.cat(rgb, axis=0) if rgb[0].ndim == 4 else torch.stack(rgb, axis=0)
|
| 1209 |
+
|
| 1210 |
+
# if self.config.do_convert_grayscale and rgb.ndim == 3:
|
| 1211 |
+
# rgb = rgb.unsqueeze(1)
|
| 1212 |
+
|
| 1213 |
+
# channel = rgb.shape[1]
|
| 1214 |
+
|
| 1215 |
+
# height, width = self.get_default_height_width(rgb, height, width)
|
| 1216 |
+
# if self.config.do_resize:
|
| 1217 |
+
# rgb = self.resize(rgb, height, width)
|
| 1218 |
+
|
| 1219 |
+
# depth = torch.cat(depth, axis=0) if depth[0].ndim == 4 else torch.stack(depth, axis=0)
|
| 1220 |
+
|
| 1221 |
+
# if self.config.do_convert_grayscale and depth.ndim == 3:
|
| 1222 |
+
# depth = depth.unsqueeze(1)
|
| 1223 |
+
|
| 1224 |
+
# channel = depth.shape[1]
|
| 1225 |
+
# # don't need any preprocess if the image is latents
|
| 1226 |
+
# if depth == 4:
|
| 1227 |
+
# return rgb, depth
|
| 1228 |
+
|
| 1229 |
+
# height, width = self.get_default_height_width(depth, height, width)
|
| 1230 |
+
# if self.config.do_resize:
|
| 1231 |
+
# depth = self.resize(depth, height, width)
|
| 1232 |
+
# expected range [0,1], normalize to [-1,1]
|
| 1233 |
+
do_normalize = self.config.do_normalize
|
| 1234 |
+
if rgb.min() < 0 and do_normalize:
|
| 1235 |
+
warnings.warn(
|
| 1236 |
+
"Passing `image` as torch tensor with value range in [-1,1] is deprecated. The expected value range for image tensor is [0,1] "
|
| 1237 |
+
f"when passing as pytorch tensor or numpy Array. You passed `image` with value range [{rgb.min()},{rgb.max()}]",
|
| 1238 |
+
FutureWarning,
|
| 1239 |
+
)
|
| 1240 |
+
do_normalize = False
|
| 1241 |
+
|
| 1242 |
+
if do_normalize:
|
| 1243 |
+
rgb = self.normalize(rgb)
|
| 1244 |
+
depth = self.normalize(depth)
|
| 1245 |
+
|
| 1246 |
+
if self.config.do_binarize:
|
| 1247 |
+
rgb = self.binarize(rgb)
|
| 1248 |
+
depth = self.binarize(depth)
|
| 1249 |
+
|
| 1250 |
+
return rgb, depth
|
| 1251 |
+
|
| 1252 |
+
|
| 1253 |
+
class IPAdapterMaskProcessor(VaeImageProcessor):
|
| 1254 |
+
"""
|
| 1255 |
+
Image processor for IP Adapter image masks.
|
| 1256 |
+
|
| 1257 |
+
Args:
|
| 1258 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 1259 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`.
|
| 1260 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
| 1261 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
| 1262 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
| 1263 |
+
Resampling filter to use when resizing the image.
|
| 1264 |
+
do_normalize (`bool`, *optional*, defaults to `False`):
|
| 1265 |
+
Whether to normalize the image to [-1,1].
|
| 1266 |
+
do_binarize (`bool`, *optional*, defaults to `True`):
|
| 1267 |
+
Whether to binarize the image to 0/1.
|
| 1268 |
+
do_convert_grayscale (`bool`, *optional*, defaults to be `True`):
|
| 1269 |
+
Whether to convert the images to grayscale format.
|
| 1270 |
+
|
| 1271 |
+
"""
|
| 1272 |
+
|
| 1273 |
+
config_name = CONFIG_NAME
|
| 1274 |
+
|
| 1275 |
+
@register_to_config
|
| 1276 |
+
def __init__(
|
| 1277 |
+
self,
|
| 1278 |
+
do_resize: bool = True,
|
| 1279 |
+
vae_scale_factor: int = 8,
|
| 1280 |
+
resample: str = "lanczos",
|
| 1281 |
+
do_normalize: bool = False,
|
| 1282 |
+
do_binarize: bool = True,
|
| 1283 |
+
do_convert_grayscale: bool = True,
|
| 1284 |
+
):
|
| 1285 |
+
super().__init__(
|
| 1286 |
+
do_resize=do_resize,
|
| 1287 |
+
vae_scale_factor=vae_scale_factor,
|
| 1288 |
+
resample=resample,
|
| 1289 |
+
do_normalize=do_normalize,
|
| 1290 |
+
do_binarize=do_binarize,
|
| 1291 |
+
do_convert_grayscale=do_convert_grayscale,
|
| 1292 |
+
)
|
| 1293 |
+
|
| 1294 |
+
@staticmethod
|
| 1295 |
+
def downsample(mask: torch.Tensor, batch_size: int, num_queries: int, value_embed_dim: int):
|
| 1296 |
+
"""
|
| 1297 |
+
Downsamples the provided mask tensor to match the expected dimensions for scaled dot-product attention. If the
|
| 1298 |
+
aspect ratio of the mask does not match the aspect ratio of the output image, a warning is issued.
|
| 1299 |
+
|
| 1300 |
+
Args:
|
| 1301 |
+
mask (`torch.Tensor`):
|
| 1302 |
+
The input mask tensor generated with `IPAdapterMaskProcessor.preprocess()`.
|
| 1303 |
+
batch_size (`int`):
|
| 1304 |
+
The batch size.
|
| 1305 |
+
num_queries (`int`):
|
| 1306 |
+
The number of queries.
|
| 1307 |
+
value_embed_dim (`int`):
|
| 1308 |
+
The dimensionality of the value embeddings.
|
| 1309 |
+
|
| 1310 |
+
Returns:
|
| 1311 |
+
`torch.Tensor`:
|
| 1312 |
+
The downsampled mask tensor.
|
| 1313 |
+
|
| 1314 |
+
"""
|
| 1315 |
+
o_h = mask.shape[1]
|
| 1316 |
+
o_w = mask.shape[2]
|
| 1317 |
+
ratio = o_w / o_h
|
| 1318 |
+
mask_h = int(math.sqrt(num_queries / ratio))
|
| 1319 |
+
mask_h = int(mask_h) + int((num_queries % int(mask_h)) != 0)
|
| 1320 |
+
mask_w = num_queries // mask_h
|
| 1321 |
+
|
| 1322 |
+
mask_downsample = F.interpolate(mask.unsqueeze(0), size=(mask_h, mask_w), mode="bicubic").squeeze(0)
|
| 1323 |
+
|
| 1324 |
+
# Repeat batch_size times
|
| 1325 |
+
if mask_downsample.shape[0] < batch_size:
|
| 1326 |
+
mask_downsample = mask_downsample.repeat(batch_size, 1, 1)
|
| 1327 |
+
|
| 1328 |
+
mask_downsample = mask_downsample.view(mask_downsample.shape[0], -1)
|
| 1329 |
+
|
| 1330 |
+
downsampled_area = mask_h * mask_w
|
| 1331 |
+
# If the output image and the mask do not have the same aspect ratio, tensor shapes will not match
|
| 1332 |
+
# Pad tensor if downsampled_mask.shape[1] is smaller than num_queries
|
| 1333 |
+
if downsampled_area < num_queries:
|
| 1334 |
+
warnings.warn(
|
| 1335 |
+
"The aspect ratio of the mask does not match the aspect ratio of the output image. "
|
| 1336 |
+
"Please update your masks or adjust the output size for optimal performance.",
|
| 1337 |
+
UserWarning,
|
| 1338 |
+
)
|
| 1339 |
+
mask_downsample = F.pad(mask_downsample, (0, num_queries - mask_downsample.shape[1]), value=0.0)
|
| 1340 |
+
# Discard last embeddings if downsampled_mask.shape[1] is bigger than num_queries
|
| 1341 |
+
if downsampled_area > num_queries:
|
| 1342 |
+
warnings.warn(
|
| 1343 |
+
"The aspect ratio of the mask does not match the aspect ratio of the output image. "
|
| 1344 |
+
"Please update your masks or adjust the output size for optimal performance.",
|
| 1345 |
+
UserWarning,
|
| 1346 |
+
)
|
| 1347 |
+
mask_downsample = mask_downsample[:, :num_queries]
|
| 1348 |
+
|
| 1349 |
+
# Repeat last dimension to match SDPA output shape
|
| 1350 |
+
mask_downsample = mask_downsample.view(mask_downsample.shape[0], mask_downsample.shape[1], 1).repeat(
|
| 1351 |
+
1, 1, value_embed_dim
|
| 1352 |
+
)
|
| 1353 |
+
|
| 1354 |
+
return mask_downsample
|
| 1355 |
+
|
| 1356 |
+
|
| 1357 |
+
class PixArtImageProcessor(VaeImageProcessor):
|
| 1358 |
+
"""
|
| 1359 |
+
Image processor for PixArt image resize and crop.
|
| 1360 |
+
|
| 1361 |
+
Args:
|
| 1362 |
+
do_resize (`bool`, *optional*, defaults to `True`):
|
| 1363 |
+
Whether to downscale the image's (height, width) dimensions to multiples of `vae_scale_factor`. Can accept
|
| 1364 |
+
`height` and `width` arguments from [`image_processor.VaeImageProcessor.preprocess`] method.
|
| 1365 |
+
vae_scale_factor (`int`, *optional*, defaults to `8`):
|
| 1366 |
+
VAE scale factor. If `do_resize` is `True`, the image is automatically resized to multiples of this factor.
|
| 1367 |
+
resample (`str`, *optional*, defaults to `lanczos`):
|
| 1368 |
+
Resampling filter to use when resizing the image.
|
| 1369 |
+
do_normalize (`bool`, *optional*, defaults to `True`):
|
| 1370 |
+
Whether to normalize the image to [-1,1].
|
| 1371 |
+
do_binarize (`bool`, *optional*, defaults to `False`):
|
| 1372 |
+
Whether to binarize the image to 0/1.
|
| 1373 |
+
do_convert_rgb (`bool`, *optional*, defaults to be `False`):
|
| 1374 |
+
Whether to convert the images to RGB format.
|
| 1375 |
+
do_convert_grayscale (`bool`, *optional*, defaults to be `False`):
|
| 1376 |
+
Whether to convert the images to grayscale format.
|
| 1377 |
+
"""
|
| 1378 |
+
|
| 1379 |
+
@register_to_config
|
| 1380 |
+
def __init__(
|
| 1381 |
+
self,
|
| 1382 |
+
do_resize: bool = True,
|
| 1383 |
+
vae_scale_factor: int = 8,
|
| 1384 |
+
resample: str = "lanczos",
|
| 1385 |
+
do_normalize: bool = True,
|
| 1386 |
+
do_binarize: bool = False,
|
| 1387 |
+
do_convert_grayscale: bool = False,
|
| 1388 |
+
):
|
| 1389 |
+
super().__init__(
|
| 1390 |
+
do_resize=do_resize,
|
| 1391 |
+
vae_scale_factor=vae_scale_factor,
|
| 1392 |
+
resample=resample,
|
| 1393 |
+
do_normalize=do_normalize,
|
| 1394 |
+
do_binarize=do_binarize,
|
| 1395 |
+
do_convert_grayscale=do_convert_grayscale,
|
| 1396 |
+
)
|
| 1397 |
+
|
| 1398 |
+
@staticmethod
|
| 1399 |
+
def classify_height_width_bin(height: int, width: int, ratios: dict) -> Tuple[int, int]:
|
| 1400 |
+
r"""
|
| 1401 |
+
Returns the binned height and width based on the aspect ratio.
|
| 1402 |
+
|
| 1403 |
+
Args:
|
| 1404 |
+
height (`int`): The height of the image.
|
| 1405 |
+
width (`int`): The width of the image.
|
| 1406 |
+
ratios (`dict`): A dictionary where keys are aspect ratios and values are tuples of (height, width).
|
| 1407 |
+
|
| 1408 |
+
Returns:
|
| 1409 |
+
`Tuple[int, int]`: The closest binned height and width.
|
| 1410 |
+
"""
|
| 1411 |
+
ar = float(height / width)
|
| 1412 |
+
closest_ratio = min(ratios.keys(), key=lambda ratio: abs(float(ratio) - ar))
|
| 1413 |
+
default_hw = ratios[closest_ratio]
|
| 1414 |
+
return int(default_hw[0]), int(default_hw[1])
|
| 1415 |
+
|
| 1416 |
+
@staticmethod
|
| 1417 |
+
def resize_and_crop_tensor(samples: torch.Tensor, new_width: int, new_height: int) -> torch.Tensor:
|
| 1418 |
+
r"""
|
| 1419 |
+
Resizes and crops a tensor of images to the specified dimensions.
|
| 1420 |
+
|
| 1421 |
+
Args:
|
| 1422 |
+
samples (`torch.Tensor`):
|
| 1423 |
+
A tensor of shape (N, C, H, W) where N is the batch size, C is the number of channels, H is the height,
|
| 1424 |
+
and W is the width.
|
| 1425 |
+
new_width (`int`): The desired width of the output images.
|
| 1426 |
+
new_height (`int`): The desired height of the output images.
|
| 1427 |
+
|
| 1428 |
+
Returns:
|
| 1429 |
+
`torch.Tensor`: A tensor containing the resized and cropped images.
|
| 1430 |
+
"""
|
| 1431 |
+
orig_height, orig_width = samples.shape[2], samples.shape[3]
|
| 1432 |
+
|
| 1433 |
+
# Check if resizing is needed
|
| 1434 |
+
if orig_height != new_height or orig_width != new_width:
|
| 1435 |
+
ratio = max(new_height / orig_height, new_width / orig_width)
|
| 1436 |
+
resized_width = int(orig_width * ratio)
|
| 1437 |
+
resized_height = int(orig_height * ratio)
|
| 1438 |
+
|
| 1439 |
+
# Resize
|
| 1440 |
+
samples = F.interpolate(
|
| 1441 |
+
samples, size=(resized_height, resized_width), mode="bilinear", align_corners=False
|
| 1442 |
+
)
|
| 1443 |
+
|
| 1444 |
+
# Center Crop
|
| 1445 |
+
start_x = (resized_width - new_width) // 2
|
| 1446 |
+
end_x = start_x + new_width
|
| 1447 |
+
start_y = (resized_height - new_height) // 2
|
| 1448 |
+
end_y = start_y + new_height
|
| 1449 |
+
samples = samples[:, :, start_y:end_y, start_x:end_x]
|
| 1450 |
+
|
| 1451 |
+
return samples
|
pythonProject/.venv/Lib/site-packages/diffusers/optimization.py
ADDED
|
@@ -0,0 +1,361 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# coding=utf-8
|
| 2 |
+
# Copyright 2025 The HuggingFace Inc. team.
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
"""PyTorch optimization for diffusion models."""
|
| 16 |
+
|
| 17 |
+
import math
|
| 18 |
+
from enum import Enum
|
| 19 |
+
from typing import Optional, Union
|
| 20 |
+
|
| 21 |
+
from torch.optim import Optimizer
|
| 22 |
+
from torch.optim.lr_scheduler import LambdaLR
|
| 23 |
+
|
| 24 |
+
from .utils import logging
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
logger = logging.get_logger(__name__)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class SchedulerType(Enum):
|
| 31 |
+
LINEAR = "linear"
|
| 32 |
+
COSINE = "cosine"
|
| 33 |
+
COSINE_WITH_RESTARTS = "cosine_with_restarts"
|
| 34 |
+
POLYNOMIAL = "polynomial"
|
| 35 |
+
CONSTANT = "constant"
|
| 36 |
+
CONSTANT_WITH_WARMUP = "constant_with_warmup"
|
| 37 |
+
PIECEWISE_CONSTANT = "piecewise_constant"
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def get_constant_schedule(optimizer: Optimizer, last_epoch: int = -1) -> LambdaLR:
|
| 41 |
+
"""
|
| 42 |
+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
| 43 |
+
|
| 44 |
+
Args:
|
| 45 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 46 |
+
The optimizer for which to schedule the learning rate.
|
| 47 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 48 |
+
The index of the last epoch when resuming training.
|
| 49 |
+
|
| 50 |
+
Return:
|
| 51 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 52 |
+
"""
|
| 53 |
+
return LambdaLR(optimizer, lambda _: 1, last_epoch=last_epoch)
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
def get_constant_schedule_with_warmup(optimizer: Optimizer, num_warmup_steps: int, last_epoch: int = -1) -> LambdaLR:
|
| 57 |
+
"""
|
| 58 |
+
Create a schedule with a constant learning rate preceded by a warmup period during which the learning rate
|
| 59 |
+
increases linearly between 0 and the initial lr set in the optimizer.
|
| 60 |
+
|
| 61 |
+
Args:
|
| 62 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 63 |
+
The optimizer for which to schedule the learning rate.
|
| 64 |
+
num_warmup_steps (`int`):
|
| 65 |
+
The number of steps for the warmup phase.
|
| 66 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 67 |
+
The index of the last epoch when resuming training.
|
| 68 |
+
|
| 69 |
+
Return:
|
| 70 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 71 |
+
"""
|
| 72 |
+
|
| 73 |
+
def lr_lambda(current_step: int):
|
| 74 |
+
if current_step < num_warmup_steps:
|
| 75 |
+
return float(current_step) / float(max(1.0, num_warmup_steps))
|
| 76 |
+
return 1.0
|
| 77 |
+
|
| 78 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch=last_epoch)
|
| 79 |
+
|
| 80 |
+
|
| 81 |
+
def get_piecewise_constant_schedule(optimizer: Optimizer, step_rules: str, last_epoch: int = -1) -> LambdaLR:
|
| 82 |
+
"""
|
| 83 |
+
Create a schedule with a constant learning rate, using the learning rate set in optimizer.
|
| 84 |
+
|
| 85 |
+
Args:
|
| 86 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 87 |
+
The optimizer for which to schedule the learning rate.
|
| 88 |
+
step_rules (`string`):
|
| 89 |
+
The rules for the learning rate. ex: rule_steps="1:10,0.1:20,0.01:30,0.005" it means that the learning rate
|
| 90 |
+
if multiple 1 for the first 10 steps, multiple 0.1 for the next 20 steps, multiple 0.01 for the next 30
|
| 91 |
+
steps and multiple 0.005 for the other steps.
|
| 92 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 93 |
+
The index of the last epoch when resuming training.
|
| 94 |
+
|
| 95 |
+
Return:
|
| 96 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 97 |
+
"""
|
| 98 |
+
|
| 99 |
+
rules_dict = {}
|
| 100 |
+
rule_list = step_rules.split(",")
|
| 101 |
+
for rule_str in rule_list[:-1]:
|
| 102 |
+
value_str, steps_str = rule_str.split(":")
|
| 103 |
+
steps = int(steps_str)
|
| 104 |
+
value = float(value_str)
|
| 105 |
+
rules_dict[steps] = value
|
| 106 |
+
last_lr_multiple = float(rule_list[-1])
|
| 107 |
+
|
| 108 |
+
def create_rules_function(rules_dict, last_lr_multiple):
|
| 109 |
+
def rule_func(steps: int) -> float:
|
| 110 |
+
sorted_steps = sorted(rules_dict.keys())
|
| 111 |
+
for i, sorted_step in enumerate(sorted_steps):
|
| 112 |
+
if steps < sorted_step:
|
| 113 |
+
return rules_dict[sorted_steps[i]]
|
| 114 |
+
return last_lr_multiple
|
| 115 |
+
|
| 116 |
+
return rule_func
|
| 117 |
+
|
| 118 |
+
rules_func = create_rules_function(rules_dict, last_lr_multiple)
|
| 119 |
+
|
| 120 |
+
return LambdaLR(optimizer, rules_func, last_epoch=last_epoch)
|
| 121 |
+
|
| 122 |
+
|
| 123 |
+
def get_linear_schedule_with_warmup(
|
| 124 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, last_epoch: int = -1
|
| 125 |
+
) -> LambdaLR:
|
| 126 |
+
"""
|
| 127 |
+
Create a schedule with a learning rate that decreases linearly from the initial lr set in the optimizer to 0, after
|
| 128 |
+
a warmup period during which it increases linearly from 0 to the initial lr set in the optimizer.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 132 |
+
The optimizer for which to schedule the learning rate.
|
| 133 |
+
num_warmup_steps (`int`):
|
| 134 |
+
The number of steps for the warmup phase.
|
| 135 |
+
num_training_steps (`int`):
|
| 136 |
+
The total number of training steps.
|
| 137 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 138 |
+
The index of the last epoch when resuming training.
|
| 139 |
+
|
| 140 |
+
Return:
|
| 141 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 142 |
+
"""
|
| 143 |
+
|
| 144 |
+
def lr_lambda(current_step: int):
|
| 145 |
+
if current_step < num_warmup_steps:
|
| 146 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 147 |
+
return max(
|
| 148 |
+
0.0, float(num_training_steps - current_step) / float(max(1, num_training_steps - num_warmup_steps))
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 152 |
+
|
| 153 |
+
|
| 154 |
+
def get_cosine_schedule_with_warmup(
|
| 155 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: float = 0.5, last_epoch: int = -1
|
| 156 |
+
) -> LambdaLR:
|
| 157 |
+
"""
|
| 158 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
| 159 |
+
initial lr set in the optimizer to 0, after a warmup period during which it increases linearly between 0 and the
|
| 160 |
+
initial lr set in the optimizer.
|
| 161 |
+
|
| 162 |
+
Args:
|
| 163 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 164 |
+
The optimizer for which to schedule the learning rate.
|
| 165 |
+
num_warmup_steps (`int`):
|
| 166 |
+
The number of steps for the warmup phase.
|
| 167 |
+
num_training_steps (`int`):
|
| 168 |
+
The total number of training steps.
|
| 169 |
+
num_periods (`float`, *optional*, defaults to 0.5):
|
| 170 |
+
The number of periods of the cosine function in a schedule (the default is to just decrease from the max
|
| 171 |
+
value to 0 following a half-cosine).
|
| 172 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 173 |
+
The index of the last epoch when resuming training.
|
| 174 |
+
|
| 175 |
+
Return:
|
| 176 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 177 |
+
"""
|
| 178 |
+
|
| 179 |
+
def lr_lambda(current_step):
|
| 180 |
+
if current_step < num_warmup_steps:
|
| 181 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 182 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 183 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * float(num_cycles) * 2.0 * progress)))
|
| 184 |
+
|
| 185 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
def get_cosine_with_hard_restarts_schedule_with_warmup(
|
| 189 |
+
optimizer: Optimizer, num_warmup_steps: int, num_training_steps: int, num_cycles: int = 1, last_epoch: int = -1
|
| 190 |
+
) -> LambdaLR:
|
| 191 |
+
"""
|
| 192 |
+
Create a schedule with a learning rate that decreases following the values of the cosine function between the
|
| 193 |
+
initial lr set in the optimizer to 0, with several hard restarts, after a warmup period during which it increases
|
| 194 |
+
linearly between 0 and the initial lr set in the optimizer.
|
| 195 |
+
|
| 196 |
+
Args:
|
| 197 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 198 |
+
The optimizer for which to schedule the learning rate.
|
| 199 |
+
num_warmup_steps (`int`):
|
| 200 |
+
The number of steps for the warmup phase.
|
| 201 |
+
num_training_steps (`int`):
|
| 202 |
+
The total number of training steps.
|
| 203 |
+
num_cycles (`int`, *optional*, defaults to 1):
|
| 204 |
+
The number of hard restarts to use.
|
| 205 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 206 |
+
The index of the last epoch when resuming training.
|
| 207 |
+
|
| 208 |
+
Return:
|
| 209 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 210 |
+
"""
|
| 211 |
+
|
| 212 |
+
def lr_lambda(current_step):
|
| 213 |
+
if current_step < num_warmup_steps:
|
| 214 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 215 |
+
progress = float(current_step - num_warmup_steps) / float(max(1, num_training_steps - num_warmup_steps))
|
| 216 |
+
if progress >= 1.0:
|
| 217 |
+
return 0.0
|
| 218 |
+
return max(0.0, 0.5 * (1.0 + math.cos(math.pi * ((float(num_cycles) * progress) % 1.0))))
|
| 219 |
+
|
| 220 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def get_polynomial_decay_schedule_with_warmup(
|
| 224 |
+
optimizer: Optimizer,
|
| 225 |
+
num_warmup_steps: int,
|
| 226 |
+
num_training_steps: int,
|
| 227 |
+
lr_end: float = 1e-7,
|
| 228 |
+
power: float = 1.0,
|
| 229 |
+
last_epoch: int = -1,
|
| 230 |
+
) -> LambdaLR:
|
| 231 |
+
"""
|
| 232 |
+
Create a schedule with a learning rate that decreases as a polynomial decay from the initial lr set in the
|
| 233 |
+
optimizer to end lr defined by *lr_end*, after a warmup period during which it increases linearly from 0 to the
|
| 234 |
+
initial lr set in the optimizer.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
optimizer ([`~torch.optim.Optimizer`]):
|
| 238 |
+
The optimizer for which to schedule the learning rate.
|
| 239 |
+
num_warmup_steps (`int`):
|
| 240 |
+
The number of steps for the warmup phase.
|
| 241 |
+
num_training_steps (`int`):
|
| 242 |
+
The total number of training steps.
|
| 243 |
+
lr_end (`float`, *optional*, defaults to 1e-7):
|
| 244 |
+
The end LR.
|
| 245 |
+
power (`float`, *optional*, defaults to 1.0):
|
| 246 |
+
Power factor.
|
| 247 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 248 |
+
The index of the last epoch when resuming training.
|
| 249 |
+
|
| 250 |
+
Note: *power* defaults to 1.0 as in the fairseq implementation, which in turn is based on the original BERT
|
| 251 |
+
implementation at
|
| 252 |
+
https://github.com/google-research/bert/blob/f39e881b169b9d53bea03d2d341b31707a6c052b/optimization.py#L37
|
| 253 |
+
|
| 254 |
+
Return:
|
| 255 |
+
`torch.optim.lr_scheduler.LambdaLR` with the appropriate schedule.
|
| 256 |
+
|
| 257 |
+
"""
|
| 258 |
+
|
| 259 |
+
lr_init = optimizer.defaults["lr"]
|
| 260 |
+
if not (lr_init > lr_end):
|
| 261 |
+
raise ValueError(f"lr_end ({lr_end}) must be smaller than initial lr ({lr_init})")
|
| 262 |
+
|
| 263 |
+
def lr_lambda(current_step: int):
|
| 264 |
+
if current_step < num_warmup_steps:
|
| 265 |
+
return float(current_step) / float(max(1, num_warmup_steps))
|
| 266 |
+
elif current_step > num_training_steps:
|
| 267 |
+
return lr_end / lr_init # as LambdaLR multiplies by lr_init
|
| 268 |
+
else:
|
| 269 |
+
lr_range = lr_init - lr_end
|
| 270 |
+
decay_steps = num_training_steps - num_warmup_steps
|
| 271 |
+
pct_remaining = 1 - (current_step - num_warmup_steps) / decay_steps
|
| 272 |
+
decay = lr_range * pct_remaining**power + lr_end
|
| 273 |
+
return decay / lr_init # as LambdaLR multiplies by lr_init
|
| 274 |
+
|
| 275 |
+
return LambdaLR(optimizer, lr_lambda, last_epoch)
|
| 276 |
+
|
| 277 |
+
|
| 278 |
+
TYPE_TO_SCHEDULER_FUNCTION = {
|
| 279 |
+
SchedulerType.LINEAR: get_linear_schedule_with_warmup,
|
| 280 |
+
SchedulerType.COSINE: get_cosine_schedule_with_warmup,
|
| 281 |
+
SchedulerType.COSINE_WITH_RESTARTS: get_cosine_with_hard_restarts_schedule_with_warmup,
|
| 282 |
+
SchedulerType.POLYNOMIAL: get_polynomial_decay_schedule_with_warmup,
|
| 283 |
+
SchedulerType.CONSTANT: get_constant_schedule,
|
| 284 |
+
SchedulerType.CONSTANT_WITH_WARMUP: get_constant_schedule_with_warmup,
|
| 285 |
+
SchedulerType.PIECEWISE_CONSTANT: get_piecewise_constant_schedule,
|
| 286 |
+
}
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def get_scheduler(
|
| 290 |
+
name: Union[str, SchedulerType],
|
| 291 |
+
optimizer: Optimizer,
|
| 292 |
+
step_rules: Optional[str] = None,
|
| 293 |
+
num_warmup_steps: Optional[int] = None,
|
| 294 |
+
num_training_steps: Optional[int] = None,
|
| 295 |
+
num_cycles: int = 1,
|
| 296 |
+
power: float = 1.0,
|
| 297 |
+
last_epoch: int = -1,
|
| 298 |
+
) -> LambdaLR:
|
| 299 |
+
"""
|
| 300 |
+
Unified API to get any scheduler from its name.
|
| 301 |
+
|
| 302 |
+
Args:
|
| 303 |
+
name (`str` or `SchedulerType`):
|
| 304 |
+
The name of the scheduler to use.
|
| 305 |
+
optimizer (`torch.optim.Optimizer`):
|
| 306 |
+
The optimizer that will be used during training.
|
| 307 |
+
step_rules (`str`, *optional*):
|
| 308 |
+
A string representing the step rules to use. This is only used by the `PIECEWISE_CONSTANT` scheduler.
|
| 309 |
+
num_warmup_steps (`int`, *optional*):
|
| 310 |
+
The number of warmup steps to do. This is not required by all schedulers (hence the argument being
|
| 311 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
| 312 |
+
num_training_steps (`int``, *optional*):
|
| 313 |
+
The number of training steps to do. This is not required by all schedulers (hence the argument being
|
| 314 |
+
optional), the function will raise an error if it's unset and the scheduler type requires it.
|
| 315 |
+
num_cycles (`int`, *optional*):
|
| 316 |
+
The number of hard restarts used in `COSINE_WITH_RESTARTS` scheduler.
|
| 317 |
+
power (`float`, *optional*, defaults to 1.0):
|
| 318 |
+
Power factor. See `POLYNOMIAL` scheduler
|
| 319 |
+
last_epoch (`int`, *optional*, defaults to -1):
|
| 320 |
+
The index of the last epoch when resuming training.
|
| 321 |
+
"""
|
| 322 |
+
name = SchedulerType(name)
|
| 323 |
+
schedule_func = TYPE_TO_SCHEDULER_FUNCTION[name]
|
| 324 |
+
if name == SchedulerType.CONSTANT:
|
| 325 |
+
return schedule_func(optimizer, last_epoch=last_epoch)
|
| 326 |
+
|
| 327 |
+
if name == SchedulerType.PIECEWISE_CONSTANT:
|
| 328 |
+
return schedule_func(optimizer, step_rules=step_rules, last_epoch=last_epoch)
|
| 329 |
+
|
| 330 |
+
# All other schedulers require `num_warmup_steps`
|
| 331 |
+
if num_warmup_steps is None:
|
| 332 |
+
raise ValueError(f"{name} requires `num_warmup_steps`, please provide that argument.")
|
| 333 |
+
|
| 334 |
+
if name == SchedulerType.CONSTANT_WITH_WARMUP:
|
| 335 |
+
return schedule_func(optimizer, num_warmup_steps=num_warmup_steps, last_epoch=last_epoch)
|
| 336 |
+
|
| 337 |
+
# All other schedulers require `num_training_steps`
|
| 338 |
+
if num_training_steps is None:
|
| 339 |
+
raise ValueError(f"{name} requires `num_training_steps`, please provide that argument.")
|
| 340 |
+
|
| 341 |
+
if name == SchedulerType.COSINE_WITH_RESTARTS:
|
| 342 |
+
return schedule_func(
|
| 343 |
+
optimizer,
|
| 344 |
+
num_warmup_steps=num_warmup_steps,
|
| 345 |
+
num_training_steps=num_training_steps,
|
| 346 |
+
num_cycles=num_cycles,
|
| 347 |
+
last_epoch=last_epoch,
|
| 348 |
+
)
|
| 349 |
+
|
| 350 |
+
if name == SchedulerType.POLYNOMIAL:
|
| 351 |
+
return schedule_func(
|
| 352 |
+
optimizer,
|
| 353 |
+
num_warmup_steps=num_warmup_steps,
|
| 354 |
+
num_training_steps=num_training_steps,
|
| 355 |
+
power=power,
|
| 356 |
+
last_epoch=last_epoch,
|
| 357 |
+
)
|
| 358 |
+
|
| 359 |
+
return schedule_func(
|
| 360 |
+
optimizer, num_warmup_steps=num_warmup_steps, num_training_steps=num_training_steps, last_epoch=last_epoch
|
| 361 |
+
)
|
pythonProject/.venv/Lib/site-packages/diffusers/pipelines/marigold/__pycache__/pipeline_marigold_normals.cpython-310.pyc
ADDED
|
Binary file (22.6 kB). View file
|
|
|
pythonProject/.venv/Lib/site-packages/diffusers/py.typed
ADDED
|
File without changes
|
pythonProject/.venv/Lib/site-packages/diffusers/training_utils.py
ADDED
|
@@ -0,0 +1,730 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import contextlib
|
| 2 |
+
import copy
|
| 3 |
+
import gc
|
| 4 |
+
import math
|
| 5 |
+
import random
|
| 6 |
+
import re
|
| 7 |
+
import warnings
|
| 8 |
+
from contextlib import contextmanager
|
| 9 |
+
from typing import Any, Dict, Iterable, List, Optional, Tuple, Union
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
import torch
|
| 13 |
+
|
| 14 |
+
from .models import UNet2DConditionModel
|
| 15 |
+
from .pipelines import DiffusionPipeline
|
| 16 |
+
from .schedulers import SchedulerMixin
|
| 17 |
+
from .utils import (
|
| 18 |
+
convert_state_dict_to_diffusers,
|
| 19 |
+
convert_state_dict_to_peft,
|
| 20 |
+
deprecate,
|
| 21 |
+
is_peft_available,
|
| 22 |
+
is_torch_npu_available,
|
| 23 |
+
is_torchvision_available,
|
| 24 |
+
is_transformers_available,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
if is_transformers_available():
|
| 29 |
+
import transformers
|
| 30 |
+
|
| 31 |
+
if transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
| 32 |
+
import deepspeed
|
| 33 |
+
|
| 34 |
+
if is_peft_available():
|
| 35 |
+
from peft import set_peft_model_state_dict
|
| 36 |
+
|
| 37 |
+
if is_torchvision_available():
|
| 38 |
+
from torchvision import transforms
|
| 39 |
+
|
| 40 |
+
if is_torch_npu_available():
|
| 41 |
+
import torch_npu # noqa: F401
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
def set_seed(seed: int):
|
| 45 |
+
"""
|
| 46 |
+
Helper function for reproducible behavior to set the seed in `random`, `numpy`, `torch`.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
seed (`int`): The seed to set.
|
| 50 |
+
|
| 51 |
+
Returns:
|
| 52 |
+
`None`
|
| 53 |
+
"""
|
| 54 |
+
random.seed(seed)
|
| 55 |
+
np.random.seed(seed)
|
| 56 |
+
torch.manual_seed(seed)
|
| 57 |
+
if is_torch_npu_available():
|
| 58 |
+
torch.npu.manual_seed_all(seed)
|
| 59 |
+
else:
|
| 60 |
+
torch.cuda.manual_seed_all(seed)
|
| 61 |
+
# ^^ safe to call this function even if cuda is not available
|
| 62 |
+
|
| 63 |
+
|
| 64 |
+
def compute_snr(noise_scheduler, timesteps):
|
| 65 |
+
"""
|
| 66 |
+
Computes SNR as per
|
| 67 |
+
https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L847-L849
|
| 68 |
+
for the given timesteps using the provided noise scheduler.
|
| 69 |
+
|
| 70 |
+
Args:
|
| 71 |
+
noise_scheduler (`NoiseScheduler`):
|
| 72 |
+
An object containing the noise schedule parameters, specifically `alphas_cumprod`, which is used to compute
|
| 73 |
+
the SNR values.
|
| 74 |
+
timesteps (`torch.Tensor`):
|
| 75 |
+
A tensor of timesteps for which the SNR is computed.
|
| 76 |
+
|
| 77 |
+
Returns:
|
| 78 |
+
`torch.Tensor`: A tensor containing the computed SNR values for each timestep.
|
| 79 |
+
"""
|
| 80 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod
|
| 81 |
+
sqrt_alphas_cumprod = alphas_cumprod**0.5
|
| 82 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
| 83 |
+
|
| 84 |
+
# Expand the tensors.
|
| 85 |
+
# Adapted from https://github.com/TiankaiHang/Min-SNR-Diffusion-Training/blob/521b624bd70c67cee4bdf49225915f5945a872e3/guided_diffusion/gaussian_diffusion.py#L1026
|
| 86 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 87 |
+
while len(sqrt_alphas_cumprod.shape) < len(timesteps.shape):
|
| 88 |
+
sqrt_alphas_cumprod = sqrt_alphas_cumprod[..., None]
|
| 89 |
+
alpha = sqrt_alphas_cumprod.expand(timesteps.shape)
|
| 90 |
+
|
| 91 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod.to(device=timesteps.device)[timesteps].float()
|
| 92 |
+
while len(sqrt_one_minus_alphas_cumprod.shape) < len(timesteps.shape):
|
| 93 |
+
sqrt_one_minus_alphas_cumprod = sqrt_one_minus_alphas_cumprod[..., None]
|
| 94 |
+
sigma = sqrt_one_minus_alphas_cumprod.expand(timesteps.shape)
|
| 95 |
+
|
| 96 |
+
# Compute SNR.
|
| 97 |
+
snr = (alpha / sigma) ** 2
|
| 98 |
+
return snr
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
def resolve_interpolation_mode(interpolation_type: str):
|
| 102 |
+
"""
|
| 103 |
+
Maps a string describing an interpolation function to the corresponding torchvision `InterpolationMode` enum. The
|
| 104 |
+
full list of supported enums is documented at
|
| 105 |
+
https://pytorch.org/vision/0.9/transforms.html#torchvision.transforms.functional.InterpolationMode.
|
| 106 |
+
|
| 107 |
+
Args:
|
| 108 |
+
interpolation_type (`str`):
|
| 109 |
+
A string describing an interpolation method. Currently, `bilinear`, `bicubic`, `box`, `nearest`,
|
| 110 |
+
`nearest_exact`, `hamming`, and `lanczos` are supported, corresponding to the supported interpolation modes
|
| 111 |
+
in torchvision.
|
| 112 |
+
|
| 113 |
+
Returns:
|
| 114 |
+
`torchvision.transforms.InterpolationMode`: an `InterpolationMode` enum used by torchvision's `resize`
|
| 115 |
+
transform.
|
| 116 |
+
"""
|
| 117 |
+
if not is_torchvision_available():
|
| 118 |
+
raise ImportError(
|
| 119 |
+
"Please make sure to install `torchvision` to be able to use the `resolve_interpolation_mode()` function."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if interpolation_type == "bilinear":
|
| 123 |
+
interpolation_mode = transforms.InterpolationMode.BILINEAR
|
| 124 |
+
elif interpolation_type == "bicubic":
|
| 125 |
+
interpolation_mode = transforms.InterpolationMode.BICUBIC
|
| 126 |
+
elif interpolation_type == "box":
|
| 127 |
+
interpolation_mode = transforms.InterpolationMode.BOX
|
| 128 |
+
elif interpolation_type == "nearest":
|
| 129 |
+
interpolation_mode = transforms.InterpolationMode.NEAREST
|
| 130 |
+
elif interpolation_type == "nearest_exact":
|
| 131 |
+
interpolation_mode = transforms.InterpolationMode.NEAREST_EXACT
|
| 132 |
+
elif interpolation_type == "hamming":
|
| 133 |
+
interpolation_mode = transforms.InterpolationMode.HAMMING
|
| 134 |
+
elif interpolation_type == "lanczos":
|
| 135 |
+
interpolation_mode = transforms.InterpolationMode.LANCZOS
|
| 136 |
+
else:
|
| 137 |
+
raise ValueError(
|
| 138 |
+
f"The given interpolation mode {interpolation_type} is not supported. Currently supported interpolation"
|
| 139 |
+
f" modes are `bilinear`, `bicubic`, `box`, `nearest`, `nearest_exact`, `hamming`, and `lanczos`."
|
| 140 |
+
)
|
| 141 |
+
|
| 142 |
+
return interpolation_mode
|
| 143 |
+
|
| 144 |
+
|
| 145 |
+
def compute_dream_and_update_latents(
|
| 146 |
+
unet: UNet2DConditionModel,
|
| 147 |
+
noise_scheduler: SchedulerMixin,
|
| 148 |
+
timesteps: torch.Tensor,
|
| 149 |
+
noise: torch.Tensor,
|
| 150 |
+
noisy_latents: torch.Tensor,
|
| 151 |
+
target: torch.Tensor,
|
| 152 |
+
encoder_hidden_states: torch.Tensor,
|
| 153 |
+
dream_detail_preservation: float = 1.0,
|
| 154 |
+
) -> Tuple[Optional[torch.Tensor], Optional[torch.Tensor]]:
|
| 155 |
+
"""
|
| 156 |
+
Implements "DREAM (Diffusion Rectification and Estimation-Adaptive Models)" from
|
| 157 |
+
https://huggingface.co/papers/2312.00210. DREAM helps align training with sampling to help training be more
|
| 158 |
+
efficient and accurate at the cost of an extra forward step without gradients.
|
| 159 |
+
|
| 160 |
+
Args:
|
| 161 |
+
`unet`: The state unet to use to make a prediction.
|
| 162 |
+
`noise_scheduler`: The noise scheduler used to add noise for the given timestep.
|
| 163 |
+
`timesteps`: The timesteps for the noise_scheduler to user.
|
| 164 |
+
`noise`: A tensor of noise in the shape of noisy_latents.
|
| 165 |
+
`noisy_latents`: Previously noise latents from the training loop.
|
| 166 |
+
`target`: The ground-truth tensor to predict after eps is removed.
|
| 167 |
+
`encoder_hidden_states`: Text embeddings from the text model.
|
| 168 |
+
`dream_detail_preservation`: A float value that indicates detail preservation level.
|
| 169 |
+
See reference.
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
`tuple[torch.Tensor, torch.Tensor]`: Adjusted noisy_latents and target.
|
| 173 |
+
"""
|
| 174 |
+
alphas_cumprod = noise_scheduler.alphas_cumprod.to(timesteps.device)[timesteps, None, None, None]
|
| 175 |
+
sqrt_one_minus_alphas_cumprod = (1.0 - alphas_cumprod) ** 0.5
|
| 176 |
+
|
| 177 |
+
# The paper uses lambda = sqrt(1 - alpha) ** p, with p = 1 in their experiments.
|
| 178 |
+
dream_lambda = sqrt_one_minus_alphas_cumprod**dream_detail_preservation
|
| 179 |
+
|
| 180 |
+
pred = None
|
| 181 |
+
with torch.no_grad():
|
| 182 |
+
pred = unet(noisy_latents, timesteps, encoder_hidden_states).sample
|
| 183 |
+
|
| 184 |
+
_noisy_latents, _target = (None, None)
|
| 185 |
+
if noise_scheduler.config.prediction_type == "epsilon":
|
| 186 |
+
predicted_noise = pred
|
| 187 |
+
delta_noise = (noise - predicted_noise).detach()
|
| 188 |
+
delta_noise.mul_(dream_lambda)
|
| 189 |
+
_noisy_latents = noisy_latents.add(sqrt_one_minus_alphas_cumprod * delta_noise)
|
| 190 |
+
_target = target.add(delta_noise)
|
| 191 |
+
elif noise_scheduler.config.prediction_type == "v_prediction":
|
| 192 |
+
raise NotImplementedError("DREAM has not been implemented for v-prediction")
|
| 193 |
+
else:
|
| 194 |
+
raise ValueError(f"Unknown prediction type {noise_scheduler.config.prediction_type}")
|
| 195 |
+
|
| 196 |
+
return _noisy_latents, _target
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
def unet_lora_state_dict(unet: UNet2DConditionModel) -> Dict[str, torch.Tensor]:
|
| 200 |
+
r"""
|
| 201 |
+
Returns:
|
| 202 |
+
A state dict containing just the LoRA parameters.
|
| 203 |
+
"""
|
| 204 |
+
lora_state_dict = {}
|
| 205 |
+
|
| 206 |
+
for name, module in unet.named_modules():
|
| 207 |
+
if hasattr(module, "set_lora_layer"):
|
| 208 |
+
lora_layer = getattr(module, "lora_layer")
|
| 209 |
+
if lora_layer is not None:
|
| 210 |
+
current_lora_layer_sd = lora_layer.state_dict()
|
| 211 |
+
for lora_layer_matrix_name, lora_param in current_lora_layer_sd.items():
|
| 212 |
+
# The matrix name can either be "down" or "up".
|
| 213 |
+
lora_state_dict[f"{name}.lora.{lora_layer_matrix_name}"] = lora_param
|
| 214 |
+
|
| 215 |
+
return lora_state_dict
|
| 216 |
+
|
| 217 |
+
|
| 218 |
+
def cast_training_params(model: Union[torch.nn.Module, List[torch.nn.Module]], dtype=torch.float32):
|
| 219 |
+
"""
|
| 220 |
+
Casts the training parameters of the model to the specified data type.
|
| 221 |
+
|
| 222 |
+
Args:
|
| 223 |
+
model: The PyTorch model whose parameters will be cast.
|
| 224 |
+
dtype: The data type to which the model parameters will be cast.
|
| 225 |
+
"""
|
| 226 |
+
if not isinstance(model, list):
|
| 227 |
+
model = [model]
|
| 228 |
+
for m in model:
|
| 229 |
+
for param in m.parameters():
|
| 230 |
+
# only upcast trainable parameters into fp32
|
| 231 |
+
if param.requires_grad:
|
| 232 |
+
param.data = param.to(dtype)
|
| 233 |
+
|
| 234 |
+
|
| 235 |
+
def _set_state_dict_into_text_encoder(
|
| 236 |
+
lora_state_dict: Dict[str, torch.Tensor], prefix: str, text_encoder: torch.nn.Module
|
| 237 |
+
):
|
| 238 |
+
"""
|
| 239 |
+
Sets the `lora_state_dict` into `text_encoder` coming from `transformers`.
|
| 240 |
+
|
| 241 |
+
Args:
|
| 242 |
+
lora_state_dict: The state dictionary to be set.
|
| 243 |
+
prefix: String identifier to retrieve the portion of the state dict that belongs to `text_encoder`.
|
| 244 |
+
text_encoder: Where the `lora_state_dict` is to be set.
|
| 245 |
+
"""
|
| 246 |
+
|
| 247 |
+
text_encoder_state_dict = {
|
| 248 |
+
f"{k.replace(prefix, '')}": v for k, v in lora_state_dict.items() if k.startswith(prefix)
|
| 249 |
+
}
|
| 250 |
+
text_encoder_state_dict = convert_state_dict_to_peft(convert_state_dict_to_diffusers(text_encoder_state_dict))
|
| 251 |
+
set_peft_model_state_dict(text_encoder, text_encoder_state_dict, adapter_name="default")
|
| 252 |
+
|
| 253 |
+
|
| 254 |
+
def _collate_lora_metadata(modules_to_save: Dict[str, torch.nn.Module]) -> Dict[str, Any]:
|
| 255 |
+
metadatas = {}
|
| 256 |
+
for module_name, module in modules_to_save.items():
|
| 257 |
+
if module is not None:
|
| 258 |
+
metadatas[f"{module_name}_lora_adapter_metadata"] = module.peft_config["default"].to_dict()
|
| 259 |
+
return metadatas
|
| 260 |
+
|
| 261 |
+
|
| 262 |
+
def compute_density_for_timestep_sampling(
|
| 263 |
+
weighting_scheme: str,
|
| 264 |
+
batch_size: int,
|
| 265 |
+
logit_mean: float = None,
|
| 266 |
+
logit_std: float = None,
|
| 267 |
+
mode_scale: float = None,
|
| 268 |
+
device: Union[torch.device, str] = "cpu",
|
| 269 |
+
generator: Optional[torch.Generator] = None,
|
| 270 |
+
):
|
| 271 |
+
"""
|
| 272 |
+
Compute the density for sampling the timesteps when doing SD3 training.
|
| 273 |
+
|
| 274 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
| 275 |
+
|
| 276 |
+
SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
|
| 277 |
+
"""
|
| 278 |
+
if weighting_scheme == "logit_normal":
|
| 279 |
+
u = torch.normal(mean=logit_mean, std=logit_std, size=(batch_size,), device=device, generator=generator)
|
| 280 |
+
u = torch.nn.functional.sigmoid(u)
|
| 281 |
+
elif weighting_scheme == "mode":
|
| 282 |
+
u = torch.rand(size=(batch_size,), device=device, generator=generator)
|
| 283 |
+
u = 1 - u - mode_scale * (torch.cos(math.pi * u / 2) ** 2 - 1 + u)
|
| 284 |
+
else:
|
| 285 |
+
u = torch.rand(size=(batch_size,), device=device, generator=generator)
|
| 286 |
+
return u
|
| 287 |
+
|
| 288 |
+
|
| 289 |
+
def compute_loss_weighting_for_sd3(weighting_scheme: str, sigmas=None):
|
| 290 |
+
"""
|
| 291 |
+
Computes loss weighting scheme for SD3 training.
|
| 292 |
+
|
| 293 |
+
Courtesy: This was contributed by Rafie Walker in https://github.com/huggingface/diffusers/pull/8528.
|
| 294 |
+
|
| 295 |
+
SD3 paper reference: https://huggingface.co/papers/2403.03206v1.
|
| 296 |
+
"""
|
| 297 |
+
if weighting_scheme == "sigma_sqrt":
|
| 298 |
+
weighting = (sigmas**-2.0).float()
|
| 299 |
+
elif weighting_scheme == "cosmap":
|
| 300 |
+
bot = 1 - 2 * sigmas + 2 * sigmas**2
|
| 301 |
+
weighting = 2 / (math.pi * bot)
|
| 302 |
+
else:
|
| 303 |
+
weighting = torch.ones_like(sigmas)
|
| 304 |
+
return weighting
|
| 305 |
+
|
| 306 |
+
|
| 307 |
+
def free_memory():
|
| 308 |
+
"""
|
| 309 |
+
Runs garbage collection. Then clears the cache of the available accelerator.
|
| 310 |
+
"""
|
| 311 |
+
gc.collect()
|
| 312 |
+
|
| 313 |
+
if torch.cuda.is_available():
|
| 314 |
+
torch.cuda.empty_cache()
|
| 315 |
+
elif torch.backends.mps.is_available():
|
| 316 |
+
torch.mps.empty_cache()
|
| 317 |
+
elif is_torch_npu_available():
|
| 318 |
+
torch_npu.npu.empty_cache()
|
| 319 |
+
elif hasattr(torch, "xpu") and torch.xpu.is_available():
|
| 320 |
+
torch.xpu.empty_cache()
|
| 321 |
+
|
| 322 |
+
|
| 323 |
+
@contextmanager
|
| 324 |
+
def offload_models(
|
| 325 |
+
*modules: Union[torch.nn.Module, DiffusionPipeline], device: Union[str, torch.device], offload: bool = True
|
| 326 |
+
):
|
| 327 |
+
"""
|
| 328 |
+
Context manager that, if offload=True, moves each module to `device` on enter, then moves it back to its original
|
| 329 |
+
device on exit.
|
| 330 |
+
|
| 331 |
+
Args:
|
| 332 |
+
device (`str` or `torch.Device`): Device to move the `modules` to.
|
| 333 |
+
offload (`bool`): Flag to enable offloading.
|
| 334 |
+
"""
|
| 335 |
+
if offload:
|
| 336 |
+
is_model = not any(isinstance(m, DiffusionPipeline) for m in modules)
|
| 337 |
+
# record where each module was
|
| 338 |
+
if is_model:
|
| 339 |
+
original_devices = [next(m.parameters()).device for m in modules]
|
| 340 |
+
else:
|
| 341 |
+
assert len(modules) == 1
|
| 342 |
+
# For DiffusionPipeline, wrap the device in a list to make it iterable
|
| 343 |
+
original_devices = [modules[0].device]
|
| 344 |
+
# move to target device
|
| 345 |
+
for m in modules:
|
| 346 |
+
m.to(device)
|
| 347 |
+
|
| 348 |
+
try:
|
| 349 |
+
yield
|
| 350 |
+
finally:
|
| 351 |
+
if offload:
|
| 352 |
+
# move back to original devices
|
| 353 |
+
for m, orig_dev in zip(modules, original_devices):
|
| 354 |
+
m.to(orig_dev)
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
def parse_buckets_string(buckets_str):
|
| 358 |
+
"""Parses a string defining buckets into a list of (height, width) tuples."""
|
| 359 |
+
if not buckets_str:
|
| 360 |
+
raise ValueError("Bucket string cannot be empty.")
|
| 361 |
+
|
| 362 |
+
bucket_pairs = buckets_str.strip().split(";")
|
| 363 |
+
parsed_buckets = []
|
| 364 |
+
for pair_str in bucket_pairs:
|
| 365 |
+
match = re.match(r"^\s*(\d+)\s*,\s*(\d+)\s*$", pair_str)
|
| 366 |
+
if not match:
|
| 367 |
+
raise ValueError(f"Invalid bucket format: '{pair_str}'. Expected 'height,width'.")
|
| 368 |
+
try:
|
| 369 |
+
height = int(match.group(1))
|
| 370 |
+
width = int(match.group(2))
|
| 371 |
+
if height <= 0 or width <= 0:
|
| 372 |
+
raise ValueError("Bucket dimensions must be positive integers.")
|
| 373 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 374 |
+
warnings.warn(f"Bucket dimension ({height},{width}) not divisible by 8. This might cause issues.")
|
| 375 |
+
parsed_buckets.append((height, width))
|
| 376 |
+
except ValueError as e:
|
| 377 |
+
raise ValueError(f"Invalid integer in bucket pair '{pair_str}': {e}") from e
|
| 378 |
+
|
| 379 |
+
if not parsed_buckets:
|
| 380 |
+
raise ValueError("No valid buckets found in the provided string.")
|
| 381 |
+
|
| 382 |
+
return parsed_buckets
|
| 383 |
+
|
| 384 |
+
|
| 385 |
+
def find_nearest_bucket(h, w, bucket_options):
|
| 386 |
+
"""Finds the closes bucket to the given height and width."""
|
| 387 |
+
min_metric = float("inf")
|
| 388 |
+
best_bucket_idx = None
|
| 389 |
+
for bucket_idx, (bucket_h, bucket_w) in enumerate(bucket_options):
|
| 390 |
+
metric = abs(h * bucket_w - w * bucket_h)
|
| 391 |
+
if metric <= min_metric:
|
| 392 |
+
min_metric = metric
|
| 393 |
+
best_bucket_idx = bucket_idx
|
| 394 |
+
return best_bucket_idx
|
| 395 |
+
|
| 396 |
+
|
| 397 |
+
# Adapted from torch-ema https://github.com/fadel/pytorch_ema/blob/master/torch_ema/ema.py#L14
|
| 398 |
+
class EMAModel:
|
| 399 |
+
"""
|
| 400 |
+
Exponential Moving Average of models weights
|
| 401 |
+
"""
|
| 402 |
+
|
| 403 |
+
def __init__(
|
| 404 |
+
self,
|
| 405 |
+
parameters: Iterable[torch.nn.Parameter],
|
| 406 |
+
decay: float = 0.9999,
|
| 407 |
+
min_decay: float = 0.0,
|
| 408 |
+
update_after_step: int = 0,
|
| 409 |
+
use_ema_warmup: bool = False,
|
| 410 |
+
inv_gamma: Union[float, int] = 1.0,
|
| 411 |
+
power: Union[float, int] = 2 / 3,
|
| 412 |
+
foreach: bool = False,
|
| 413 |
+
model_cls: Optional[Any] = None,
|
| 414 |
+
model_config: Dict[str, Any] = None,
|
| 415 |
+
**kwargs,
|
| 416 |
+
):
|
| 417 |
+
"""
|
| 418 |
+
Args:
|
| 419 |
+
parameters (Iterable[torch.nn.Parameter]): The parameters to track.
|
| 420 |
+
decay (float): The decay factor for the exponential moving average.
|
| 421 |
+
min_decay (float): The minimum decay factor for the exponential moving average.
|
| 422 |
+
update_after_step (int): The number of steps to wait before starting to update the EMA weights.
|
| 423 |
+
use_ema_warmup (bool): Whether to use EMA warmup.
|
| 424 |
+
inv_gamma (float):
|
| 425 |
+
Inverse multiplicative factor of EMA warmup. Default: 1. Only used if `use_ema_warmup` is True.
|
| 426 |
+
power (float): Exponential factor of EMA warmup. Default: 2/3. Only used if `use_ema_warmup` is True.
|
| 427 |
+
foreach (bool): Use torch._foreach functions for updating shadow parameters. Should be faster.
|
| 428 |
+
device (Optional[Union[str, torch.device]]): The device to store the EMA weights on. If None, the EMA
|
| 429 |
+
weights will be stored on CPU.
|
| 430 |
+
|
| 431 |
+
@crowsonkb's notes on EMA Warmup:
|
| 432 |
+
If gamma=1 and power=1, implements a simple average. gamma=1, power=2/3 are good values for models you plan
|
| 433 |
+
to train for a million or more steps (reaches decay factor 0.999 at 31.6K steps, 0.9999 at 1M steps),
|
| 434 |
+
gamma=1, power=3/4 for models you plan to train for less (reaches decay factor 0.999 at 10K steps, 0.9999
|
| 435 |
+
at 215.4k steps).
|
| 436 |
+
"""
|
| 437 |
+
|
| 438 |
+
if isinstance(parameters, torch.nn.Module):
|
| 439 |
+
deprecation_message = (
|
| 440 |
+
"Passing a `torch.nn.Module` to `ExponentialMovingAverage` is deprecated. "
|
| 441 |
+
"Please pass the parameters of the module instead."
|
| 442 |
+
)
|
| 443 |
+
deprecate(
|
| 444 |
+
"passing a `torch.nn.Module` to `ExponentialMovingAverage`",
|
| 445 |
+
"1.0.0",
|
| 446 |
+
deprecation_message,
|
| 447 |
+
standard_warn=False,
|
| 448 |
+
)
|
| 449 |
+
parameters = parameters.parameters()
|
| 450 |
+
|
| 451 |
+
# set use_ema_warmup to True if a torch.nn.Module is passed for backwards compatibility
|
| 452 |
+
use_ema_warmup = True
|
| 453 |
+
|
| 454 |
+
if kwargs.get("max_value", None) is not None:
|
| 455 |
+
deprecation_message = "The `max_value` argument is deprecated. Please use `decay` instead."
|
| 456 |
+
deprecate("max_value", "1.0.0", deprecation_message, standard_warn=False)
|
| 457 |
+
decay = kwargs["max_value"]
|
| 458 |
+
|
| 459 |
+
if kwargs.get("min_value", None) is not None:
|
| 460 |
+
deprecation_message = "The `min_value` argument is deprecated. Please use `min_decay` instead."
|
| 461 |
+
deprecate("min_value", "1.0.0", deprecation_message, standard_warn=False)
|
| 462 |
+
min_decay = kwargs["min_value"]
|
| 463 |
+
|
| 464 |
+
parameters = list(parameters)
|
| 465 |
+
self.shadow_params = [p.clone().detach() for p in parameters]
|
| 466 |
+
|
| 467 |
+
if kwargs.get("device", None) is not None:
|
| 468 |
+
deprecation_message = "The `device` argument is deprecated. Please use `to` instead."
|
| 469 |
+
deprecate("device", "1.0.0", deprecation_message, standard_warn=False)
|
| 470 |
+
self.to(device=kwargs["device"])
|
| 471 |
+
|
| 472 |
+
self.temp_stored_params = None
|
| 473 |
+
|
| 474 |
+
self.decay = decay
|
| 475 |
+
self.min_decay = min_decay
|
| 476 |
+
self.update_after_step = update_after_step
|
| 477 |
+
self.use_ema_warmup = use_ema_warmup
|
| 478 |
+
self.inv_gamma = inv_gamma
|
| 479 |
+
self.power = power
|
| 480 |
+
self.optimization_step = 0
|
| 481 |
+
self.cur_decay_value = None # set in `step()`
|
| 482 |
+
self.foreach = foreach
|
| 483 |
+
|
| 484 |
+
self.model_cls = model_cls
|
| 485 |
+
self.model_config = model_config
|
| 486 |
+
|
| 487 |
+
@classmethod
|
| 488 |
+
def from_pretrained(cls, path, model_cls, foreach=False) -> "EMAModel":
|
| 489 |
+
_, ema_kwargs = model_cls.from_config(path, return_unused_kwargs=True)
|
| 490 |
+
model = model_cls.from_pretrained(path)
|
| 491 |
+
|
| 492 |
+
ema_model = cls(model.parameters(), model_cls=model_cls, model_config=model.config, foreach=foreach)
|
| 493 |
+
|
| 494 |
+
ema_model.load_state_dict(ema_kwargs)
|
| 495 |
+
return ema_model
|
| 496 |
+
|
| 497 |
+
def save_pretrained(self, path):
|
| 498 |
+
if self.model_cls is None:
|
| 499 |
+
raise ValueError("`save_pretrained` can only be used if `model_cls` was defined at __init__.")
|
| 500 |
+
|
| 501 |
+
if self.model_config is None:
|
| 502 |
+
raise ValueError("`save_pretrained` can only be used if `model_config` was defined at __init__.")
|
| 503 |
+
|
| 504 |
+
model = self.model_cls.from_config(self.model_config)
|
| 505 |
+
state_dict = self.state_dict()
|
| 506 |
+
state_dict.pop("shadow_params", None)
|
| 507 |
+
|
| 508 |
+
model.register_to_config(**state_dict)
|
| 509 |
+
self.copy_to(model.parameters())
|
| 510 |
+
model.save_pretrained(path)
|
| 511 |
+
|
| 512 |
+
def get_decay(self, optimization_step: int) -> float:
|
| 513 |
+
"""
|
| 514 |
+
Compute the decay factor for the exponential moving average.
|
| 515 |
+
"""
|
| 516 |
+
step = max(0, optimization_step - self.update_after_step - 1)
|
| 517 |
+
|
| 518 |
+
if step <= 0:
|
| 519 |
+
return 0.0
|
| 520 |
+
|
| 521 |
+
if self.use_ema_warmup:
|
| 522 |
+
cur_decay_value = 1 - (1 + step / self.inv_gamma) ** -self.power
|
| 523 |
+
else:
|
| 524 |
+
cur_decay_value = (1 + step) / (10 + step)
|
| 525 |
+
|
| 526 |
+
cur_decay_value = min(cur_decay_value, self.decay)
|
| 527 |
+
# make sure decay is not smaller than min_decay
|
| 528 |
+
cur_decay_value = max(cur_decay_value, self.min_decay)
|
| 529 |
+
return cur_decay_value
|
| 530 |
+
|
| 531 |
+
@torch.no_grad()
|
| 532 |
+
def step(self, parameters: Iterable[torch.nn.Parameter]):
|
| 533 |
+
if isinstance(parameters, torch.nn.Module):
|
| 534 |
+
deprecation_message = (
|
| 535 |
+
"Passing a `torch.nn.Module` to `ExponentialMovingAverage.step` is deprecated. "
|
| 536 |
+
"Please pass the parameters of the module instead."
|
| 537 |
+
)
|
| 538 |
+
deprecate(
|
| 539 |
+
"passing a `torch.nn.Module` to `ExponentialMovingAverage.step`",
|
| 540 |
+
"1.0.0",
|
| 541 |
+
deprecation_message,
|
| 542 |
+
standard_warn=False,
|
| 543 |
+
)
|
| 544 |
+
parameters = parameters.parameters()
|
| 545 |
+
|
| 546 |
+
parameters = list(parameters)
|
| 547 |
+
|
| 548 |
+
self.optimization_step += 1
|
| 549 |
+
|
| 550 |
+
# Compute the decay factor for the exponential moving average.
|
| 551 |
+
decay = self.get_decay(self.optimization_step)
|
| 552 |
+
self.cur_decay_value = decay
|
| 553 |
+
one_minus_decay = 1 - decay
|
| 554 |
+
|
| 555 |
+
context_manager = contextlib.nullcontext()
|
| 556 |
+
|
| 557 |
+
if self.foreach:
|
| 558 |
+
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
| 559 |
+
context_manager = deepspeed.zero.GatheredParameters(parameters, modifier_rank=None)
|
| 560 |
+
|
| 561 |
+
with context_manager:
|
| 562 |
+
params_grad = [param for param in parameters if param.requires_grad]
|
| 563 |
+
s_params_grad = [
|
| 564 |
+
s_param for s_param, param in zip(self.shadow_params, parameters) if param.requires_grad
|
| 565 |
+
]
|
| 566 |
+
|
| 567 |
+
if len(params_grad) < len(parameters):
|
| 568 |
+
torch._foreach_copy_(
|
| 569 |
+
[s_param for s_param, param in zip(self.shadow_params, parameters) if not param.requires_grad],
|
| 570 |
+
[param for param in parameters if not param.requires_grad],
|
| 571 |
+
non_blocking=True,
|
| 572 |
+
)
|
| 573 |
+
|
| 574 |
+
torch._foreach_sub_(
|
| 575 |
+
s_params_grad, torch._foreach_sub(s_params_grad, params_grad), alpha=one_minus_decay
|
| 576 |
+
)
|
| 577 |
+
|
| 578 |
+
else:
|
| 579 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 580 |
+
if is_transformers_available() and transformers.integrations.deepspeed.is_deepspeed_zero3_enabled():
|
| 581 |
+
context_manager = deepspeed.zero.GatheredParameters(param, modifier_rank=None)
|
| 582 |
+
|
| 583 |
+
with context_manager:
|
| 584 |
+
if param.requires_grad:
|
| 585 |
+
s_param.sub_(one_minus_decay * (s_param - param))
|
| 586 |
+
else:
|
| 587 |
+
s_param.copy_(param)
|
| 588 |
+
|
| 589 |
+
def copy_to(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 590 |
+
"""
|
| 591 |
+
Copy current averaged parameters into given collection of parameters.
|
| 592 |
+
|
| 593 |
+
Args:
|
| 594 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 595 |
+
updated with the stored moving averages. If `None`, the parameters with which this
|
| 596 |
+
`ExponentialMovingAverage` was initialized will be used.
|
| 597 |
+
"""
|
| 598 |
+
parameters = list(parameters)
|
| 599 |
+
if self.foreach:
|
| 600 |
+
torch._foreach_copy_(
|
| 601 |
+
[param.data for param in parameters],
|
| 602 |
+
[s_param.to(param.device).data for s_param, param in zip(self.shadow_params, parameters)],
|
| 603 |
+
)
|
| 604 |
+
else:
|
| 605 |
+
for s_param, param in zip(self.shadow_params, parameters):
|
| 606 |
+
param.data.copy_(s_param.to(param.device).data)
|
| 607 |
+
|
| 608 |
+
def pin_memory(self) -> None:
|
| 609 |
+
r"""
|
| 610 |
+
Move internal buffers of the ExponentialMovingAverage to pinned memory. Useful for non-blocking transfers for
|
| 611 |
+
offloading EMA params to the host.
|
| 612 |
+
"""
|
| 613 |
+
|
| 614 |
+
self.shadow_params = [p.pin_memory() for p in self.shadow_params]
|
| 615 |
+
|
| 616 |
+
def to(self, device=None, dtype=None, non_blocking=False) -> None:
|
| 617 |
+
r"""
|
| 618 |
+
Move internal buffers of the ExponentialMovingAverage to `device`.
|
| 619 |
+
|
| 620 |
+
Args:
|
| 621 |
+
device: like `device` argument to `torch.Tensor.to`
|
| 622 |
+
"""
|
| 623 |
+
# .to() on the tensors handles None correctly
|
| 624 |
+
self.shadow_params = [
|
| 625 |
+
p.to(device=device, dtype=dtype, non_blocking=non_blocking)
|
| 626 |
+
if p.is_floating_point()
|
| 627 |
+
else p.to(device=device, non_blocking=non_blocking)
|
| 628 |
+
for p in self.shadow_params
|
| 629 |
+
]
|
| 630 |
+
|
| 631 |
+
def state_dict(self) -> dict:
|
| 632 |
+
r"""
|
| 633 |
+
Returns the state of the ExponentialMovingAverage as a dict. This method is used by accelerate during
|
| 634 |
+
checkpointing to save the ema state dict.
|
| 635 |
+
"""
|
| 636 |
+
# Following PyTorch conventions, references to tensors are returned:
|
| 637 |
+
# "returns a reference to the state and not its copy!" -
|
| 638 |
+
# https://pytorch.org/tutorials/beginner/saving_loading_models.html#what-is-a-state-dict
|
| 639 |
+
return {
|
| 640 |
+
"decay": self.decay,
|
| 641 |
+
"min_decay": self.min_decay,
|
| 642 |
+
"optimization_step": self.optimization_step,
|
| 643 |
+
"update_after_step": self.update_after_step,
|
| 644 |
+
"use_ema_warmup": self.use_ema_warmup,
|
| 645 |
+
"inv_gamma": self.inv_gamma,
|
| 646 |
+
"power": self.power,
|
| 647 |
+
"shadow_params": self.shadow_params,
|
| 648 |
+
}
|
| 649 |
+
|
| 650 |
+
def store(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 651 |
+
r"""
|
| 652 |
+
Saves the current parameters for restoring later.
|
| 653 |
+
|
| 654 |
+
Args:
|
| 655 |
+
parameters: Iterable of `torch.nn.Parameter`. The parameters to be temporarily stored.
|
| 656 |
+
"""
|
| 657 |
+
self.temp_stored_params = [param.detach().cpu().clone() for param in parameters]
|
| 658 |
+
|
| 659 |
+
def restore(self, parameters: Iterable[torch.nn.Parameter]) -> None:
|
| 660 |
+
r"""
|
| 661 |
+
Restore the parameters stored with the `store` method. Useful to validate the model with EMA parameters
|
| 662 |
+
without: affecting the original optimization process. Store the parameters before the `copy_to()` method. After
|
| 663 |
+
validation (or model saving), use this to restore the former parameters.
|
| 664 |
+
|
| 665 |
+
Args:
|
| 666 |
+
parameters: Iterable of `torch.nn.Parameter`; the parameters to be
|
| 667 |
+
updated with the stored parameters. If `None`, the parameters with which this
|
| 668 |
+
`ExponentialMovingAverage` was initialized will be used.
|
| 669 |
+
"""
|
| 670 |
+
|
| 671 |
+
if self.temp_stored_params is None:
|
| 672 |
+
raise RuntimeError("This ExponentialMovingAverage has no `store()`ed weights to `restore()`")
|
| 673 |
+
if self.foreach:
|
| 674 |
+
torch._foreach_copy_(
|
| 675 |
+
[param.data for param in parameters], [c_param.data for c_param in self.temp_stored_params]
|
| 676 |
+
)
|
| 677 |
+
else:
|
| 678 |
+
for c_param, param in zip(self.temp_stored_params, parameters):
|
| 679 |
+
param.data.copy_(c_param.data)
|
| 680 |
+
|
| 681 |
+
# Better memory-wise.
|
| 682 |
+
self.temp_stored_params = None
|
| 683 |
+
|
| 684 |
+
def load_state_dict(self, state_dict: dict) -> None:
|
| 685 |
+
r"""
|
| 686 |
+
Loads the ExponentialMovingAverage state. This method is used by accelerate during checkpointing to save the
|
| 687 |
+
ema state dict.
|
| 688 |
+
|
| 689 |
+
Args:
|
| 690 |
+
state_dict (dict): EMA state. Should be an object returned
|
| 691 |
+
from a call to :meth:`state_dict`.
|
| 692 |
+
"""
|
| 693 |
+
# deepcopy, to be consistent with module API
|
| 694 |
+
state_dict = copy.deepcopy(state_dict)
|
| 695 |
+
|
| 696 |
+
self.decay = state_dict.get("decay", self.decay)
|
| 697 |
+
if self.decay < 0.0 or self.decay > 1.0:
|
| 698 |
+
raise ValueError("Decay must be between 0 and 1")
|
| 699 |
+
|
| 700 |
+
self.min_decay = state_dict.get("min_decay", self.min_decay)
|
| 701 |
+
if not isinstance(self.min_decay, float):
|
| 702 |
+
raise ValueError("Invalid min_decay")
|
| 703 |
+
|
| 704 |
+
self.optimization_step = state_dict.get("optimization_step", self.optimization_step)
|
| 705 |
+
if not isinstance(self.optimization_step, int):
|
| 706 |
+
raise ValueError("Invalid optimization_step")
|
| 707 |
+
|
| 708 |
+
self.update_after_step = state_dict.get("update_after_step", self.update_after_step)
|
| 709 |
+
if not isinstance(self.update_after_step, int):
|
| 710 |
+
raise ValueError("Invalid update_after_step")
|
| 711 |
+
|
| 712 |
+
self.use_ema_warmup = state_dict.get("use_ema_warmup", self.use_ema_warmup)
|
| 713 |
+
if not isinstance(self.use_ema_warmup, bool):
|
| 714 |
+
raise ValueError("Invalid use_ema_warmup")
|
| 715 |
+
|
| 716 |
+
self.inv_gamma = state_dict.get("inv_gamma", self.inv_gamma)
|
| 717 |
+
if not isinstance(self.inv_gamma, (float, int)):
|
| 718 |
+
raise ValueError("Invalid inv_gamma")
|
| 719 |
+
|
| 720 |
+
self.power = state_dict.get("power", self.power)
|
| 721 |
+
if not isinstance(self.power, (float, int)):
|
| 722 |
+
raise ValueError("Invalid power")
|
| 723 |
+
|
| 724 |
+
shadow_params = state_dict.get("shadow_params", None)
|
| 725 |
+
if shadow_params is not None:
|
| 726 |
+
self.shadow_params = shadow_params
|
| 727 |
+
if not isinstance(self.shadow_params, list):
|
| 728 |
+
raise ValueError("shadow_params must be a list")
|
| 729 |
+
if not all(isinstance(p, torch.Tensor) for p in self.shadow_params):
|
| 730 |
+
raise ValueError("shadow_params must all be Tensors")
|