Upload folder using huggingface_hub
Browse filesThis view is limited to 50 files because it contains too many changes. See raw diff
- .gitattributes +108 -0
- FastGen/.claude/settings.local.json +19 -0
- FastGen/.gitignore +180 -0
- FastGen/CLAUDE.md +7 -0
- FastGen/CONTRIBUTING.md +93 -0
- FastGen/Dockerfile +77 -0
- FastGen/LICENSE +202 -0
- FastGen/Makefile +29 -0
- FastGen/README.md +182 -0
- FastGen/SF.md +146 -0
- FastGen/assets/teaser.png +3 -0
- FastGen/fastgen/__init__.py +2 -0
- FastGen/fastgen/callbacks/README.md +59 -0
- FastGen/fastgen/callbacks/__init__.py +2 -0
- FastGen/fastgen/callbacks/callback.py +183 -0
- FastGen/fastgen/callbacks/ct_schedule.py +83 -0
- FastGen/fastgen/callbacks/ema.py +155 -0
- FastGen/fastgen/callbacks/forced_weight_norm.py +28 -0
- FastGen/fastgen/callbacks/gpu_mem_profiler.py +134 -0
- FastGen/fastgen/callbacks/gpu_stats.py +92 -0
- FastGen/fastgen/callbacks/grad_clip.py +222 -0
- FastGen/fastgen/callbacks/param_count.py +116 -0
- FastGen/fastgen/callbacks/train_profiler.py +138 -0
- FastGen/fastgen/callbacks/wandb.py +404 -0
- FastGen/fastgen/configs/README.md +108 -0
- FastGen/fastgen/configs/__init__.py +2 -0
- FastGen/fastgen/configs/callbacks.py +63 -0
- FastGen/fastgen/configs/config.py +254 -0
- FastGen/fastgen/configs/config_utils.py +317 -0
- FastGen/fastgen/configs/data.py +123 -0
- FastGen/fastgen/configs/data_dummy.py +67 -0
- FastGen/fastgen/configs/discriminator.py +106 -0
- FastGen/fastgen/configs/experiments/CogVideoX/config_dmd2.py +48 -0
- FastGen/fastgen/configs/experiments/CogVideoX/config_kd.py +30 -0
- FastGen/fastgen/configs/experiments/CogVideoX/config_sft.py +35 -0
- FastGen/fastgen/configs/experiments/CogVideoX/config_sft_5b.py +40 -0
- FastGen/fastgen/configs/experiments/CosmosPredict2/config_dmd2.py +69 -0
- FastGen/fastgen/configs/experiments/CosmosPredict2/config_dmd2_14b.py +37 -0
- FastGen/fastgen/configs/experiments/CosmosPredict2/config_dmd2_v2w.py +18 -0
- FastGen/fastgen/configs/experiments/CosmosPredict2/config_sft.py +48 -0
- FastGen/fastgen/configs/experiments/CosmosPredict2/config_sft_14b.py +19 -0
- FastGen/fastgen/configs/experiments/CosmosPredict2/config_sft_v2w.py +18 -0
- FastGen/fastgen/configs/experiments/DiT/config_mf_b.py +70 -0
- FastGen/fastgen/configs/experiments/DiT/config_mf_xl.py +26 -0
- FastGen/fastgen/configs/experiments/DiT/config_sft_dit_xl.py +68 -0
- FastGen/fastgen/configs/experiments/DiT/config_sft_sit_xl.py +75 -0
- FastGen/fastgen/configs/experiments/EDM/config_cm_cifar10.py +28 -0
- FastGen/fastgen/configs/experiments/EDM/config_cm_cifar10_fast.py +25 -0
- FastGen/fastgen/configs/experiments/EDM/config_cm_in64.py +65 -0
- FastGen/fastgen/configs/experiments/EDM/config_dmd2_cifar10.py +24 -0
.gitattributes
CHANGED
|
@@ -33,3 +33,111 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 33 |
*.zip filter=lfs diff=lfs merge=lfs -text
|
| 34 |
*.zst filter=lfs diff=lfs merge=lfs -text
|
| 35 |
*tfevents* filter=lfs diff=lfs merge=lfs -text
|
| 36 |
+
FastGen/assets/teaser.png filter=lfs diff=lfs merge=lfs -text
|
| 37 |
+
FastGen/scripts/inference/examples/00_child_swings_rusty_swing_set.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 38 |
+
FastGen/scripts/inference/examples/00_child_swings_rusty_swing_set.png filter=lfs diff=lfs merge=lfs -text
|
| 39 |
+
FastGen/scripts/inference/examples/01_impressionist_rubber_duck_sunset.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 40 |
+
FastGen/scripts/inference/examples/01_impressionist_rubber_duck_sunset.png filter=lfs diff=lfs merge=lfs -text
|
| 41 |
+
FastGen/scripts/inference/examples/02_two_cyclists_stop_sign.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 42 |
+
FastGen/scripts/inference/examples/02_two_cyclists_stop_sign.png filter=lfs diff=lfs merge=lfs -text
|
| 43 |
+
FastGen/scripts/inference/examples/03_astronaut_cow_beach_van_gogh.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 44 |
+
FastGen/scripts/inference/examples/03_astronaut_cow_beach_van_gogh.png filter=lfs diff=lfs merge=lfs -text
|
| 45 |
+
FastGen/scripts/inference/examples/04_bird_flying_church_tower_clock.mp4 filter=lfs diff=lfs merge=lfs -text
|
| 46 |
+
FastGen/scripts/inference/examples/04_bird_flying_church_tower_clock.png filter=lfs diff=lfs merge=lfs -text
|
| 47 |
+
checkpoints/Self-Forcing/vidprom_filtered_extended.txt filter=lfs diff=lfs merge=lfs -text
|
| 48 |
+
hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/comp_effic.png filter=lfs diff=lfs merge=lfs -text
|
| 49 |
+
hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/data_for_diff_stage.jpg filter=lfs diff=lfs merge=lfs -text
|
| 50 |
+
hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/i2v_res.png filter=lfs diff=lfs merge=lfs -text
|
| 51 |
+
hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/t2v_res.jpg filter=lfs diff=lfs merge=lfs -text
|
| 52 |
+
hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/vben_1.3b_vs_sota.png filter=lfs diff=lfs merge=lfs -text
|
| 53 |
+
hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/vben_vs_sota.png filter=lfs diff=lfs merge=lfs -text
|
| 54 |
+
hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/video_dit_arch.jpg filter=lfs diff=lfs merge=lfs -text
|
| 55 |
+
hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/video_vae_res.jpg filter=lfs diff=lfs merge=lfs -text
|
| 56 |
+
hf_models/Wan2.1-T2V-1.3B-Diffusers/examples/i2v_input.JPG filter=lfs diff=lfs merge=lfs -text
|
| 57 |
+
hf_models/Wan2.1-T2V-1.3B-Diffusers/tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
|
| 58 |
+
pip_wheels/accelerate-1.12.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 59 |
+
pip_wheels/anyio-4.12.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 60 |
+
pip_wheels/av-14.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 61 |
+
pip_wheels/babel-2.18.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 62 |
+
pip_wheels/beautifulsoup4-4.14.3-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 63 |
+
pip_wheels/bleach-6.3.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 64 |
+
pip_wheels/bokeh-3.8.2-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 65 |
+
pip_wheels/boto3-1.42.39-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 66 |
+
pip_wheels/botocore-1.42.39-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 67 |
+
pip_wheels/certifi-2026.1.4-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 68 |
+
pip_wheels/cffi-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 69 |
+
pip_wheels/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 70 |
+
pip_wheels/click-8.3.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 71 |
+
pip_wheels/contourpy-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 72 |
+
pip_wheels/debugpy-1.8.20-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 73 |
+
pip_wheels/diffusers-0.35.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 74 |
+
pip_wheels/fsspec-2026.1.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 75 |
+
pip_wheels/gitpython-3.1.46-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 76 |
+
pip_wheels/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 77 |
+
pip_wheels/huggingface_hub-0.36.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 78 |
+
pip_wheels/imageio-2.37.2-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 79 |
+
pip_wheels/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 80 |
+
pip_wheels/ipykernel-7.1.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 81 |
+
pip_wheels/ipython-9.10.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 82 |
+
pip_wheels/ipywidgets-8.1.8-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 83 |
+
pip_wheels/jedi-0.19.2-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 84 |
+
pip_wheels/jinja2-3.1.6-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 85 |
+
pip_wheels/jupyter_client-8.8.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 86 |
+
pip_wheels/jupyter_server-2.17.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 87 |
+
pip_wheels/jupyterlab-4.5.3-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 88 |
+
pip_wheels/jupyterlab_widgets-3.0.16-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 89 |
+
pip_wheels/lark-1.3.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 90 |
+
pip_wheels/moviepy-2.2.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 91 |
+
pip_wheels/mpmath-1.3.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 92 |
+
pip_wheels/narwhals-2.16.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 93 |
+
pip_wheels/nbconvert-7.17.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 94 |
+
pip_wheels/networkx-3.6.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 95 |
+
pip_wheels/notebook-7.5.3-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 96 |
+
pip_wheels/numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 97 |
+
pip_wheels/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 98 |
+
pip_wheels/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 99 |
+
pip_wheels/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 100 |
+
pip_wheels/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 101 |
+
pip_wheels/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 102 |
+
pip_wheels/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 103 |
+
pip_wheels/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 104 |
+
pip_wheels/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 105 |
+
pip_wheels/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 106 |
+
pip_wheels/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 107 |
+
pip_wheels/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 108 |
+
pip_wheels/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 109 |
+
pip_wheels/opencv_python_headless-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 110 |
+
pip_wheels/pandas-2.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 111 |
+
pip_wheels/parso-0.8.5-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 112 |
+
pip_wheels/pillow-11.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 113 |
+
pip_wheels/plotly-6.5.2-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 114 |
+
pip_wheels/prompt_toolkit-3.0.52-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 115 |
+
pip_wheels/protobuf-6.33.5-cp39-abi3-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 116 |
+
pip_wheels/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 117 |
+
pip_wheels/pydantic-2.12.5-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 118 |
+
pip_wheels/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 119 |
+
pip_wheels/pygments-2.19.2-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 120 |
+
pip_wheels/python_dateutil-2.9.0.post0-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 121 |
+
pip_wheels/pytz-2025.2-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 122 |
+
pip_wheels/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 123 |
+
pip_wheels/pyzmq-26.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 124 |
+
pip_wheels/rdkit-2024.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 125 |
+
pip_wheels/regex-2026.1.15-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 126 |
+
pip_wheels/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 127 |
+
pip_wheels/safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 128 |
+
pip_wheels/scipy-1.16.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 129 |
+
pip_wheels/sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 130 |
+
pip_wheels/sentry_sdk-2.51.0-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 131 |
+
pip_wheels/setuptools-80.10.2-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 132 |
+
pip_wheels/sympy-1.13.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 133 |
+
pip_wheels/timm-1.0.24-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 134 |
+
pip_wheels/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 135 |
+
pip_wheels/torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 136 |
+
pip_wheels/torchvision-0.21.0-cp312-cp312-manylinux1_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 137 |
+
pip_wheels/tornado-6.5.4-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 138 |
+
pip_wheels/transformers-4.49.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 139 |
+
pip_wheels/triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 140 |
+
pip_wheels/tzdata-2025.3-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 141 |
+
pip_wheels/urllib3-2.6.3-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
| 142 |
+
pip_wheels/wandb-0.23.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
|
| 143 |
+
pip_wheels/widgetsnbextension-4.0.15-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
|
FastGen/.claude/settings.local.json
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"permissions": {
|
| 3 |
+
"allow": [
|
| 4 |
+
"Bash(conda activate:*)",
|
| 5 |
+
"Bash(pip install:*)",
|
| 6 |
+
"Bash(source:*)",
|
| 7 |
+
"Bash(conda info:*)",
|
| 8 |
+
"Bash(pip cache:*)",
|
| 9 |
+
"Bash(python:*)",
|
| 10 |
+
"Bash(curl:*)",
|
| 11 |
+
"Bash(echo:*)",
|
| 12 |
+
"Bash(huggingface-cli download:*)",
|
| 13 |
+
"Bash(chmod:*)",
|
| 14 |
+
"Bash(bash:*)",
|
| 15 |
+
"Bash(ls:*)",
|
| 16 |
+
"Bash(huggingface-cli:*)"
|
| 17 |
+
]
|
| 18 |
+
}
|
| 19 |
+
}
|
FastGen/.gitignore
ADDED
|
@@ -0,0 +1,180 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# macOS
|
| 2 |
+
.DS_Store
|
| 3 |
+
|
| 4 |
+
# Byte-compiled / optimized / DLL files
|
| 5 |
+
__pycache__/
|
| 6 |
+
*.py[cod]
|
| 7 |
+
*$py.class
|
| 8 |
+
|
| 9 |
+
# C extensions
|
| 10 |
+
*.so
|
| 11 |
+
|
| 12 |
+
# Distribution / packaging
|
| 13 |
+
.Python
|
| 14 |
+
build/
|
| 15 |
+
develop-eggs/
|
| 16 |
+
dist/
|
| 17 |
+
downloads/
|
| 18 |
+
eggs/
|
| 19 |
+
.eggs/
|
| 20 |
+
lib/
|
| 21 |
+
lib64/
|
| 22 |
+
parts/
|
| 23 |
+
sdist/
|
| 24 |
+
var/
|
| 25 |
+
wheels/
|
| 26 |
+
share/python-wheels/
|
| 27 |
+
*.egg-info/
|
| 28 |
+
.installed.cfg
|
| 29 |
+
*.egg
|
| 30 |
+
MANIFEST
|
| 31 |
+
|
| 32 |
+
# PyInstaller
|
| 33 |
+
# Usually these files are written by a python script from a template
|
| 34 |
+
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
| 35 |
+
*.manifest
|
| 36 |
+
*.spec
|
| 37 |
+
|
| 38 |
+
# Installer logs
|
| 39 |
+
pip-log.txt
|
| 40 |
+
pip-delete-this-directory.txt
|
| 41 |
+
|
| 42 |
+
# Unit test / coverage reports
|
| 43 |
+
htmlcov/
|
| 44 |
+
.tox/
|
| 45 |
+
.nox/
|
| 46 |
+
.coverage
|
| 47 |
+
.coverage.*
|
| 48 |
+
.cache
|
| 49 |
+
nosetests.xml
|
| 50 |
+
coverage.xml
|
| 51 |
+
*.cover
|
| 52 |
+
*.py,cover
|
| 53 |
+
.hypothesis/
|
| 54 |
+
.pytest_cache/
|
| 55 |
+
cover/
|
| 56 |
+
|
| 57 |
+
# Translations
|
| 58 |
+
*.mo
|
| 59 |
+
*.pot
|
| 60 |
+
|
| 61 |
+
# Django stuff:
|
| 62 |
+
*.log
|
| 63 |
+
local_settings.py
|
| 64 |
+
db.sqlite3
|
| 65 |
+
db.sqlite3-journal
|
| 66 |
+
|
| 67 |
+
# Flask stuff:
|
| 68 |
+
instance/
|
| 69 |
+
.webassets-cache
|
| 70 |
+
|
| 71 |
+
# Scrapy stuff:
|
| 72 |
+
.scrapy
|
| 73 |
+
|
| 74 |
+
# Sphinx documentation
|
| 75 |
+
docs/_build/
|
| 76 |
+
|
| 77 |
+
# PyBuilder
|
| 78 |
+
.pybuilder/
|
| 79 |
+
target/
|
| 80 |
+
|
| 81 |
+
# Jupyter Notebook
|
| 82 |
+
.ipynb_checkpoints
|
| 83 |
+
|
| 84 |
+
# IPython
|
| 85 |
+
profile_default/
|
| 86 |
+
ipython_config.py
|
| 87 |
+
|
| 88 |
+
# pyenv
|
| 89 |
+
# For a library or package, you might want to ignore these files since the code is
|
| 90 |
+
# intended to run in multiple environments; otherwise, check them in:
|
| 91 |
+
# .python-version
|
| 92 |
+
|
| 93 |
+
# pipenv
|
| 94 |
+
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
| 95 |
+
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
| 96 |
+
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
| 97 |
+
# install all needed dependencies.
|
| 98 |
+
#Pipfile.lock
|
| 99 |
+
|
| 100 |
+
# poetry
|
| 101 |
+
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
| 102 |
+
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
| 103 |
+
# commonly ignored for libraries.
|
| 104 |
+
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
| 105 |
+
#poetry.lock
|
| 106 |
+
|
| 107 |
+
# pdm
|
| 108 |
+
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
| 109 |
+
#pdm.lock
|
| 110 |
+
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
| 111 |
+
# in version control.
|
| 112 |
+
# https://pdm.fming.dev/#use-with-ide
|
| 113 |
+
.pdm.toml
|
| 114 |
+
|
| 115 |
+
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
| 116 |
+
__pypackages__/
|
| 117 |
+
|
| 118 |
+
# Celery stuff
|
| 119 |
+
celerybeat-schedule
|
| 120 |
+
celerybeat.pid
|
| 121 |
+
|
| 122 |
+
# SageMath parsed files
|
| 123 |
+
*.sage.py
|
| 124 |
+
|
| 125 |
+
# Environments
|
| 126 |
+
.env
|
| 127 |
+
.venv
|
| 128 |
+
env/
|
| 129 |
+
venv/
|
| 130 |
+
ENV/
|
| 131 |
+
env.bak/
|
| 132 |
+
venv.bak/
|
| 133 |
+
|
| 134 |
+
# Spyder project settings
|
| 135 |
+
.spyderproject
|
| 136 |
+
.spyproject
|
| 137 |
+
|
| 138 |
+
# Rope project settings
|
| 139 |
+
.ropeproject
|
| 140 |
+
|
| 141 |
+
# mkdocs documentation
|
| 142 |
+
/site
|
| 143 |
+
|
| 144 |
+
# mypy
|
| 145 |
+
.mypy_cache/
|
| 146 |
+
.dmypy.json
|
| 147 |
+
dmypy.json
|
| 148 |
+
|
| 149 |
+
# Pyre type checker
|
| 150 |
+
.pyre/
|
| 151 |
+
|
| 152 |
+
# pytype static type analyzer
|
| 153 |
+
.pytype/
|
| 154 |
+
|
| 155 |
+
# Cython debug symbols
|
| 156 |
+
cython_debug/
|
| 157 |
+
|
| 158 |
+
# PyCharm
|
| 159 |
+
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
| 160 |
+
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
| 161 |
+
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
| 162 |
+
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
| 163 |
+
.idea/
|
| 164 |
+
|
| 165 |
+
# large files
|
| 166 |
+
*.zip
|
| 167 |
+
|
| 168 |
+
# temporal folder
|
| 169 |
+
tmp/
|
| 170 |
+
|
| 171 |
+
# credentials
|
| 172 |
+
credentials
|
| 173 |
+
|
| 174 |
+
# logs
|
| 175 |
+
FASTGEN_OUTPUT/
|
| 176 |
+
wandb/
|
| 177 |
+
*.err
|
| 178 |
+
|
| 179 |
+
# vscode
|
| 180 |
+
*.vscode/
|
FastGen/CLAUDE.md
ADDED
|
@@ -0,0 +1,7 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# FastGen Project Notes
|
| 2 |
+
|
| 3 |
+
## Environment
|
| 4 |
+
|
| 5 |
+
- Conda environment: `sf`
|
| 6 |
+
- Activate: `conda activate sf`
|
| 7 |
+
- Install: `pip install -e .`
|
FastGen/CONTRIBUTING.md
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Contributing to FastGen
|
| 2 |
+
|
| 3 |
+
Thank you for your interest in contributing to FastGen!
|
| 4 |
+
|
| 5 |
+
## Issue Tracking
|
| 6 |
+
|
| 7 |
+
* All enhancement, bugfix, or change requests should begin with the creation of an issue.
|
| 8 |
+
* The issue should be reviewed and approved prior to code review.
|
| 9 |
+
|
| 10 |
+
## Development Setup
|
| 11 |
+
|
| 12 |
+
All development should be done using the provided Docker image to ensure a consistent environment.
|
| 13 |
+
|
| 14 |
+
```bash
|
| 15 |
+
# Build the Docker image
|
| 16 |
+
docker build -t fastgen:dev .
|
| 17 |
+
|
| 18 |
+
# Run interactively with GPU support (as current user, not root)
|
| 19 |
+
docker run --gpus all -it --rm \
|
| 20 |
+
--user $(id -u):$(id -g) \
|
| 21 |
+
-e HOME=/workspace \
|
| 22 |
+
-v /etc/passwd:/etc/passwd:ro \
|
| 23 |
+
-v /etc/group:/etc/group:ro \
|
| 24 |
+
-v $(pwd):/workspace \
|
| 25 |
+
fastgen:dev bash
|
| 26 |
+
|
| 27 |
+
# Inside the container, install linters
|
| 28 |
+
make install
|
| 29 |
+
```
|
| 30 |
+
|
| 31 |
+
## Code Style and Formatting
|
| 32 |
+
|
| 33 |
+
FastGen uses **Ruff** for linting and code formatting. Follow existing conventions when adding new code.
|
| 34 |
+
|
| 35 |
+
```bash
|
| 36 |
+
make format # Auto-format code
|
| 37 |
+
make lint # Check compliance
|
| 38 |
+
```
|
| 39 |
+
|
| 40 |
+
**Note:** The `fastgen/third_party/` directory is excluded from checks.
|
| 41 |
+
|
| 42 |
+
## Continuous Integration
|
| 43 |
+
|
| 44 |
+
The GitLab CI pipeline runs automatically on every push and merge request:
|
| 45 |
+
|
| 46 |
+
| Stage | Commands |
|
| 47 |
+
|-------|----------|
|
| 48 |
+
| Lint | `make lint`, `make mypy` |
|
| 49 |
+
| Test | `make pytest` |
|
| 50 |
+
| Install | `make install-fastgen` |
|
| 51 |
+
|
| 52 |
+
**Note:** Before submitting a pull request, ensure all checks pass locally.
|
| 53 |
+
## Pull Request Process
|
| 54 |
+
|
| 55 |
+
1. Fork the repository (for external contributors) or create a feature branch.
|
| 56 |
+
2. Make your changes following the code style guidelines.
|
| 57 |
+
3. Run all checks locally (see [CI section](#continuous-integration)).
|
| 58 |
+
4. Commit your changes with sign-off (PRs with unsigned commits will not be accepted):
|
| 59 |
+
```bash
|
| 60 |
+
git commit -s -m "Add feature X"
|
| 61 |
+
```
|
| 62 |
+
5. Push and create a Pull Request.
|
| 63 |
+
|
| 64 |
+
**Note:** Keep PRs focused on a single concern and reference related issues (e.g., "Fixes #123").
|
| 65 |
+
|
| 66 |
+
|
| 67 |
+
## Developer Certificate of Origin
|
| 68 |
+
|
| 69 |
+
```
|
| 70 |
+
Developer Certificate of Origin
|
| 71 |
+
Version 1.1
|
| 72 |
+
|
| 73 |
+
Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
|
| 74 |
+
1 Letterman Drive
|
| 75 |
+
Suite D4700
|
| 76 |
+
San Francisco, CA, 94129
|
| 77 |
+
|
| 78 |
+
Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed.
|
| 79 |
+
```
|
| 80 |
+
|
| 81 |
+
```
|
| 82 |
+
Developer's Certificate of Origin 1.1
|
| 83 |
+
|
| 84 |
+
By making a contribution to this project, I certify that:
|
| 85 |
+
|
| 86 |
+
(a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or
|
| 87 |
+
|
| 88 |
+
(b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or
|
| 89 |
+
|
| 90 |
+
(c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it.
|
| 91 |
+
|
| 92 |
+
(d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved.
|
| 93 |
+
```
|
FastGen/Dockerfile
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
FROM nvcr.io/nvidia/pytorch:25.06-py3
|
| 2 |
+
|
| 3 |
+
# Defines the architecture (amd64 or arm64) automatically during build
|
| 4 |
+
ARG TARGETARCH
|
| 5 |
+
|
| 6 |
+
ENV PYTHONDONTWRITEBYTECODE=1
|
| 7 |
+
ENV PYTHONUNBUFFERED=1
|
| 8 |
+
ENV DEBIAN_FRONTEND=noninteractive
|
| 9 |
+
|
| 10 |
+
# Set the timezone
|
| 11 |
+
ENV TZ=America/Los_Angeles
|
| 12 |
+
RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
|
| 13 |
+
|
| 14 |
+
# Install system dependencies
|
| 15 |
+
RUN apt-get update --fix-missing && \
|
| 16 |
+
apt-get install -y --no-install-recommends \
|
| 17 |
+
build-essential \
|
| 18 |
+
ninja-build \
|
| 19 |
+
cmake \
|
| 20 |
+
wget \
|
| 21 |
+
ca-certificates \
|
| 22 |
+
git \
|
| 23 |
+
curl \
|
| 24 |
+
unzip \
|
| 25 |
+
python3-tk \
|
| 26 |
+
&& rm -rf /var/lib/apt/lists/* \
|
| 27 |
+
&& cmake --version \
|
| 28 |
+
&& echo "✅ All system dependencies installed successfully"
|
| 29 |
+
|
| 30 |
+
# Install s5cmd with architecture logic
|
| 31 |
+
RUN S5CMD_VERSION="2.1.0-beta.1" && \
|
| 32 |
+
if [ "$TARGETARCH" = "arm64" ]; then \
|
| 33 |
+
S5CMD_ARCH="Linux-arm64"; \
|
| 34 |
+
else \
|
| 35 |
+
S5CMD_ARCH="Linux-64bit"; \
|
| 36 |
+
fi && \
|
| 37 |
+
wget "https://github.com/peak/s5cmd/releases/download/v${S5CMD_VERSION}/s5cmd_${S5CMD_VERSION}_${S5CMD_ARCH}.tar.gz" && \
|
| 38 |
+
tar -xf "s5cmd_${S5CMD_VERSION}_${S5CMD_ARCH}.tar.gz" && \
|
| 39 |
+
install s5cmd /usr/local/bin/ && \
|
| 40 |
+
rm s5cmd "s5cmd_${S5CMD_VERSION}_${S5CMD_ARCH}.tar.gz"
|
| 41 |
+
|
| 42 |
+
# Set build optimization flags
|
| 43 |
+
ENV MAX_JOBS=4
|
| 44 |
+
ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0;10.0"
|
| 45 |
+
|
| 46 |
+
# Upgrade pip
|
| 47 |
+
RUN pip install --upgrade pip setuptools wheel
|
| 48 |
+
# numpy<2.0.0 is required for the preinstalled numba and thinc
|
| 49 |
+
RUN pip install wandb[media] \
|
| 50 |
+
diffusers==0.35.1 \
|
| 51 |
+
transformers==4.49.0 \
|
| 52 |
+
accelerate \
|
| 53 |
+
safetensors \
|
| 54 |
+
scipy \
|
| 55 |
+
einops \
|
| 56 |
+
jupyter \
|
| 57 |
+
hydra-core \
|
| 58 |
+
imageio \
|
| 59 |
+
tqdm \
|
| 60 |
+
webdataset \
|
| 61 |
+
av \
|
| 62 |
+
loguru \
|
| 63 |
+
boto3 \
|
| 64 |
+
timm \
|
| 65 |
+
ftfy \
|
| 66 |
+
opencv-python-headless \
|
| 67 |
+
sentencepiece \
|
| 68 |
+
"numpy<2.0.0"
|
| 69 |
+
RUN pip uninstall -y apex
|
| 70 |
+
RUN pip install --index-url=https://sc-hw-artf.nvidia.com/artifactory/api/pypi/hwinf-mlwfo-pypi/simple --upgrade one-logger-utils
|
| 71 |
+
|
| 72 |
+
WORKDIR /workspace
|
| 73 |
+
|
| 74 |
+
# Entry point
|
| 75 |
+
RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
|
| 76 |
+
ENTRYPOINT ["/entry.sh"]
|
| 77 |
+
|
FastGen/LICENSE
ADDED
|
@@ -0,0 +1,202 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
|
| 2 |
+
Apache License
|
| 3 |
+
Version 2.0, January 2004
|
| 4 |
+
http://www.apache.org/licenses/
|
| 5 |
+
|
| 6 |
+
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
| 7 |
+
|
| 8 |
+
1. Definitions.
|
| 9 |
+
|
| 10 |
+
"License" shall mean the terms and conditions for use, reproduction,
|
| 11 |
+
and distribution as defined by Sections 1 through 9 of this document.
|
| 12 |
+
|
| 13 |
+
"Licensor" shall mean the copyright owner or entity authorized by
|
| 14 |
+
the copyright owner that is granting the License.
|
| 15 |
+
|
| 16 |
+
"Legal Entity" shall mean the union of the acting entity and all
|
| 17 |
+
other entities that control, are controlled by, or are under common
|
| 18 |
+
control with that entity. For the purposes of this definition,
|
| 19 |
+
"control" means (i) the power, direct or indirect, to cause the
|
| 20 |
+
direction or management of such entity, whether by contract or
|
| 21 |
+
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
| 22 |
+
outstanding shares, or (iii) beneficial ownership of such entity.
|
| 23 |
+
|
| 24 |
+
"You" (or "Your") shall mean an individual or Legal Entity
|
| 25 |
+
exercising permissions granted by this License.
|
| 26 |
+
|
| 27 |
+
"Source" form shall mean the preferred form for making modifications,
|
| 28 |
+
including but not limited to software source code, documentation
|
| 29 |
+
source, and configuration files.
|
| 30 |
+
|
| 31 |
+
"Object" form shall mean any form resulting from mechanical
|
| 32 |
+
transformation or translation of a Source form, including but
|
| 33 |
+
not limited to compiled object code, generated documentation,
|
| 34 |
+
and conversions to other media types.
|
| 35 |
+
|
| 36 |
+
"Work" shall mean the work of authorship, whether in Source or
|
| 37 |
+
Object form, made available under the License, as indicated by a
|
| 38 |
+
copyright notice that is included in or attached to the work
|
| 39 |
+
(an example is provided in the Appendix below).
|
| 40 |
+
|
| 41 |
+
"Derivative Works" shall mean any work, whether in Source or Object
|
| 42 |
+
form, that is based on (or derived from) the Work and for which the
|
| 43 |
+
editorial revisions, annotations, elaborations, or other modifications
|
| 44 |
+
represent, as a whole, an original work of authorship. For the purposes
|
| 45 |
+
of this License, Derivative Works shall not include works that remain
|
| 46 |
+
separable from, or merely link (or bind by name) to the interfaces of,
|
| 47 |
+
the Work and Derivative Works thereof.
|
| 48 |
+
|
| 49 |
+
"Contribution" shall mean any work of authorship, including
|
| 50 |
+
the original version of the Work and any modifications or additions
|
| 51 |
+
to that Work or Derivative Works thereof, that is intentionally
|
| 52 |
+
submitted to Licensor for inclusion in the Work by the copyright owner
|
| 53 |
+
or by an individual or Legal Entity authorized to submit on behalf of
|
| 54 |
+
the copyright owner. For the purposes of this definition, "submitted"
|
| 55 |
+
means any form of electronic, verbal, or written communication sent
|
| 56 |
+
to the Licensor or its representatives, including but not limited to
|
| 57 |
+
communication on electronic mailing lists, source code control systems,
|
| 58 |
+
and issue tracking systems that are managed by, or on behalf of, the
|
| 59 |
+
Licensor for the purpose of discussing and improving the Work, but
|
| 60 |
+
excluding communication that is conspicuously marked or otherwise
|
| 61 |
+
designated in writing by the copyright owner as "Not a Contribution."
|
| 62 |
+
|
| 63 |
+
"Contributor" shall mean Licensor and any individual or Legal Entity
|
| 64 |
+
on behalf of whom a Contribution has been received by Licensor and
|
| 65 |
+
subsequently incorporated within the Work.
|
| 66 |
+
|
| 67 |
+
2. Grant of Copyright License. Subject to the terms and conditions of
|
| 68 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 69 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 70 |
+
copyright license to reproduce, prepare Derivative Works of,
|
| 71 |
+
publicly display, publicly perform, sublicense, and distribute the
|
| 72 |
+
Work and such Derivative Works in Source or Object form.
|
| 73 |
+
|
| 74 |
+
3. Grant of Patent License. Subject to the terms and conditions of
|
| 75 |
+
this License, each Contributor hereby grants to You a perpetual,
|
| 76 |
+
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
| 77 |
+
(except as stated in this section) patent license to make, have made,
|
| 78 |
+
use, offer to sell, sell, import, and otherwise transfer the Work,
|
| 79 |
+
where such license applies only to those patent claims licensable
|
| 80 |
+
by such Contributor that are necessarily infringed by their
|
| 81 |
+
Contribution(s) alone or by combination of their Contribution(s)
|
| 82 |
+
with the Work to which such Contribution(s) was submitted. If You
|
| 83 |
+
institute patent litigation against any entity (including a
|
| 84 |
+
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
| 85 |
+
or a Contribution incorporated within the Work constitutes direct
|
| 86 |
+
or contributory patent infringement, then any patent licenses
|
| 87 |
+
granted to You under this License for that Work shall terminate
|
| 88 |
+
as of the date such litigation is filed.
|
| 89 |
+
|
| 90 |
+
4. Redistribution. You may reproduce and distribute copies of the
|
| 91 |
+
Work or Derivative Works thereof in any medium, with or without
|
| 92 |
+
modifications, and in Source or Object form, provided that You
|
| 93 |
+
meet the following conditions:
|
| 94 |
+
|
| 95 |
+
(a) You must give any other recipients of the Work or
|
| 96 |
+
Derivative Works a copy of this License; and
|
| 97 |
+
|
| 98 |
+
(b) You must cause any modified files to carry prominent notices
|
| 99 |
+
stating that You changed the files; and
|
| 100 |
+
|
| 101 |
+
(c) You must retain, in the Source form of any Derivative Works
|
| 102 |
+
that You distribute, all copyright, patent, trademark, and
|
| 103 |
+
attribution notices from the Source form of the Work,
|
| 104 |
+
excluding those notices that do not pertain to any part of
|
| 105 |
+
the Derivative Works; and
|
| 106 |
+
|
| 107 |
+
(d) If the Work includes a "NOTICE" text file as part of its
|
| 108 |
+
distribution, then any Derivative Works that You distribute must
|
| 109 |
+
include a readable copy of the attribution notices contained
|
| 110 |
+
within such NOTICE file, excluding those notices that do not
|
| 111 |
+
pertain to any part of the Derivative Works, in at least one
|
| 112 |
+
of the following places: within a NOTICE text file distributed
|
| 113 |
+
as part of the Derivative Works; within the Source form or
|
| 114 |
+
documentation, if provided along with the Derivative Works; or,
|
| 115 |
+
within a display generated by the Derivative Works, if and
|
| 116 |
+
wherever such third-party notices normally appear. The contents
|
| 117 |
+
of the NOTICE file are for informational purposes only and
|
| 118 |
+
do not modify the License. You may add Your own attribution
|
| 119 |
+
notices within Derivative Works that You distribute, alongside
|
| 120 |
+
or as an addendum to the NOTICE text from the Work, provided
|
| 121 |
+
that such additional attribution notices cannot be construed
|
| 122 |
+
as modifying the License.
|
| 123 |
+
|
| 124 |
+
You may add Your own copyright statement to Your modifications and
|
| 125 |
+
may provide additional or different license terms and conditions
|
| 126 |
+
for use, reproduction, or distribution of Your modifications, or
|
| 127 |
+
for any such Derivative Works as a whole, provided Your use,
|
| 128 |
+
reproduction, and distribution of the Work otherwise complies with
|
| 129 |
+
the conditions stated in this License.
|
| 130 |
+
|
| 131 |
+
5. Submission of Contributions. Unless You explicitly state otherwise,
|
| 132 |
+
any Contribution intentionally submitted for inclusion in the Work
|
| 133 |
+
by You to the Licensor shall be under the terms and conditions of
|
| 134 |
+
this License, without any additional terms or conditions.
|
| 135 |
+
Notwithstanding the above, nothing herein shall supersede or modify
|
| 136 |
+
the terms of any separate license agreement you may have executed
|
| 137 |
+
with Licensor regarding such Contributions.
|
| 138 |
+
|
| 139 |
+
6. Trademarks. This License does not grant permission to use the trade
|
| 140 |
+
names, trademarks, service marks, or product names of the Licensor,
|
| 141 |
+
except as required for reasonable and customary use in describing the
|
| 142 |
+
origin of the Work and reproducing the content of the NOTICE file.
|
| 143 |
+
|
| 144 |
+
7. Disclaimer of Warranty. Unless required by applicable law or
|
| 145 |
+
agreed to in writing, Licensor provides the Work (and each
|
| 146 |
+
Contributor provides its Contributions) on an "AS IS" BASIS,
|
| 147 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
| 148 |
+
implied, including, without limitation, any warranties or conditions
|
| 149 |
+
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
| 150 |
+
PARTICULAR PURPOSE. You are solely responsible for determining the
|
| 151 |
+
appropriateness of using or redistributing the Work and assume any
|
| 152 |
+
risks associated with Your exercise of permissions under this License.
|
| 153 |
+
|
| 154 |
+
8. Limitation of Liability. In no event and under no legal theory,
|
| 155 |
+
whether in tort (including negligence), contract, or otherwise,
|
| 156 |
+
unless required by applicable law (such as deliberate and grossly
|
| 157 |
+
negligent acts) or agreed to in writing, shall any Contributor be
|
| 158 |
+
liable to You for damages, including any direct, indirect, special,
|
| 159 |
+
incidental, or consequential damages of any character arising as a
|
| 160 |
+
result of this License or out of the use or inability to use the
|
| 161 |
+
Work (including but not limited to damages for loss of goodwill,
|
| 162 |
+
work stoppage, computer failure or malfunction, or any and all
|
| 163 |
+
other commercial damages or losses), even if such Contributor
|
| 164 |
+
has been advised of the possibility of such damages.
|
| 165 |
+
|
| 166 |
+
9. Accepting Warranty or Additional Liability. While redistributing
|
| 167 |
+
the Work or Derivative Works thereof, You may choose to offer,
|
| 168 |
+
and charge a fee for, acceptance of support, warranty, indemnity,
|
| 169 |
+
or other liability obligations and/or rights consistent with this
|
| 170 |
+
License. However, in accepting such obligations, You may act only
|
| 171 |
+
on Your own behalf and on Your sole responsibility, not on behalf
|
| 172 |
+
of any other Contributor, and only if You agree to indemnify,
|
| 173 |
+
defend, and hold each Contributor harmless for any liability
|
| 174 |
+
incurred by, or claims asserted against, such Contributor by reason
|
| 175 |
+
of your accepting any such warranty or additional liability.
|
| 176 |
+
|
| 177 |
+
END OF TERMS AND CONDITIONS
|
| 178 |
+
|
| 179 |
+
APPENDIX: How to apply the Apache License to your work.
|
| 180 |
+
|
| 181 |
+
To apply the Apache License to your work, attach the following
|
| 182 |
+
boilerplate notice, with the fields enclosed by brackets "[]"
|
| 183 |
+
replaced with your own identifying information. (Don't include
|
| 184 |
+
the brackets!) The text should be enclosed in the appropriate
|
| 185 |
+
comment syntax for the file format. We also recommend that a
|
| 186 |
+
file or class name and description of purpose be included on the
|
| 187 |
+
same "printed page" as the copyright notice for easier
|
| 188 |
+
identification within third-party archives.
|
| 189 |
+
|
| 190 |
+
Copyright 2026 NVIDIA CORPORATION & AFFILIATES
|
| 191 |
+
|
| 192 |
+
Licensed under the Apache License, Version 2.0 (the "License");
|
| 193 |
+
you may not use this file except in compliance with the License.
|
| 194 |
+
You may obtain a copy of the License at
|
| 195 |
+
|
| 196 |
+
http://www.apache.org/licenses/LICENSE-2.0
|
| 197 |
+
|
| 198 |
+
Unless required by applicable law or agreed to in writing, software
|
| 199 |
+
distributed under the License is distributed on an "AS IS" BASIS,
|
| 200 |
+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 201 |
+
See the License for the specific language governing permissions and
|
| 202 |
+
limitations under the License.
|
FastGen/Makefile
ADDED
|
@@ -0,0 +1,29 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
sinclude .env
|
| 2 |
+
|
| 3 |
+
all: help
|
| 4 |
+
|
| 5 |
+
install-linters: ## install the linters
|
| 6 |
+
pip install black==23.10.0 ruff==0.6.9 mypy==1.9.0 types-psutil
|
| 7 |
+
|
| 8 |
+
install: install-linters
|
| 9 |
+
|
| 10 |
+
mypy:
|
| 11 |
+
python3 -m mypy --check-untyped-defs --follow-imports=silent --exclude third_party/ train.py
|
| 12 |
+
|
| 13 |
+
lint:
|
| 14 |
+
python3 -m ruff format --check --exclude fastgen/third_party/
|
| 15 |
+
python3 -m ruff check --exclude fastgen/third_party/
|
| 16 |
+
|
| 17 |
+
format:
|
| 18 |
+
python3 -m ruff format --exclude fastgen/third_party/
|
| 19 |
+
python3 -m ruff check --fix --exclude fastgen/third_party/
|
| 20 |
+
|
| 21 |
+
install-fastgen:
|
| 22 |
+
python3 -m pip install -e .
|
| 23 |
+
|
| 24 |
+
pytest:
|
| 25 |
+
ulimit -n 4096 && python3 -m pytest --ignore=FASTGEN_OUTPUT --ignore=runs --ignore=tmp --ignore third_party
|
| 26 |
+
|
| 27 |
+
.PHONY: help
|
| 28 |
+
help:
|
| 29 |
+
@grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
|
FastGen/README.md
ADDED
|
@@ -0,0 +1,182 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
<h1 align="center">NVIDIA FastGen: Fast Generation from Diffusion Models</h1>
|
| 2 |
+
|
| 3 |
+
[](LICENSE)
|
| 4 |
+
[](https://pytorch.org/)
|
| 5 |
+
[](https://www.nvidia.com/)
|
| 6 |
+
|
| 7 |
+
<p align="center">
|
| 8 |
+
<b></b> <a href="https://weilinie.github.io/">Weili Nie</a> • <a href="https://jberner.info/">Julius Berner</a> • <a href="https://research.nvidia.com/labs/genair/author/chao-liu/">Chao Liu</a> • <a href="http://latentspace.cc/">Arash Vahdat</a>
|
| 9 |
+
</p>
|
| 10 |
+
|
| 11 |
+
<p align="center">
|
| 12 |
+
<img src="assets/teaser.png" alt="FastGen header" width=100%/>
|
| 13 |
+
</p>
|
| 14 |
+
|
| 15 |
+
FastGen is a PyTorch-based framework for building fast generative models using various distillation and acceleration techniques. It supports:
|
| 16 |
+
- large-scale training with ≥10B parameters.
|
| 17 |
+
- different tasks and modalities, including T2I, I2V, and V2V.
|
| 18 |
+
- various distillation methods, including consistency models, distribution matching distillation, self-forcing, and more.
|
| 19 |
+
|
| 20 |
+
## Repository Structure
|
| 21 |
+
|
| 22 |
+
```
|
| 23 |
+
fastgen/
|
| 24 |
+
├── fastgen/
|
| 25 |
+
│ ├── callbacks/ # Training callbacks (EMA, profiling, etc.)
|
| 26 |
+
│ ├── configs/ # Configuration system
|
| 27 |
+
│ │ ├── experiments/ # Experiment configs
|
| 28 |
+
│ │ └── methods/ # Method-specific configs
|
| 29 |
+
│ ├── datasets/ # Dataset loaders
|
| 30 |
+
│ ├── methods/ # Training methods (CM, DMD2, SFT, KD etc.)
|
| 31 |
+
│ ├── networks/ # Neural network architectures
|
| 32 |
+
│ ├── third_party/ # Third-party dependencies
|
| 33 |
+
│ ├── trainer.py # Main training loop
|
| 34 |
+
│ └── utils/ # Utilities (distributed, checkpointing)
|
| 35 |
+
├── scripts/ # Inference and evaluation scripts
|
| 36 |
+
├── tests/ # Unit tests
|
| 37 |
+
├── Makefile # Development commands (lint, format, test)
|
| 38 |
+
└── train.py # Main training entry point
|
| 39 |
+
```
|
| 40 |
+
|
| 41 |
+
## Setup
|
| 42 |
+
|
| 43 |
+
**Recommended:** Use the provided Docker container for a consistent environment. See [CONTRIBUTING.md](CONTRIBUTING.md) for Docker setup instructions. Otherwise, create a new [conda](https://www.anaconda.com/docs/getting-started/miniconda/install) environment with `conda create -y -n fastgen python=3.12.3 pip; conda activate fastgen`.
|
| 44 |
+
|
| 45 |
+
### Installation
|
| 46 |
+
|
| 47 |
+
```bash
|
| 48 |
+
git clone https://github.com/NVlabs/FastGen.git
|
| 49 |
+
cd FastGen
|
| 50 |
+
pip install -e .
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
### Credentials (Optional)
|
| 54 |
+
|
| 55 |
+
For W&B logging, [get your API key](https://wandb.ai/settings) and save it to `credentials/wandb_api.txt` or set the `WANDB_API_KEY` environment variable.
|
| 56 |
+
Without either of these, W&B will prompt for your API key interactively.
|
| 57 |
+
For more details, including S3 storage and other environment variables, see [fastgen/configs/README.md](fastgen/configs/README.md#environment-variables).
|
| 58 |
+
|
| 59 |
+
## Quick Start
|
| 60 |
+
|
| 61 |
+
Before running the following commands, download the CIFAR-10 dataset and pretrained EDM models:
|
| 62 |
+
|
| 63 |
+
```bash
|
| 64 |
+
python scripts/download_data.py --dataset cifar10
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
For other datasets and models, see [fastgen/networks/README.md](fastgen/networks/README.md) and [fastgen/datasets/README.md](fastgen/datasets/README.md).
|
| 68 |
+
|
| 69 |
+
### Basic Training
|
| 70 |
+
|
| 71 |
+
```bash
|
| 72 |
+
python train.py --config=fastgen/configs/experiments/EDM/config_dmd2_test.py
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
If you run out-of-memory, try a smaller batch-size, e.g., `dataloader_train.batch_size=32`, which automatically uses gradient accumulation to match the global batch-size.
|
| 76 |
+
|
| 77 |
+
**Expected Output:** See the training log for a link to the run on [wandb.ai](https://wandb.ai). Training outputs go to `$FASTGEN_OUTPUT_ROOT/{project}/{group}/{name}/`. With default settings, outputs are organized as follows:
|
| 78 |
+
```
|
| 79 |
+
FASTGEN_OUTPUT/fastgen/cifar10/debug/
|
| 80 |
+
├── checkpoints/ # Model checkpoints in the format {iteration:07d}.pth
|
| 81 |
+
│ ├── 0001000.pth
|
| 82 |
+
│ └── ...
|
| 83 |
+
├── config.yaml # Resolved configuration for reproducibility
|
| 84 |
+
├── wandb_id.txt # W&B run ID for resuming
|
| 85 |
+
└── ...
|
| 86 |
+
```
|
| 87 |
+
|
| 88 |
+
### DDP/FSDP2 Training
|
| 89 |
+
|
| 90 |
+
For multi-GPU training, use DDP:
|
| 91 |
+
|
| 92 |
+
```bash
|
| 93 |
+
torchrun --nproc_per_node=8 train.py \
|
| 94 |
+
--config=fastgen/configs/experiments/EDM/config_dmd2_test.py \
|
| 95 |
+
- trainer.ddp=True log_config.name=test_ddp
|
| 96 |
+
```
|
| 97 |
+
|
| 98 |
+
For large models, use FSDP2 for model sharding by replacing `trainer.ddp=True` with `trainer.fsdp=True`.
|
| 99 |
+
|
| 100 |
+
|
| 101 |
+
### Inference
|
| 102 |
+
|
| 103 |
+
```bash
|
| 104 |
+
python scripts/inference/image_model_inference.py --config fastgen/configs/experiments/EDM/config_dmd2_test.py \
|
| 105 |
+
--classes=10 --prompt_file=scripts/inference/prompts/classes.txt --ckpt=FASTGEN_OUTPUT/fastgen/cifar10/debug/checkpoints/0002000.pth - log_config.name=test_inference
|
| 106 |
+
```
|
| 107 |
+
|
| 108 |
+
For other inferences modes and FID evaluations, see [scripts/README.md](scripts/README.md).
|
| 109 |
+
|
| 110 |
+
|
| 111 |
+
### Command-Line Overrides
|
| 112 |
+
|
| 113 |
+
Override any config parameter using Hydra-style syntax (note the `-` separator):
|
| 114 |
+
|
| 115 |
+
```bash
|
| 116 |
+
python train.py --config=path/to/config.py - key=value nested.key=value
|
| 117 |
+
```
|
| 118 |
+
|
| 119 |
+
## Documentation
|
| 120 |
+
|
| 121 |
+
Detailed documentation is available in each component's README:
|
| 122 |
+
|
| 123 |
+
| Component | Documentation | Description |
|
| 124 |
+
|-----------|---------------|-------------|
|
| 125 |
+
| **Methods** | [fastgen/methods/README.md](fastgen/methods/README.md) | Training methods (sCM, MeanFlow, DMD2, Self-Forcing, etc.) |
|
| 126 |
+
| **Networks** | [fastgen/networks/README.md](fastgen/networks/README.md) | Network architectures (EDM, SD, SDXL, Flux, WAN, CogVideoX, Cosmos) and pretrained models |
|
| 127 |
+
| **Configs** | [fastgen/configs/README.md](fastgen/configs/README.md) | Configuration system, environment variables, and creating custom configs |
|
| 128 |
+
| **Datasets** | [fastgen/datasets/README.md](fastgen/datasets/README.md) | Dataset preparation and WebDataset loaders |
|
| 129 |
+
| **Callbacks** | [fastgen/callbacks/README.md](fastgen/callbacks/README.md) | Training callbacks (EMA, logging, gradient clipping, etc.) |
|
| 130 |
+
| **Inference** | [scripts/README.md](scripts/README.md) | Inference modes (T2I, T2V, I2V, V2V, etc.) and FID evaluation |
|
| 131 |
+
| **Third Party** | [fastgen/third_party/README.md](fastgen/third_party/README.md) | Third-party dependencies (Depth Anything V2, etc.) |
|
| 132 |
+
|
| 133 |
+
## Supported Methods
|
| 134 |
+
|
| 135 |
+
| Category | Methods |
|
| 136 |
+
|----------|---------|
|
| 137 |
+
| **Consistency Models** | [CM](https://arxiv.org/abs/2303.01469), [sCM](https://arxiv.org/abs/2410.11081), [TCM](https://arxiv.org/abs/2410.14895), [MeanFlow](https://arxiv.org/abs/2505.13447) |
|
| 138 |
+
| **Distribution Matching** | [DMD2](https://arxiv.org/abs/2405.14867), [f-Distill](https://arxiv.org/abs/2502.15681), [LADD](https://arxiv.org/abs/2403.12015), [CausVid](https://arxiv.org/abs/2412.07772), [Self-Forcing](https://arxiv.org/abs/2506.08009) |
|
| 139 |
+
| **Fine-Tuning** | [SFT](https://arxiv.org/abs/2006.11239), [CausalSFT](https://arxiv.org/abs/2407.01392) |
|
| 140 |
+
| **Knowledge Distillation** | [KD](https://arxiv.org/abs/2101.02388), [CausalKD](https://arxiv.org/abs/2412.07772) |
|
| 141 |
+
|
| 142 |
+
See [fastgen/methods/README.md](fastgen/methods/README.md) for details.
|
| 143 |
+
|
| 144 |
+
## Supported Networks and Data
|
| 145 |
+
|
| 146 |
+
FastGen is designed to be **agnostic to the network and data** and you can add your own architectures and datasets (see [fastgen/networks/README.md](fastgen/networks/README.md) and [fastgen/datasets/README.md](fastgen/datasets/README.md)). For reference, we provide the following implementations:
|
| 147 |
+
|
| 148 |
+
| Data | Networks |
|
| 149 |
+
|------|----------|
|
| 150 |
+
| **Image** | EDM, EDM2, DiT, SD 1.5, SDXL, Flux |
|
| 151 |
+
| **Video** | WAN (T2V, I2V, VACE), CogVideoX, Cosmos Predict2 |
|
| 152 |
+
|
| 153 |
+
See [fastgen/networks/README.md](fastgen/networks/README.md) for details.
|
| 154 |
+
Not all combinations of methods and networks are currently supported. We provide typical use-cases in our predefined configs in [fastgen/configs/experiments](fastgen/configs/experiments).
|
| 155 |
+
|
| 156 |
+
**We plan to provide distilled student checkpoints for CIFAR-10 and ImageNet soon.**
|
| 157 |
+
|
| 158 |
+
## Contributing
|
| 159 |
+
|
| 160 |
+
We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details.
|
| 161 |
+
|
| 162 |
+
We thank everyone who has helped design, build, and test FastGen!
|
| 163 |
+
|
| 164 |
+
- **Core contributors:** Weili Nie, Julius Berner, Chao Liu
|
| 165 |
+
- **Other contributors:** James Lucas, David Pankratz, Sihyun Yu, Willis Ma, Yilun Xu, Shengqu Cai, Xinyin Ma, Yanke Song
|
| 166 |
+
- **Collaborators:** Sophia Zalewski, Wei Xiong, Christian Laforte, Sajad Norouzi, Kaiwen Zheng, Miloš Hašan, Saeed Hadadan, Gene Liu, David Dynerman, Grace Lam, Pooya Jannaty, Jan Kautz, and many more.
|
| 167 |
+
- **Project lead:** Arash Vahdat
|
| 168 |
+
|
| 169 |
+
## License
|
| 170 |
+
|
| 171 |
+
This project is licensed under the Apache License 2.0 - see [LICENSE](LICENSE) for details. Third-party licenses are documented in [licenses/README.md](licenses/README.md).
|
| 172 |
+
|
| 173 |
+
## Reference
|
| 174 |
+
|
| 175 |
+
```
|
| 176 |
+
@article{fastgen2026,
|
| 177 |
+
title={NVIDIA FastGen: Fast Generation from Diffusion Models},
|
| 178 |
+
author={Nie, Weili and Berner, Julius and Liu, Chao and Vahdat, Arash},
|
| 179 |
+
url={https://github.com/NVlabs/FastGen},
|
| 180 |
+
year={2026},
|
| 181 |
+
}
|
| 182 |
+
```
|
FastGen/SF.md
ADDED
|
@@ -0,0 +1,146 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Self-Forcing: Quick Reference Guide
|
| 2 |
+
|
| 3 |
+
Self-Forcing is a distribution matching distillation method for causal/autoregressive video generation. It trains a fast student model to match a teacher model's output distribution while maintaining gradient flow through an autoregressive rollout process.
|
| 4 |
+
|
| 5 |
+
**Reference**: [Huang et al., 2025](https://arxiv.org/abs/2506.08009)
|
| 6 |
+
|
| 7 |
+
---
|
| 8 |
+
|
| 9 |
+
## Environment Setup
|
| 10 |
+
|
| 11 |
+
```bash
|
| 12 |
+
conda activate sf
|
| 13 |
+
pip install -e .
|
| 14 |
+
```
|
| 15 |
+
|
| 16 |
+
---
|
| 17 |
+
|
| 18 |
+
## Quick Start
|
| 19 |
+
|
| 20 |
+
### 1. Download Checkpoint
|
| 21 |
+
|
| 22 |
+
```bash
|
| 23 |
+
huggingface-cli download gdhe17/Self-Forcing checkpoints/ode_init.pt \
|
| 24 |
+
--local-dir FASTGEN_OUTPUT/MODEL/Self-Forcing
|
| 25 |
+
```
|
| 26 |
+
|
| 27 |
+
### 2. Run Training
|
| 28 |
+
|
| 29 |
+
#### Data-Free Training (Recommended for testing)
|
| 30 |
+
|
| 31 |
+
Uses random tensors instead of real video data. GAN discriminator is disabled.
|
| 32 |
+
|
| 33 |
+
**Single GPU:**
|
| 34 |
+
```bash
|
| 35 |
+
python train.py --config=fastgen/configs/experiments/WanT2V/config_sf_datafree.py
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
**Single GPU (specific device):**
|
| 39 |
+
```bash
|
| 40 |
+
CUDA_VISIBLE_DEVICES=0 python train.py --config=fastgen/configs/experiments/WanT2V/config_sf_datafree.py
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
**Multi-GPU (8 GPUs):**
|
| 44 |
+
```bash
|
| 45 |
+
torchrun --nproc_per_node=8 train.py \
|
| 46 |
+
--config=fastgen/configs/experiments/WanT2V/config_sf_datafree.py \
|
| 47 |
+
- trainer.ddp=True
|
| 48 |
+
```
|
| 49 |
+
|
| 50 |
+
#### Training with Real Video Data
|
| 51 |
+
|
| 52 |
+
For best quality, use real video data with GAN discriminator enabled.
|
| 53 |
+
|
| 54 |
+
Prepare WebDataset shards (see `fastgen/datasets/README.md`):
|
| 55 |
+
```
|
| 56 |
+
/path/to/videos/
|
| 57 |
+
├── 00000.tar (sample_001.mp4, sample_001.txt, ...)
|
| 58 |
+
├── 00001.tar
|
| 59 |
+
└── ...
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
**Single GPU:**
|
| 63 |
+
```bash
|
| 64 |
+
python train.py --config=fastgen/configs/experiments/WanT2V/config_sf.py \
|
| 65 |
+
- dataloader_train.datatags='["WDS:/path/to/your/videos"]'
|
| 66 |
+
```
|
| 67 |
+
|
| 68 |
+
**Multi-GPU (8 GPUs):**
|
| 69 |
+
```bash
|
| 70 |
+
torchrun --nproc_per_node=8 train.py \
|
| 71 |
+
--config=fastgen/configs/experiments/WanT2V/config_sf.py \
|
| 72 |
+
- trainer.ddp=True dataloader_train.datatags='["WDS:/path/to/your/videos"]'
|
| 73 |
+
```
|
| 74 |
+
|
| 75 |
+
---
|
| 76 |
+
|
| 77 |
+
## Available Configs
|
| 78 |
+
|
| 79 |
+
| Model | Config Path | Notes |
|
| 80 |
+
|-------|-------------|-------|
|
| 81 |
+
| WAN T2V (light) | `fastgen/configs/experiments/WanT2V/config_sf_datafree_light.py` | Data-free, reduced resolution (~8-12GB VRAM) |
|
| 82 |
+
| WAN T2V (data-free) | `fastgen/configs/experiments/WanT2V/config_sf_datafree.py` | Data-free, full 480p (~24GB VRAM) |
|
| 83 |
+
| WAN T2V (with data) | `fastgen/configs/experiments/WanT2V/config_sf.py` | Requires video data, GAN enabled |
|
| 84 |
+
| VACE-WAN V2V | `fastgen/configs/experiments/WanV2V/config_sf.py` | Requires video data |
|
| 85 |
+
|
| 86 |
+
---
|
| 87 |
+
|
| 88 |
+
## Key Configuration Parameters
|
| 89 |
+
|
| 90 |
+
### Self-Forcing Specific
|
| 91 |
+
|
| 92 |
+
| Parameter | Default | Description |
|
| 93 |
+
|-----------|---------|-------------|
|
| 94 |
+
| `enable_gradient_in_rollout` | `True` | Enable gradients at exit step during rollout |
|
| 95 |
+
| `start_gradient_frame` | `0` | Frame index to start gradient tracking |
|
| 96 |
+
| `same_step_across_blocks` | `True` | Use same exit step for all blocks |
|
| 97 |
+
| `last_step_only` | `False` | Exit only at the last denoising step |
|
| 98 |
+
| `context_noise` | `0.0` | Noise level added to cached context (0-1) |
|
| 99 |
+
|
| 100 |
+
### Training Settings
|
| 101 |
+
|
| 102 |
+
| Parameter | Default | Description |
|
| 103 |
+
|-----------|---------|-------------|
|
| 104 |
+
| `student_sample_steps` | `4` | Number of denoising steps |
|
| 105 |
+
| `student_update_freq` | `5` | Student update frequency |
|
| 106 |
+
| `gan_loss_weight_gen` | `0.001` | GAN loss weight for generator |
|
| 107 |
+
|
| 108 |
+
---
|
| 109 |
+
|
| 110 |
+
## Custom Training Example
|
| 111 |
+
|
| 112 |
+
```bash
|
| 113 |
+
python train.py --config=fastgen/configs/experiments/WanT2V/config_sf.py \
|
| 114 |
+
- model.gan_loss_weight_gen=0.005 \
|
| 115 |
+
model.student_sample_steps=6 \
|
| 116 |
+
model.context_noise=0.1 \
|
| 117 |
+
trainer.max_iter=10000 \
|
| 118 |
+
dataloader_train.batch_size=2
|
| 119 |
+
```
|
| 120 |
+
|
| 121 |
+
---
|
| 122 |
+
|
| 123 |
+
## How It Works
|
| 124 |
+
|
| 125 |
+
1. **Autoregressive Rollout**: Processes video chunk-by-chunk with KV-caching
|
| 126 |
+
2. **Gradient Tracking**: Enables gradients at stochastic exit steps during rollout
|
| 127 |
+
3. **Distribution Matching**: Combines VSD loss with GAN training
|
| 128 |
+
4. **Alternating Updates**:
|
| 129 |
+
- Student updates (VSD + GAN loss)
|
| 130 |
+
- Discriminator/fake score updates (adversarial training)
|
| 131 |
+
|
| 132 |
+
---
|
| 133 |
+
|
| 134 |
+
## Key Files
|
| 135 |
+
|
| 136 |
+
- **Method**: `fastgen/methods/distribution_matching/self_forcing.py`
|
| 137 |
+
- **Config Template**: `fastgen/configs/methods/config_self_forcing.py`
|
| 138 |
+
- **Tests**: `tests/test_sfmodel.py`
|
| 139 |
+
|
| 140 |
+
---
|
| 141 |
+
|
| 142 |
+
## Testing
|
| 143 |
+
|
| 144 |
+
```bash
|
| 145 |
+
pytest tests/test_sfmodel.py -v
|
| 146 |
+
```
|
FastGen/assets/teaser.png
ADDED
|
Git LFS Details
|
FastGen/fastgen/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
FastGen/fastgen/callbacks/README.md
ADDED
|
@@ -0,0 +1,59 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Callbacks
|
| 2 |
+
|
| 3 |
+
Training callbacks for FastGen. All callbacks inherit from `Callback` and hook into the training lifecycle.
|
| 4 |
+
|
| 5 |
+
## Available Callbacks
|
| 6 |
+
|
| 7 |
+
| Callback | Description |
|
| 8 |
+
|----------|-------------|
|
| 9 |
+
| [`EMACallback`](ema.py) | Exponential Moving Average updates (constant, power, or halflife schedules) |
|
| 10 |
+
| [`GradClipCallback`](grad_clip.py) | Gradient norm clipping with NaN/Inf handling |
|
| 11 |
+
| [`WandbCallback`](wandb.py) | Weights & Biases logging (metrics, images, videos) |
|
| 12 |
+
| [`GPUStatsCallback`](gpu_stats.py) | GPU memory and utilization logging |
|
| 13 |
+
| [`MemTrackerCallback`](gpu_mem_profiler.py) | GPU memory profiling with HTML visualizations |
|
| 14 |
+
| [`TrainProfilerCallback`](train_profiler.py) | Training speed and timing breakdown |
|
| 15 |
+
| [`ParamCountCallback`](param_count.py) | Parameter count logging |
|
| 16 |
+
| [`ForcedWeightNormCallback`](forced_weight_norm.py) | Forced weight normalization (EDM2) |
|
| 17 |
+
| [`CTScheduleCallback`](ct_schedule.py) | Consistency training curriculum schedule |
|
| 18 |
+
|
| 19 |
+
## Usage
|
| 20 |
+
|
| 21 |
+
Configure callbacks as a dictionary in the config:
|
| 22 |
+
|
| 23 |
+
```python
|
| 24 |
+
from fastgen.configs.callbacks import WANDB_CALLBACK, GradClip_CALLBACK, EMA_CALLBACK
|
| 25 |
+
from omegaconf import DictConfig
|
| 26 |
+
|
| 27 |
+
config.trainer.callbacks = DictConfig({
|
| 28 |
+
**WANDB_CALLBACK,
|
| 29 |
+
**GradClip_CALLBACK,
|
| 30 |
+
**EMA_CALLBACK,
|
| 31 |
+
})
|
| 32 |
+
```
|
| 33 |
+
|
| 34 |
+
Predefined configs are in `fastgen/configs/callbacks.py`.
|
| 35 |
+
|
| 36 |
+
## Lifecycle Hooks
|
| 37 |
+
|
| 38 |
+
Callbacks can override these hooks:
|
| 39 |
+
|
| 40 |
+
| Phase | Hooks |
|
| 41 |
+
|-------|-------|
|
| 42 |
+
| Setup | `on_app_begin`, `on_model_init_start/end`, `on_optimizer_init_start/end`, `on_load_checkpoint_start/end`, `on_dataloader_init_start/end` |
|
| 43 |
+
| Training | `on_train_begin/end`, `on_training_step_begin/end`, `on_training_accum_step_begin`, `on_backward_begin`, `on_optimizer_step_begin` |
|
| 44 |
+
| Validation | `on_validation_begin/end`, `on_validation_step_begin/end` |
|
| 45 |
+
| Checkpointing | `on_save_checkpoint_start/success/end`, `state_dict`, `load_state_dict` |
|
| 46 |
+
| Cleanup | `on_app_end` |
|
| 47 |
+
|
| 48 |
+
## Custom Callbacks
|
| 49 |
+
|
| 50 |
+
```python
|
| 51 |
+
from fastgen.callbacks.callback import Callback
|
| 52 |
+
|
| 53 |
+
class MyCallback(Callback):
|
| 54 |
+
def on_training_step_end(self, model, data_batch, output_batch, loss_dict, iteration=0):
|
| 55 |
+
if iteration % self.config.trainer.logging_iter == 0:
|
| 56 |
+
print(f"Step {iteration}: loss = {loss_dict['total_loss'].item():.4f}")
|
| 57 |
+
```
|
| 58 |
+
|
| 59 |
+
Callbacks have access to `self.config` and `self.trainer` after initialization.
|
FastGen/fastgen/callbacks/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
FastGen/fastgen/callbacks/callback.py
ADDED
|
@@ -0,0 +1,183 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
from typing import Callable, Any, TYPE_CHECKING
|
| 6 |
+
import torch
|
| 7 |
+
from torch.utils.data import DataLoader
|
| 8 |
+
|
| 9 |
+
from fastgen.utils import instantiate
|
| 10 |
+
import fastgen.utils.logging_utils as logger
|
| 11 |
+
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from fastgen.configs.config import BaseConfig
|
| 14 |
+
from fastgen.trainer import Trainer
|
| 15 |
+
from fastgen.methods import FastGenModel
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class CallbackDict:
|
| 19 |
+
def __init__(self, config: BaseConfig, trainer: Trainer):
|
| 20 |
+
self._callbacks = {}
|
| 21 |
+
callback_configs = config.trainer.callbacks
|
| 22 |
+
if callback_configs:
|
| 23 |
+
if isinstance(callback_configs, list):
|
| 24 |
+
logger.warning(msg="The 'config.trainer.callbacks' parameter should be a dict instead of a list. ")
|
| 25 |
+
callback_configs = {f"callback_{k}": v for k, v in enumerate(callback_configs)}
|
| 26 |
+
for callback_name, current_callback_cfg in callback_configs.items():
|
| 27 |
+
if "_target_" not in current_callback_cfg:
|
| 28 |
+
logger.critical(
|
| 29 |
+
f"Callback {callback_name} is missing the '_target_' field. \n Skip {current_callback_cfg}"
|
| 30 |
+
)
|
| 31 |
+
continue
|
| 32 |
+
logger.critical(f"Instantiating callback {callback_name}: {current_callback_cfg}")
|
| 33 |
+
_callback = instantiate(current_callback_cfg)
|
| 34 |
+
assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback."
|
| 35 |
+
_callback.config = config
|
| 36 |
+
_callback.trainer = trainer
|
| 37 |
+
_callback.on_app_begin()
|
| 38 |
+
self._callbacks[callback_name] = _callback
|
| 39 |
+
|
| 40 |
+
def __getattr__(self, method_name: str) -> Callable:
|
| 41 |
+
def load_state_dict(state_dict: dict[str, Any]) -> None:
|
| 42 |
+
for name, callback in self._callbacks.items():
|
| 43 |
+
if name in state_dict:
|
| 44 |
+
callback.load_state_dict(state_dict[name])
|
| 45 |
+
else:
|
| 46 |
+
logger.warning(f"Callback {name} not found in checkpoint.")
|
| 47 |
+
|
| 48 |
+
def state_dict() -> dict[str, Any]:
|
| 49 |
+
return {name: self._callbacks[name].state_dict() for name in self._callbacks}
|
| 50 |
+
|
| 51 |
+
def callbacks_wrapper(*args, **kwargs):
|
| 52 |
+
for callback in self._callbacks.values():
|
| 53 |
+
assert hasattr(callback, method_name)
|
| 54 |
+
method = getattr(callback, method_name)
|
| 55 |
+
assert callable(method), f"{method_name} is not callable."
|
| 56 |
+
method(*args, **kwargs)
|
| 57 |
+
|
| 58 |
+
if method_name == "state_dict":
|
| 59 |
+
return state_dict
|
| 60 |
+
if method_name == "load_state_dict":
|
| 61 |
+
return load_state_dict
|
| 62 |
+
return callbacks_wrapper
|
| 63 |
+
|
| 64 |
+
|
| 65 |
+
class Callback:
|
| 66 |
+
config: "BaseConfig"
|
| 67 |
+
trainer: "Trainer"
|
| 68 |
+
|
| 69 |
+
def on_app_begin(self) -> None:
|
| 70 |
+
pass
|
| 71 |
+
|
| 72 |
+
def on_model_init_start(self, model: FastGenModel) -> None:
|
| 73 |
+
pass
|
| 74 |
+
|
| 75 |
+
def on_model_init_end(self, model: FastGenModel | torch.nn.parallel.DistributedDataParallel) -> None:
|
| 76 |
+
pass
|
| 77 |
+
|
| 78 |
+
def on_optimizer_init_start(self, model: FastGenModel) -> None:
|
| 79 |
+
pass
|
| 80 |
+
|
| 81 |
+
def on_optimizer_init_end(self, model: FastGenModel) -> None:
|
| 82 |
+
pass
|
| 83 |
+
|
| 84 |
+
def on_load_checkpoint_start(self, model: FastGenModel) -> None:
|
| 85 |
+
pass
|
| 86 |
+
|
| 87 |
+
def on_load_checkpoint_end(self, model: FastGenModel, iteration: int = 0) -> None:
|
| 88 |
+
pass
|
| 89 |
+
|
| 90 |
+
def on_dataloader_init_start(self, model: FastGenModel, iteration: int = 0) -> None:
|
| 91 |
+
pass
|
| 92 |
+
|
| 93 |
+
def on_dataloader_init_end(
|
| 94 |
+
self, model: FastGenModel, dataloader_train: DataLoader, dataloader_val: DataLoader, iteration: int = 0
|
| 95 |
+
) -> None:
|
| 96 |
+
pass
|
| 97 |
+
|
| 98 |
+
def on_train_begin(self, model: FastGenModel, iteration: int = 0) -> None:
|
| 99 |
+
pass
|
| 100 |
+
|
| 101 |
+
def on_training_step_begin(
|
| 102 |
+
self,
|
| 103 |
+
model: FastGenModel,
|
| 104 |
+
iteration: int = 0,
|
| 105 |
+
) -> None:
|
| 106 |
+
pass
|
| 107 |
+
|
| 108 |
+
def on_training_accum_step_begin(
|
| 109 |
+
self,
|
| 110 |
+
model: FastGenModel,
|
| 111 |
+
data_batch: dict[str, torch.Tensor],
|
| 112 |
+
iteration: int = 0,
|
| 113 |
+
accum_iter: int = 0,
|
| 114 |
+
) -> None:
|
| 115 |
+
pass
|
| 116 |
+
|
| 117 |
+
def on_backward_begin(
|
| 118 |
+
self,
|
| 119 |
+
model: FastGenModel,
|
| 120 |
+
data_batch: dict[str, torch.Tensor],
|
| 121 |
+
output_batch: dict[str, torch.Tensor | Callable],
|
| 122 |
+
loss_dict: dict[str, torch.Tensor],
|
| 123 |
+
iteration: int = 0,
|
| 124 |
+
accum_iter: int = 0,
|
| 125 |
+
) -> None:
|
| 126 |
+
pass
|
| 127 |
+
|
| 128 |
+
def on_training_step_end(
|
| 129 |
+
self,
|
| 130 |
+
model: FastGenModel,
|
| 131 |
+
data_batch: dict[str, torch.Tensor],
|
| 132 |
+
output_batch: dict[str, torch.Tensor | Callable],
|
| 133 |
+
loss_dict: dict[str, torch.Tensor],
|
| 134 |
+
iteration: int = 0,
|
| 135 |
+
) -> None:
|
| 136 |
+
pass
|
| 137 |
+
|
| 138 |
+
def on_optimizer_step_begin(self, model: FastGenModel, iteration: int = 0) -> None:
|
| 139 |
+
pass
|
| 140 |
+
|
| 141 |
+
def on_train_end(self, model: FastGenModel, iteration: int = 0) -> None:
|
| 142 |
+
pass
|
| 143 |
+
|
| 144 |
+
def on_validation_begin(self, model: FastGenModel, iteration: int = 0, idx: int = 0) -> None:
|
| 145 |
+
pass
|
| 146 |
+
|
| 147 |
+
def on_validation_step_begin(
|
| 148 |
+
self, model: FastGenModel, data_batch: dict[str, torch.Tensor], step: int = 0, iteration: int = 0, idx: int = 0
|
| 149 |
+
) -> None:
|
| 150 |
+
pass
|
| 151 |
+
|
| 152 |
+
def on_validation_step_end(
|
| 153 |
+
self,
|
| 154 |
+
model: FastGenModel,
|
| 155 |
+
data_batch: dict[str, torch.Tensor],
|
| 156 |
+
output_batch: dict[str, torch.Tensor | Callable],
|
| 157 |
+
loss_dict: dict[str, torch.Tensor],
|
| 158 |
+
step: int = 0,
|
| 159 |
+
iteration: int = 0,
|
| 160 |
+
idx: int = 0,
|
| 161 |
+
) -> None:
|
| 162 |
+
pass
|
| 163 |
+
|
| 164 |
+
def on_validation_end(self, model: FastGenModel, iteration: int = 0, idx: int = 0) -> None:
|
| 165 |
+
pass
|
| 166 |
+
|
| 167 |
+
def on_save_checkpoint_start(self, model: FastGenModel, iteration: int = 0) -> None:
|
| 168 |
+
pass
|
| 169 |
+
|
| 170 |
+
def on_save_checkpoint_success(self, model: FastGenModel, iteration: int = 0, path: str = None) -> None:
|
| 171 |
+
pass
|
| 172 |
+
|
| 173 |
+
def on_save_checkpoint_end(self, model: FastGenModel, iteration: int = 0) -> None:
|
| 174 |
+
pass
|
| 175 |
+
|
| 176 |
+
def on_app_end(self, model: FastGenModel, iteration: int = 0) -> None:
|
| 177 |
+
pass
|
| 178 |
+
|
| 179 |
+
def state_dict(self) -> dict[str, Any]:
|
| 180 |
+
return {}
|
| 181 |
+
|
| 182 |
+
def load_state_dict(self, state_dict: dict[str, Any]) -> None:
|
| 183 |
+
pass
|
FastGen/fastgen/callbacks/ct_schedule.py
ADDED
|
@@ -0,0 +1,83 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
from typing import Callable, TYPE_CHECKING
|
| 6 |
+
import wandb
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
|
| 10 |
+
from fastgen.callbacks.callback import Callback
|
| 11 |
+
import fastgen.utils.logging_utils as logger
|
| 12 |
+
from fastgen.utils.basic_utils import get_batch_size_total
|
| 13 |
+
from fastgen.utils.distributed import is_rank0
|
| 14 |
+
|
| 15 |
+
if TYPE_CHECKING:
|
| 16 |
+
from fastgen.methods import FastGenModel
|
| 17 |
+
from fastgen.configs.config import BaseConfig
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class CTScheduleCallback(Callback):
|
| 21 |
+
config: "BaseConfig"
|
| 22 |
+
|
| 23 |
+
def __init__(
|
| 24 |
+
self,
|
| 25 |
+
q: float = 2.0,
|
| 26 |
+
ratio_limit: float = 0.999,
|
| 27 |
+
kimg_per_stage: int = 12500,
|
| 28 |
+
batch_size: int = 1,
|
| 29 |
+
):
|
| 30 |
+
self.q = q
|
| 31 |
+
self.ratio_limit = ratio_limit
|
| 32 |
+
self.kimg_per_stage = kimg_per_stage
|
| 33 |
+
self.batch_size = batch_size
|
| 34 |
+
|
| 35 |
+
self.stage = 0
|
| 36 |
+
self.ratio = 0.0
|
| 37 |
+
|
| 38 |
+
def _get_cur_stage(self, model, iteration):
|
| 39 |
+
# Start from the saved iteration of the first-stage model in TCM
|
| 40 |
+
if hasattr(model, "resume_iter"):
|
| 41 |
+
assert isinstance(model.resume_iter, int)
|
| 42 |
+
iteration = iteration + model.resume_iter
|
| 43 |
+
|
| 44 |
+
batch_size = self.batch_size
|
| 45 |
+
if hasattr(self, "config"):
|
| 46 |
+
# override the batch_size using self.config
|
| 47 |
+
batch_size = get_batch_size_total(self.config)
|
| 48 |
+
|
| 49 |
+
cur_nimg = iteration * batch_size
|
| 50 |
+
stage = cur_nimg // (self.kimg_per_stage * 1000)
|
| 51 |
+
return stage, cur_nimg
|
| 52 |
+
|
| 53 |
+
def _update_schedule(self, stage):
|
| 54 |
+
self.stage = stage
|
| 55 |
+
self.ratio = 1 - 1 / self.q ** (stage + 1)
|
| 56 |
+
if self.ratio > self.ratio_limit:
|
| 57 |
+
logger.info(f"Clipping ratio from {self.ratio} -> {self.ratio_limit}")
|
| 58 |
+
self.ratio = self.ratio_limit
|
| 59 |
+
|
| 60 |
+
def on_train_begin(self, model: FastGenModel, iteration: int = 0) -> None:
|
| 61 |
+
stage, _ = self._get_cur_stage(model, iteration)
|
| 62 |
+
self._update_schedule(stage)
|
| 63 |
+
setattr(model, "ratio", self.ratio)
|
| 64 |
+
|
| 65 |
+
def on_training_step_end(
|
| 66 |
+
self,
|
| 67 |
+
model: FastGenModel,
|
| 68 |
+
data_batch: dict[str, torch.Tensor],
|
| 69 |
+
output_batch: dict[str, torch.Tensor | Callable],
|
| 70 |
+
loss_dict: dict[str, torch.Tensor],
|
| 71 |
+
iteration: int = 0,
|
| 72 |
+
) -> None:
|
| 73 |
+
del data_batch, output_batch, loss_dict
|
| 74 |
+
new_stage, cur_nimg = self._get_cur_stage(model, iteration)
|
| 75 |
+
if new_stage > self.stage:
|
| 76 |
+
self._update_schedule(new_stage)
|
| 77 |
+
setattr(model, "ratio", self.ratio)
|
| 78 |
+
|
| 79 |
+
if hasattr(self, "config"):
|
| 80 |
+
# only wandb log when config exists
|
| 81 |
+
if iteration % self.config.trainer.logging_iter == 0 and is_rank0():
|
| 82 |
+
if wandb.run:
|
| 83 |
+
wandb.log({"ct_schedule/kimg": cur_nimg / 1e3, "ct_schedule/ratio": self.ratio}, step=iteration)
|
FastGen/fastgen/callbacks/ema.py
ADDED
|
@@ -0,0 +1,155 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from typing import Callable, TYPE_CHECKING, Optional
|
| 7 |
+
|
| 8 |
+
import torch
|
| 9 |
+
import wandb
|
| 10 |
+
|
| 11 |
+
from fastgen.callbacks.callback import Callback
|
| 12 |
+
from fastgen.utils.basic_utils import get_batch_size_total
|
| 13 |
+
from fastgen.utils.distributed import synchronize, is_rank0
|
| 14 |
+
import fastgen.utils.logging_utils as logger
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from fastgen.methods import FastGenModel
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class EMACallback(Callback):
|
| 21 |
+
def __init__(
|
| 22 |
+
self,
|
| 23 |
+
type: str = "constant",
|
| 24 |
+
# params for type=constant
|
| 25 |
+
beta: float = 0.9999,
|
| 26 |
+
# params for type=power
|
| 27 |
+
gamma: float = 16.97,
|
| 28 |
+
# params for type=halflife
|
| 29 |
+
ema_halflife_kimg: float = 500,
|
| 30 |
+
ema_rampup_ratio: Optional[float] = 0.05,
|
| 31 |
+
ema_name: str = "ema",
|
| 32 |
+
batch_size: int = 1, # overwritten by self.config if it exists
|
| 33 |
+
fsdp: bool = False, # overwritten by self.config if it exists
|
| 34 |
+
):
|
| 35 |
+
self.type = type
|
| 36 |
+
self.beta = beta
|
| 37 |
+
self.gamma = gamma
|
| 38 |
+
self.ema_halflife_kimg = ema_halflife_kimg
|
| 39 |
+
self.ema_rampup_ratio = ema_rampup_ratio
|
| 40 |
+
self.ema_name = ema_name
|
| 41 |
+
self.batch_size = batch_size
|
| 42 |
+
self._is_fsdp = fsdp
|
| 43 |
+
self._enabled = True
|
| 44 |
+
|
| 45 |
+
def on_app_begin(self) -> None:
|
| 46 |
+
if hasattr(self, "config"):
|
| 47 |
+
# override using config
|
| 48 |
+
self._is_fsdp = self.config.trainer.fsdp
|
| 49 |
+
self.batch_size = get_batch_size_total(self.config)
|
| 50 |
+
|
| 51 |
+
def on_model_init_end(
|
| 52 |
+
self, model: FastGenModel | torch.nn.parallel.DistributedDataParallel, iteration: int = 0
|
| 53 |
+
) -> None:
|
| 54 |
+
# Unwrap DDP if needed to access the original model's attributes
|
| 55 |
+
if hasattr(model, "module"):
|
| 56 |
+
model = model.module
|
| 57 |
+
|
| 58 |
+
# check ema initialization
|
| 59 |
+
ema = getattr(model, self.ema_name, None)
|
| 60 |
+
if ema is None:
|
| 61 |
+
self._enabled = False
|
| 62 |
+
logger.info(f"EMA {self.ema_name} is not enabled, skipping callback.")
|
| 63 |
+
return
|
| 64 |
+
|
| 65 |
+
assert ema.training is False, f"EMA {self.ema_name} should be in eval mode"
|
| 66 |
+
for name, p_net in ema.named_parameters():
|
| 67 |
+
assert not p_net.requires_grad, f"EMA parameter {name} should not require gradients"
|
| 68 |
+
|
| 69 |
+
def _total_iteration(self, model: FastGenModel, iteration: int) -> int:
|
| 70 |
+
if hasattr(model, "resume_iter"):
|
| 71 |
+
assert isinstance(model.resume_iter, int)
|
| 72 |
+
iteration = iteration + model.resume_iter
|
| 73 |
+
return iteration
|
| 74 |
+
|
| 75 |
+
def _power_function_beta(self, iteration):
|
| 76 |
+
beta = (1 - 1 / iteration) ** (self.gamma + 1)
|
| 77 |
+
return beta
|
| 78 |
+
|
| 79 |
+
def _get_cur_nimg(self, iteration):
|
| 80 |
+
cur_nimg = iteration * self.batch_size
|
| 81 |
+
return self.batch_size, cur_nimg
|
| 82 |
+
|
| 83 |
+
def _halflife_beta(self, iteration):
|
| 84 |
+
ema_halflife_nimg = self.ema_halflife_kimg * 1000
|
| 85 |
+
batch_size, cur_nimg = self._get_cur_nimg(iteration)
|
| 86 |
+
if self.ema_rampup_ratio is not None:
|
| 87 |
+
ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * self.ema_rampup_ratio)
|
| 88 |
+
ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8))
|
| 89 |
+
return ema_beta
|
| 90 |
+
|
| 91 |
+
def on_training_step_end(
|
| 92 |
+
self,
|
| 93 |
+
model: FastGenModel,
|
| 94 |
+
data_batch: dict[str, torch.Tensor],
|
| 95 |
+
output_batch: dict[str, torch.Tensor | Callable],
|
| 96 |
+
loss_dict: dict[str, torch.Tensor],
|
| 97 |
+
iteration: int = 0,
|
| 98 |
+
) -> None:
|
| 99 |
+
del data_batch, output_batch, loss_dict
|
| 100 |
+
|
| 101 |
+
# Check if EMA is enabled
|
| 102 |
+
if not self._enabled:
|
| 103 |
+
return
|
| 104 |
+
|
| 105 |
+
if self.type == "constant":
|
| 106 |
+
beta = self.beta
|
| 107 |
+
elif self.type == "power":
|
| 108 |
+
beta = self._power_function_beta(self._total_iteration(model, iteration))
|
| 109 |
+
elif self.type == "halflife":
|
| 110 |
+
beta = self._halflife_beta(self._total_iteration(model, iteration))
|
| 111 |
+
else:
|
| 112 |
+
raise ValueError(f"Invalid {self.ema_name} type: {self.type}")
|
| 113 |
+
|
| 114 |
+
with torch.no_grad():
|
| 115 |
+
ema = getattr(model, self.ema_name)
|
| 116 |
+
ema_state_dict = ema.state_dict()
|
| 117 |
+
|
| 118 |
+
for name, p_net in model.net.named_parameters():
|
| 119 |
+
if self._is_fsdp and hasattr(p_net, "full_tensor"):
|
| 120 |
+
# Gather the full tensor from all ranks if using FSDP with DTensor
|
| 121 |
+
# When CPU offloading is enabled, we need to move to CUDA first because
|
| 122 |
+
# full_tensor() performs an all_gather which requires a CUDA backend
|
| 123 |
+
if p_net.device.type == "cpu":
|
| 124 |
+
# Move local shard to CUDA, gather, then the result stays on CUDA
|
| 125 |
+
# which is fine since we'll copy to EMA (which handles device placement)
|
| 126 |
+
full_tensor = p_net.to("cuda").full_tensor()
|
| 127 |
+
else:
|
| 128 |
+
full_tensor = p_net.full_tensor()
|
| 129 |
+
else:
|
| 130 |
+
full_tensor = p_net
|
| 131 |
+
# Strip checkpoint wrapper prefix if present (EMA doesn't have checkpointing)
|
| 132 |
+
ema_name = name.replace("_checkpoint_wrapped_module.", "")
|
| 133 |
+
# Cast to EMA dtype and device (typically float32 on CPU) for lerp_ compatibility
|
| 134 |
+
if ema_name in ema_state_dict:
|
| 135 |
+
ema_param = ema_state_dict[ema_name]
|
| 136 |
+
ema_param.lerp_(full_tensor.to(device=ema_param.device, dtype=ema_param.dtype), 1.0 - beta)
|
| 137 |
+
elif iteration == 1:
|
| 138 |
+
# only warn on first iteration if parameter is not found
|
| 139 |
+
logger.warning(f"EMA parameter {ema_name} not found in EMA state dict, skipping update.")
|
| 140 |
+
|
| 141 |
+
# FSDP2 doesn't shard buffers, so we can just copy them
|
| 142 |
+
for name, p_net in model.net.named_buffers():
|
| 143 |
+
if name in ema_state_dict:
|
| 144 |
+
ema_param = ema_state_dict[name]
|
| 145 |
+
ema_param.copy_(p_net.to(device=ema_param.device, dtype=ema_param.dtype))
|
| 146 |
+
elif iteration == 1:
|
| 147 |
+
# only warn on first iteration if buffer is not found
|
| 148 |
+
logger.warning(f"EMA buffer {name} not found in EMA state dict, skipping update.")
|
| 149 |
+
|
| 150 |
+
if hasattr(self, "config"):
|
| 151 |
+
# only wandb log when config exists
|
| 152 |
+
if iteration % self.config.trainer.logging_iter == 0 and is_rank0():
|
| 153 |
+
if wandb.run:
|
| 154 |
+
wandb.log({f"ema/{self.ema_name}_beta": beta}, step=iteration)
|
| 155 |
+
synchronize()
|
FastGen/fastgen/callbacks/forced_weight_norm.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from typing import TYPE_CHECKING
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
from fastgen.callbacks.callback import Callback
|
| 10 |
+
import fastgen.utils.logging_utils as logger
|
| 11 |
+
|
| 12 |
+
if TYPE_CHECKING:
|
| 13 |
+
from fastgen.methods import FastGenModel
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
class ForcedWeightNormCallback(Callback):
|
| 17 |
+
def on_training_accum_step_begin(
|
| 18 |
+
self,
|
| 19 |
+
model: FastGenModel,
|
| 20 |
+
*args,
|
| 21 |
+
**kwargs,
|
| 22 |
+
) -> None:
|
| 23 |
+
if hasattr(model.net, "forced_weight_normalization"):
|
| 24 |
+
model.net.forced_weight_normalization()
|
| 25 |
+
else:
|
| 26 |
+
logger.warning(
|
| 27 |
+
"Enabled ForcedWeightNormCallback but model.net does not have the forced_weight_normalization method."
|
| 28 |
+
)
|
FastGen/fastgen/callbacks/gpu_mem_profiler.py
ADDED
|
@@ -0,0 +1,134 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
import os
|
| 8 |
+
from fastgen.utils import logging_utils as logger
|
| 9 |
+
from fastgen.callbacks.callback import Callback
|
| 10 |
+
import atexit
|
| 11 |
+
import pickle
|
| 12 |
+
from typing import Callable, Optional, TYPE_CHECKING
|
| 13 |
+
import base64
|
| 14 |
+
import json
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from fastgen.methods import FastGenModel
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def create_dump(dump_path):
|
| 21 |
+
logger.critical(f"Creating {dump_path}")
|
| 22 |
+
if not dump_path.endswith("html"):
|
| 23 |
+
print(f"[{__file__}] create_dump produces an HTML file but was called with {dump_path=}")
|
| 24 |
+
torch.cuda.memory._dump_snapshot(dump_path + ".pickle")
|
| 25 |
+
with open(dump_path + ".pickle", "rb") as f:
|
| 26 |
+
data = pickle.load(f)
|
| 27 |
+
_memory_viz_template = r"""
|
| 28 |
+
<!DOCTYPE html>
|
| 29 |
+
<html>
|
| 30 |
+
<head>
|
| 31 |
+
</head>
|
| 32 |
+
<body>
|
| 33 |
+
<script type="module">
|
| 34 |
+
import {add_local_files} from "https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/utils/viz/MemoryViz.js"
|
| 35 |
+
const local_files = $SNAPSHOT
|
| 36 |
+
add_local_files(local_files, $VIZ_KIND)
|
| 37 |
+
</script>
|
| 38 |
+
</body>
|
| 39 |
+
"""
|
| 40 |
+
|
| 41 |
+
# find which GPU was active
|
| 42 |
+
idx_device = -1
|
| 43 |
+
for i in range(8):
|
| 44 |
+
if data["device_traces"][i]:
|
| 45 |
+
idx_device = i
|
| 46 |
+
break
|
| 47 |
+
|
| 48 |
+
traces = data["device_traces"][idx_device] # create an aliasing variable for convenience
|
| 49 |
+
traces = [
|
| 50 |
+
d for d in traces if d["action"] == "alloc" or d["action"] == "free_completed"
|
| 51 |
+
] # only the `alloc` and `free_completed` events matter for our visualization
|
| 52 |
+
|
| 53 |
+
for d in traces:
|
| 54 |
+
d["fastgen_frames"] = [
|
| 55 |
+
f for f in d["frames"] if "fastgen" in f["filename"]
|
| 56 |
+
] # get the callstack frames from fastgen code (e.g. ignore frames in pytorch/other libraries)
|
| 57 |
+
if not d["fastgen_frames"]:
|
| 58 |
+
d["fastgen_frames"] = d["frames"]
|
| 59 |
+
|
| 60 |
+
# run through the trace and find allocations that were allocated but never freed
|
| 61 |
+
set_alloced_addrs: dict = {}
|
| 62 |
+
for d in traces:
|
| 63 |
+
if d["action"] == "alloc":
|
| 64 |
+
set_alloced_addrs[d["addr"]] = d
|
| 65 |
+
elif d["action"] == "free_completed":
|
| 66 |
+
if d["addr"] in set_alloced_addrs:
|
| 67 |
+
del set_alloced_addrs[d["addr"]]
|
| 68 |
+
else:
|
| 69 |
+
raise NotImplementedError(f"{d['action']}")
|
| 70 |
+
|
| 71 |
+
never_freed_traces = list(set_alloced_addrs.values())
|
| 72 |
+
KB = 1 << 10
|
| 73 |
+
never_freed_traces = [t for t in never_freed_traces if t["size"] > KB] # get rid of allocations below 1 KB
|
| 74 |
+
|
| 75 |
+
# now proceed through the trace (guarenteed to be all `alloc` events as we removed all free events).
|
| 76 |
+
# for each pair of alloc events, merge them iff they share a common fastgen ancestor.
|
| 77 |
+
# Merging events is useful as it both speeds up the visualization rendering and also makes it more understandable.
|
| 78 |
+
i = 0
|
| 79 |
+
while i < len(never_freed_traces) - 1:
|
| 80 |
+
curr_frames = never_freed_traces[i]["fastgen_frames"]
|
| 81 |
+
next_frames = never_freed_traces[i + 1]["fastgen_frames"]
|
| 82 |
+
if (
|
| 83 |
+
curr_frames and next_frames and curr_frames[0] == next_frames[0]
|
| 84 |
+
): # TODO: probably should compare the full callstack
|
| 85 |
+
# same ancestor, delete next event and add its size to current event
|
| 86 |
+
never_freed_traces[i]["size"] += never_freed_traces[i + 1]["size"]
|
| 87 |
+
never_freed_traces.pop(i + 1)
|
| 88 |
+
else:
|
| 89 |
+
i += 1 # different ancestor, do not combine and move on
|
| 90 |
+
|
| 91 |
+
data["device_traces"][idx_device] = never_freed_traces # update the trace to only be the merged-alloc events
|
| 92 |
+
data["segments"] = [] # shrink the trace, unused in memory timeline
|
| 93 |
+
data["external_annotations"] = [] # shrink the trace, unused in memory timeline
|
| 94 |
+
buffer = pickle.dumps(data)
|
| 95 |
+
buffer += b"\x00" * (3 - len(buffer) % 3)
|
| 96 |
+
encoded_buffer = base64.b64encode(buffer).decode("utf-8")
|
| 97 |
+
json_format = json.dumps([{"name": "snapshot.pickle", "base64": encoded_buffer}])
|
| 98 |
+
html_src = _memory_viz_template.replace("$VIZ_KIND", repr("Active Memory Timeline")).replace(
|
| 99 |
+
"$SNAPSHOT", json_format
|
| 100 |
+
)
|
| 101 |
+
with open(dump_path, "w") as f:
|
| 102 |
+
f.write(html_src)
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
class MemTrackerCallback(Callback):
|
| 106 |
+
def __init__(self, save_every_n_iters: Optional[int] = None, deactivate_after_n_iters: int = 100):
|
| 107 |
+
def close_and_save():
|
| 108 |
+
create_dump(
|
| 109 |
+
f"{os.environ.get('FASTGEN_OUTPUT_ROOT', 'FASTGEN_OUTPUT')}/crash_rank{os.environ.get('RANK', '0')}.html"
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
self.deactivate_after_n_iters = deactivate_after_n_iters # Deactivate eventually to prevent leaking host memory
|
| 113 |
+
self.save_every_n_iters = save_every_n_iters
|
| 114 |
+
self.atexit_fn = close_and_save
|
| 115 |
+
atexit.register(self.atexit_fn)
|
| 116 |
+
|
| 117 |
+
def on_app_begin(self):
|
| 118 |
+
logger.info("[MemTrackerCallback] Tracking peak memory usage")
|
| 119 |
+
torch.cuda.memory._record_memory_history(stacks="python")
|
| 120 |
+
|
| 121 |
+
def on_training_step_end(
|
| 122 |
+
self,
|
| 123 |
+
model: FastGenModel,
|
| 124 |
+
data_batch: dict[str, torch.Tensor],
|
| 125 |
+
output_batch: dict[str, torch.Tensor | Callable],
|
| 126 |
+
loss_dict: dict[str, torch.Tensor],
|
| 127 |
+
iteration: int = 0,
|
| 128 |
+
) -> None:
|
| 129 |
+
if iteration > self.deactivate_after_n_iters:
|
| 130 |
+
torch.cuda.memory._record_memory_history(enabled=None) # frees pytorch tracking datastructures
|
| 131 |
+
if self.save_every_n_iters is not None and (iteration % self.save_every_n_iters) == 0:
|
| 132 |
+
create_dump(
|
| 133 |
+
f"{os.environ.get('FASTGEN_OUTPUT_ROOT', 'FASTGEN_OUTPUT')}/step{iteration}_rank{os.environ.get('RANK', '0')}.html"
|
| 134 |
+
)
|
FastGen/fastgen/callbacks/gpu_stats.py
ADDED
|
@@ -0,0 +1,92 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import os
|
| 7 |
+
from typing import TYPE_CHECKING, Callable, Any, Dict, List
|
| 8 |
+
|
| 9 |
+
import pandas as pd
|
| 10 |
+
import psutil
|
| 11 |
+
import torch
|
| 12 |
+
|
| 13 |
+
from fastgen.callbacks.callback import Callback
|
| 14 |
+
from fastgen.utils.distributed import world_size, is_rank0, synchronize
|
| 15 |
+
import fastgen.utils.logging_utils as logger
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from fastgen.methods import FastGenModel
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def log_prof_data(data_list: List[Dict[str, Any]]):
|
| 22 |
+
# Create a table to log data with rank information
|
| 23 |
+
metrics = list(data_list[0].keys())
|
| 24 |
+
|
| 25 |
+
# Initialize dictionaries to store min and max values for each metric
|
| 26 |
+
min_values = {key: float("inf") for key in metrics}
|
| 27 |
+
max_values = {key: float("-inf") for key in metrics}
|
| 28 |
+
sum_values = {key: 0.0 for key in metrics}
|
| 29 |
+
|
| 30 |
+
count = 0
|
| 31 |
+
|
| 32 |
+
for _rank, prof_data in enumerate(data_list):
|
| 33 |
+
count += 1
|
| 34 |
+
|
| 35 |
+
# Update min, max, and sum values
|
| 36 |
+
for key in metrics:
|
| 37 |
+
min_values[key] = min(min_values[key], prof_data[key])
|
| 38 |
+
max_values[key] = max(max_values[key], prof_data[key])
|
| 39 |
+
sum_values[key] += prof_data[key]
|
| 40 |
+
|
| 41 |
+
# Calculate average values
|
| 42 |
+
avg_values = {key: sum_values[key] / count for key in metrics}
|
| 43 |
+
summary_df = pd.DataFrame({"Avg": avg_values, "Max": max_values, "Min": min_values})
|
| 44 |
+
|
| 45 |
+
logger.info(f"GPU stats:\n{summary_df.to_string()}")
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
class GPUStatsCallback(Callback):
|
| 49 |
+
def __init__(self, every_n: int = 100):
|
| 50 |
+
self.every_n = every_n
|
| 51 |
+
|
| 52 |
+
def on_train_begin(self, model: FastGenModel, iteration: int = 0):
|
| 53 |
+
torch.cuda.reset_peak_memory_stats()
|
| 54 |
+
if hasattr(self, "config"):
|
| 55 |
+
# overwritten by logging_iter if self.config exists
|
| 56 |
+
self.every_n = self.config.trainer.logging_iter
|
| 57 |
+
logger.info(f"every_n to measure gpus stats: {self.every_n}")
|
| 58 |
+
|
| 59 |
+
def on_training_step_end(
|
| 60 |
+
self,
|
| 61 |
+
model: FastGenModel,
|
| 62 |
+
data_batch: dict[str, torch.Tensor],
|
| 63 |
+
output_batch: dict[str, torch.Tensor | Callable],
|
| 64 |
+
loss_dict: dict[str, torch.Tensor],
|
| 65 |
+
iteration: int = 0,
|
| 66 |
+
) -> None:
|
| 67 |
+
del data_batch, output_batch, loss_dict
|
| 68 |
+
if iteration % self.every_n == 0:
|
| 69 |
+
cur_process = psutil.Process(os.getpid())
|
| 70 |
+
cpu_memory_usage = sum(p.memory_info().rss for p in [cur_process] + cur_process.children(recursive=True))
|
| 71 |
+
cpu_mem_gb = cpu_memory_usage / (1024**3)
|
| 72 |
+
|
| 73 |
+
peak_gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
|
| 74 |
+
peak_gpu_mem_reserved_gb = torch.cuda.max_memory_reserved() / (1024**3)
|
| 75 |
+
util = torch.cuda.utilization()
|
| 76 |
+
|
| 77 |
+
prof_data = {
|
| 78 |
+
"cpu_mem_gb": float(cpu_mem_gb),
|
| 79 |
+
"peak_gpu_mem_gb": float(peak_gpu_mem_gb),
|
| 80 |
+
"peak_gpu_mem_reserved_gb": float(peak_gpu_mem_reserved_gb),
|
| 81 |
+
"util": float(util),
|
| 82 |
+
}
|
| 83 |
+
|
| 84 |
+
synchronize()
|
| 85 |
+
data_list = [prof_data] * world_size()
|
| 86 |
+
# this is blocking by default
|
| 87 |
+
if world_size() > 1:
|
| 88 |
+
torch.distributed.all_gather_object(data_list, prof_data)
|
| 89 |
+
|
| 90 |
+
if is_rank0():
|
| 91 |
+
log_prof_data(data_list)
|
| 92 |
+
synchronize()
|
FastGen/fastgen/callbacks/grad_clip.py
ADDED
|
@@ -0,0 +1,222 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
from contextlib import contextmanager
|
| 7 |
+
from typing import TYPE_CHECKING, Optional
|
| 8 |
+
import wandb
|
| 9 |
+
|
| 10 |
+
import torch
|
| 11 |
+
from torch.distributed.tensor import DTensor
|
| 12 |
+
|
| 13 |
+
from fastgen.callbacks.callback import Callback
|
| 14 |
+
from fastgen.utils.distributed import is_rank0, world_size
|
| 15 |
+
import fastgen.utils.logging_utils as logger
|
| 16 |
+
|
| 17 |
+
if TYPE_CHECKING:
|
| 18 |
+
from fastgen.methods import FastGenModel
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
@contextmanager
|
| 22 |
+
def cast_gradients_dtype(model, dtype=torch.float32, enabled=True):
|
| 23 |
+
if enabled:
|
| 24 |
+
try:
|
| 25 |
+
# Cast gradients to the desired dtype
|
| 26 |
+
for param in model.parameters():
|
| 27 |
+
if param.grad is not None and param.grad.dtype != dtype:
|
| 28 |
+
param.grad.data = param.grad.data.to(dtype)
|
| 29 |
+
yield
|
| 30 |
+
finally:
|
| 31 |
+
# Restore original gradient dtypes
|
| 32 |
+
for param in model.parameters():
|
| 33 |
+
if param.grad is not None and param.grad.dtype != param.dtype:
|
| 34 |
+
param.grad.data = param.grad.data.to(param.dtype)
|
| 35 |
+
else:
|
| 36 |
+
yield
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
def clip_grad_norm_fsdp(
|
| 40 |
+
parameters,
|
| 41 |
+
max_norm: float,
|
| 42 |
+
norm_type: float = 2.0,
|
| 43 |
+
device: Optional[torch.device] = None,
|
| 44 |
+
) -> torch.Tensor:
|
| 45 |
+
"""
|
| 46 |
+
Clip gradients for FSDP2 models with CPU offloading.
|
| 47 |
+
|
| 48 |
+
The standard torch.nn.utils.clip_grad_norm_ fails with FSDP2 CPU offloading because
|
| 49 |
+
DTensor operations (like division) trigger all_reduce on CPU, which has no backend.
|
| 50 |
+
|
| 51 |
+
This implementation:
|
| 52 |
+
1. Extracts local tensors from DTensors
|
| 53 |
+
2. Computes local norms on native device (CPU or GPU)
|
| 54 |
+
3. All-reduces the scalar norm to get global norm
|
| 55 |
+
4. Clips gradients in-place using the global norm
|
| 56 |
+
|
| 57 |
+
Args:
|
| 58 |
+
parameters: Iterable of parameters with gradients
|
| 59 |
+
max_norm: Maximum norm value
|
| 60 |
+
norm_type: Type of norm (default: L2)
|
| 61 |
+
device: Device for all-reduce tensor. If None, inferred from gradients or defaults to cuda.
|
| 62 |
+
|
| 63 |
+
Returns:
|
| 64 |
+
Total gradient norm (global across all ranks) as a regular tensor
|
| 65 |
+
"""
|
| 66 |
+
if isinstance(parameters, torch.Tensor):
|
| 67 |
+
parameters = [parameters]
|
| 68 |
+
parameters = list(p for p in parameters if p.grad is not None)
|
| 69 |
+
|
| 70 |
+
if len(parameters) == 0:
|
| 71 |
+
return torch.tensor(0.0)
|
| 72 |
+
|
| 73 |
+
# Compute per-parameter norms on their native device (CPU or GPU)
|
| 74 |
+
# We compute norm^norm_type to allow proper aggregation across ranks
|
| 75 |
+
local_norm_sum = 0.0
|
| 76 |
+
inferred_device = None
|
| 77 |
+
for p in parameters:
|
| 78 |
+
if isinstance(p.grad, DTensor):
|
| 79 |
+
grad = p.grad._local_tensor
|
| 80 |
+
else:
|
| 81 |
+
grad = p.grad
|
| 82 |
+
|
| 83 |
+
# Infer CUDA device from gradients (use first CUDA device found)
|
| 84 |
+
if inferred_device is None and grad.device.type == "cuda":
|
| 85 |
+
inferred_device = grad.device
|
| 86 |
+
|
| 87 |
+
# Compute norm on the gradient's native device, accumulate as Python float
|
| 88 |
+
local_norm_sum += torch.norm(grad.detach().float(), norm_type).item() ** norm_type
|
| 89 |
+
|
| 90 |
+
# Use provided device, or inferred device, or fall back to current CUDA device
|
| 91 |
+
if device is None:
|
| 92 |
+
device = inferred_device if inferred_device is not None else torch.device("cuda")
|
| 93 |
+
local_norm_sum = torch.tensor(local_norm_sum, device=device)
|
| 94 |
+
|
| 95 |
+
# All-reduce to get global norm across all ranks
|
| 96 |
+
if world_size() > 1:
|
| 97 |
+
torch.distributed.all_reduce(local_norm_sum, op=torch.distributed.ReduceOp.SUM)
|
| 98 |
+
|
| 99 |
+
# Compute final global norm
|
| 100 |
+
total_norm = local_norm_sum ** (1.0 / norm_type)
|
| 101 |
+
|
| 102 |
+
# Compute clip coefficient (regular tensor division, no DTensor ops)
|
| 103 |
+
clip_coef = max_norm / (total_norm + 1e-6)
|
| 104 |
+
clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
|
| 105 |
+
|
| 106 |
+
# Apply clipping to gradients
|
| 107 |
+
for p in parameters:
|
| 108 |
+
if isinstance(p.grad, DTensor):
|
| 109 |
+
# For DTensor, scale the local tensor directly
|
| 110 |
+
local_grad = p.grad._local_tensor
|
| 111 |
+
local_grad.mul_(clip_coef_clamped.to(local_grad.device))
|
| 112 |
+
else:
|
| 113 |
+
p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))
|
| 114 |
+
|
| 115 |
+
return total_norm
|
| 116 |
+
|
| 117 |
+
|
| 118 |
+
class GradClipCallback(Callback):
|
| 119 |
+
def __init__(
|
| 120 |
+
self,
|
| 121 |
+
grad_norm: float | None = 1.0,
|
| 122 |
+
model_key: str = "net",
|
| 123 |
+
posinf: float | None = None,
|
| 124 |
+
neginf: float | None = None,
|
| 125 |
+
precision_grad_clip: Optional[torch.dtype] = None,
|
| 126 |
+
) -> None:
|
| 127 |
+
self.grad_norm = grad_norm
|
| 128 |
+
self.model_key = model_key
|
| 129 |
+
self.posinf = posinf
|
| 130 |
+
self.neginf = neginf
|
| 131 |
+
self.precision_grad_clip = precision_grad_clip
|
| 132 |
+
|
| 133 |
+
def nan_to_num(self, module: torch.nn.Module) -> tuple[int, torch.dtype | None]:
|
| 134 |
+
grad_dtype = None
|
| 135 |
+
non_finite_grads_count = 0
|
| 136 |
+
|
| 137 |
+
for name, param in module.named_parameters():
|
| 138 |
+
if param.grad is not None:
|
| 139 |
+
grad_dtype = param.grad.dtype
|
| 140 |
+
|
| 141 |
+
# Extract local tensor for DTensor (avoids triggering distributed ops on CPU)
|
| 142 |
+
if isinstance(param.grad, DTensor):
|
| 143 |
+
grad = param.grad._local_tensor
|
| 144 |
+
else:
|
| 145 |
+
grad = param.grad
|
| 146 |
+
|
| 147 |
+
non_finite_grads = grad.numel() - grad.isfinite().sum().item()
|
| 148 |
+
if non_finite_grads:
|
| 149 |
+
non_finite_grads_count += non_finite_grads
|
| 150 |
+
logger.debug(
|
| 151 |
+
f"Gradient of {name} (dtype {grad_dtype}) is not finite: "
|
| 152 |
+
f"Setting {grad.isnan().sum().item()} NaNs to 0 and {grad.isinf().sum().item()} Infs "
|
| 153 |
+
f"to {self.posinf} or {self.neginf}."
|
| 154 |
+
)
|
| 155 |
+
torch.nan_to_num(grad, nan=0.0, posinf=self.posinf, neginf=self.neginf, out=grad)
|
| 156 |
+
|
| 157 |
+
return non_finite_grads_count, grad_dtype
|
| 158 |
+
|
| 159 |
+
def on_optimizer_step_begin(self, model: FastGenModel, iteration: int = 0) -> None:
|
| 160 |
+
# unscale the optimizer related to the `model_key`
|
| 161 |
+
assert (
|
| 162 |
+
self.model_key in model.optimizer_dict.keys()
|
| 163 |
+
), f"Keys in optimizer_dict: {list(model.optimizer_dict.keys())}."
|
| 164 |
+
optimizer = model.optimizer_dict[self.model_key]
|
| 165 |
+
# Only unscale if grad_scaler should be used (checks enabled + float32 grads)
|
| 166 |
+
if model.should_use_grad_scaler(optimizer):
|
| 167 |
+
model.grad_scaler.unscale_(optimizer)
|
| 168 |
+
|
| 169 |
+
# Save model device before selecting subnet (subnet may not have .device)
|
| 170 |
+
model_device = model.device
|
| 171 |
+
|
| 172 |
+
# select subnet if specified (by default, we only perform gradient clips on model.net)
|
| 173 |
+
subnets = self.model_key.split(".")
|
| 174 |
+
for subnet in subnets:
|
| 175 |
+
model = getattr(model, subnet)
|
| 176 |
+
|
| 177 |
+
# set nan to num for each parameter
|
| 178 |
+
non_finite_grads_count, grad_dtype = self.nan_to_num(model)
|
| 179 |
+
logger.debug(f"Gradient dtype of {self.model_key}: {grad_dtype}")
|
| 180 |
+
if non_finite_grads_count > 0:
|
| 181 |
+
logger.info(
|
| 182 |
+
f"Number of parameters with non-finite gradients (of dtype {grad_dtype}): {non_finite_grads_count}"
|
| 183 |
+
)
|
| 184 |
+
log_dict = {f"optimizer/non_finite_grads_count (model_key {self.model_key})": non_finite_grads_count}
|
| 185 |
+
|
| 186 |
+
if self.grad_norm is not None:
|
| 187 |
+
# Cast all gradients to precision_grad_clip for numerical stability during clipping
|
| 188 |
+
cast_grads = (
|
| 189 |
+
self.precision_grad_clip is not None
|
| 190 |
+
and grad_dtype is not None
|
| 191 |
+
and grad_dtype != self.precision_grad_clip
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# log value at first iteration
|
| 195 |
+
if iteration == 1 and cast_grads:
|
| 196 |
+
logger.info(f"Casting gradients from {grad_dtype} to {self.precision_grad_clip} before clipping.")
|
| 197 |
+
|
| 198 |
+
# Check if CPU offloading is enabled by looking for DTensor grads on CPU
|
| 199 |
+
# CPU offloading = DTensor local tensors are on CPU
|
| 200 |
+
# No CPU offloading = DTensor local tensors are on GPU (or no DTensors)
|
| 201 |
+
use_fsdp_cpu_offload_clip = False
|
| 202 |
+
for p in model.parameters():
|
| 203 |
+
if p.grad is not None and isinstance(p.grad, DTensor):
|
| 204 |
+
if p.grad._local_tensor.device.type == "cpu":
|
| 205 |
+
use_fsdp_cpu_offload_clip = True
|
| 206 |
+
break
|
| 207 |
+
|
| 208 |
+
with cast_gradients_dtype(model, dtype=self.precision_grad_clip, enabled=cast_grads):
|
| 209 |
+
if use_fsdp_cpu_offload_clip:
|
| 210 |
+
# Use custom clipping for FSDP with CPU offloading
|
| 211 |
+
# Standard clip_grad_norm_ fails because DTensor ops trigger all_reduce on CPU
|
| 212 |
+
total_norm = clip_grad_norm_fsdp(model.parameters(), self.grad_norm, device=model_device)
|
| 213 |
+
else:
|
| 214 |
+
# Standard clipping for non-FSDP or FSDP without CPU offloading
|
| 215 |
+
total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.grad_norm, foreach=True)
|
| 216 |
+
|
| 217 |
+
log_dict[f"optimizer/grad_norm (model_key {self.model_key})"] = total_norm.item()
|
| 218 |
+
|
| 219 |
+
if hasattr(self, "config"):
|
| 220 |
+
# only wandb log when config exists
|
| 221 |
+
if iteration % self.config.trainer.logging_iter == 0 and is_rank0() and wandb.run:
|
| 222 |
+
wandb.log(log_dict, step=iteration)
|
FastGen/fastgen/callbacks/param_count.py
ADDED
|
@@ -0,0 +1,116 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
from typing import TYPE_CHECKING
|
| 6 |
+
|
| 7 |
+
from fastgen.callbacks.callback import Callback
|
| 8 |
+
from fastgen.utils.distributed import world_size
|
| 9 |
+
import fastgen.utils.logging_utils as logger
|
| 10 |
+
import torch
|
| 11 |
+
import wandb
|
| 12 |
+
|
| 13 |
+
try:
|
| 14 |
+
from torch.distributed.tensor import DTensor
|
| 15 |
+
except ImportError:
|
| 16 |
+
DTensor = None
|
| 17 |
+
|
| 18 |
+
if TYPE_CHECKING:
|
| 19 |
+
from fastgen.methods import FastGenModel
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def _get_local_numel(param: torch.Tensor) -> int:
|
| 23 |
+
"""Get the local (sharded) number of elements for a parameter.
|
| 24 |
+
|
| 25 |
+
For DTensor (FSDP2), returns the local shard size.
|
| 26 |
+
For regular tensors, returns the full size.
|
| 27 |
+
"""
|
| 28 |
+
if DTensor is not None and isinstance(param, DTensor):
|
| 29 |
+
return param._local_tensor.numel()
|
| 30 |
+
return param.numel()
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class ParamCountCallback(Callback):
|
| 34 |
+
def on_train_begin(self, model: FastGenModel, **kwargs) -> None:
|
| 35 |
+
# get modules
|
| 36 |
+
modules = {"model": model, **model.model_dict}
|
| 37 |
+
|
| 38 |
+
# iterate over modules
|
| 39 |
+
output = {}
|
| 40 |
+
for name, module in modules.items():
|
| 41 |
+
# Logical (full model) param counts
|
| 42 |
+
trainable_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
|
| 43 |
+
total_params = sum(p.numel() for p in module.parameters())
|
| 44 |
+
|
| 45 |
+
# Local (sharded) param counts - what's actually in memory on this rank
|
| 46 |
+
local_trainable_params = sum(_get_local_numel(p) for p in module.parameters() if p.requires_grad)
|
| 47 |
+
local_total_params = sum(_get_local_numel(p) for p in module.parameters())
|
| 48 |
+
|
| 49 |
+
# check if parameter counts are different across ranks
|
| 50 |
+
if world_size() > 1:
|
| 51 |
+
trainable_params = self.gather_param_counts(trainable_params)
|
| 52 |
+
total_params = self.gather_param_counts(total_params)
|
| 53 |
+
local_trainable_params = self.gather_param_counts(local_trainable_params)
|
| 54 |
+
local_total_params = self.gather_param_counts(local_total_params)
|
| 55 |
+
if len(set(total_params)) == 1 and len(set(trainable_params)) == 1:
|
| 56 |
+
trainable_params = trainable_params[0]
|
| 57 |
+
total_params = total_params[0]
|
| 58 |
+
if len(set(local_total_params)) == 1 and len(set(local_trainable_params)) == 1:
|
| 59 |
+
local_trainable_params = local_trainable_params[0]
|
| 60 |
+
local_total_params = local_total_params[0]
|
| 61 |
+
|
| 62 |
+
# logging
|
| 63 |
+
module_name = module.__class__.__name__
|
| 64 |
+
output.update(
|
| 65 |
+
{
|
| 66 |
+
f"{name}/trainable_params": trainable_params,
|
| 67 |
+
f"{name}/total_params": total_params,
|
| 68 |
+
f"{name}/local_trainable_params": local_trainable_params,
|
| 69 |
+
f"{name}/local_total_params": local_total_params,
|
| 70 |
+
}
|
| 71 |
+
)
|
| 72 |
+
if isinstance(trainable_params, list):
|
| 73 |
+
logger.warning(f"Parameter counts differ across ranks for {module_name}.")
|
| 74 |
+
for rank, (p_train, p) in enumerate(zip(trainable_params, total_params)):
|
| 75 |
+
logger.info(
|
| 76 |
+
f"{name} ({module_name}) has {p_train * 1.e-6:.2f} M trainable and {p * 1.e-6:.2f} M total params on rank {rank}."
|
| 77 |
+
)
|
| 78 |
+
else:
|
| 79 |
+
logger.info(
|
| 80 |
+
f"{name} ({module_name}) has {trainable_params * 1.e-6:.2f} M trainable and {total_params * 1.e-6:.2f} M total params (logical)."
|
| 81 |
+
)
|
| 82 |
+
|
| 83 |
+
# Report local/sharded counts
|
| 84 |
+
if isinstance(local_trainable_params, list):
|
| 85 |
+
for rank, (p_train, p) in enumerate(zip(local_trainable_params, local_total_params)):
|
| 86 |
+
logger.info(
|
| 87 |
+
f"{name} ({module_name}) has {p_train * 1.e-6:.2f} M trainable and {p * 1.e-6:.2f} M total params LOCAL on rank {rank}."
|
| 88 |
+
)
|
| 89 |
+
else:
|
| 90 |
+
is_sharded = local_total_params < total_params if not isinstance(total_params, list) else True
|
| 91 |
+
if is_sharded:
|
| 92 |
+
logger.info(
|
| 93 |
+
f"{name} ({module_name}) has {local_trainable_params * 1.e-6:.2f} M trainable and {local_total_params * 1.e-6:.2f} M total params LOCAL per rank (sharding ratio: {world_size()}x)."
|
| 94 |
+
)
|
| 95 |
+
else:
|
| 96 |
+
logger.info(f"{name} ({module_name}) is NOT sharded (local == logical params).")
|
| 97 |
+
|
| 98 |
+
if wandb.run:
|
| 99 |
+
wandb.run.summary.update(output)
|
| 100 |
+
|
| 101 |
+
def gather_param_counts(self, param_count):
|
| 102 |
+
"""
|
| 103 |
+
Gather parameter counts across all ranks.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
param_count: Parameter count to gather.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
List of parameter counts across all ranks.
|
| 110 |
+
"""
|
| 111 |
+
param_count = torch.tensor(
|
| 112 |
+
[param_count], dtype=torch.long, device="cuda" if torch.cuda.is_available() else "cpu"
|
| 113 |
+
)
|
| 114 |
+
param_count_list = [torch.zeros_like(param_count) for _ in range(world_size())]
|
| 115 |
+
torch.distributed.all_gather(param_count_list, param_count)
|
| 116 |
+
return [p.item() for p in param_count_list]
|
FastGen/fastgen/callbacks/train_profiler.py
ADDED
|
@@ -0,0 +1,138 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
|
| 6 |
+
import time
|
| 7 |
+
from typing import TYPE_CHECKING, Callable
|
| 8 |
+
|
| 9 |
+
import torch
|
| 10 |
+
import wandb
|
| 11 |
+
|
| 12 |
+
from fastgen.callbacks.callback import Callback
|
| 13 |
+
from fastgen.utils.distributed import is_rank0
|
| 14 |
+
import fastgen.utils.logging_utils as logger
|
| 15 |
+
|
| 16 |
+
if TYPE_CHECKING:
|
| 17 |
+
from fastgen.methods import FastGenModel
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
class TrainProfilerCallback(Callback):
|
| 21 |
+
"""Callback for profiling training speed and detailed timing breakdowns.
|
| 22 |
+
|
| 23 |
+
Tracks:
|
| 24 |
+
- iter_time: seconds per iteration (wall clock time)
|
| 25 |
+
- data_load_time: time spent loading data
|
| 26 |
+
- avg_forward_time: average forward pass time across accumulation steps
|
| 27 |
+
- backward_time: time spent in backward pass
|
| 28 |
+
- optim_step_time: time spent in optimizer step
|
| 29 |
+
"""
|
| 30 |
+
|
| 31 |
+
def __init__(self, every_n: int = 100, detailed: bool = True):
|
| 32 |
+
"""Initialize the profiler callback.
|
| 33 |
+
|
| 34 |
+
Args:
|
| 35 |
+
every_n: Log metrics every N iterations
|
| 36 |
+
detailed: If True, log detailed timing breakdown. If False, only log iter_time.
|
| 37 |
+
"""
|
| 38 |
+
# For iter_time tracking
|
| 39 |
+
self.last_log_time = None
|
| 40 |
+
|
| 41 |
+
# For detailed profiling
|
| 42 |
+
self.detailed = detailed
|
| 43 |
+
self.train_step_begin_time = None
|
| 44 |
+
self.accum_begin_times = None
|
| 45 |
+
self.backward_begin_times = None
|
| 46 |
+
self.optimizer_step_begin = None
|
| 47 |
+
self.step_end_time = None
|
| 48 |
+
self.every_n = every_n
|
| 49 |
+
|
| 50 |
+
def on_train_begin(self, model: FastGenModel, iteration: int = 0) -> None:
|
| 51 |
+
if hasattr(self, "config"):
|
| 52 |
+
# overwritten by logging_iter if self.config exists
|
| 53 |
+
self.every_n = self.config.trainer.logging_iter
|
| 54 |
+
logger.info(f"every_n to profile trainer: {self.every_n}")
|
| 55 |
+
|
| 56 |
+
def on_training_step_begin(
|
| 57 |
+
self,
|
| 58 |
+
model: FastGenModel,
|
| 59 |
+
iteration: int = 0,
|
| 60 |
+
):
|
| 61 |
+
if self.detailed:
|
| 62 |
+
self.train_step_begin_time = time.perf_counter()
|
| 63 |
+
self.accum_begin_times = []
|
| 64 |
+
self.backward_begin_times = []
|
| 65 |
+
|
| 66 |
+
def on_training_accum_step_begin(
|
| 67 |
+
self, model: FastGenModel, data_batch: dict[str, torch.Tensor], iteration: int = 0, accum_iter: int = 0
|
| 68 |
+
):
|
| 69 |
+
if self.detailed:
|
| 70 |
+
self.accum_begin_times.append(time.perf_counter())
|
| 71 |
+
|
| 72 |
+
def on_backward_begin(
|
| 73 |
+
self,
|
| 74 |
+
model: FastGenModel,
|
| 75 |
+
data_batch: dict[str, torch.Tensor],
|
| 76 |
+
output_batch: dict[str, torch.Tensor | Callable],
|
| 77 |
+
loss_dict: dict[str, torch.Tensor],
|
| 78 |
+
iteration: int = 0,
|
| 79 |
+
accum_iter: int = 0,
|
| 80 |
+
):
|
| 81 |
+
if self.detailed:
|
| 82 |
+
self.backward_begin_times.append(time.perf_counter())
|
| 83 |
+
|
| 84 |
+
def on_optimizer_step_begin(self, model: FastGenModel, iteration: int = 0):
|
| 85 |
+
if self.detailed:
|
| 86 |
+
self.optimizer_step_begin = time.perf_counter()
|
| 87 |
+
|
| 88 |
+
def on_training_step_end(
|
| 89 |
+
self,
|
| 90 |
+
model: FastGenModel,
|
| 91 |
+
data_batch: dict[str, torch.Tensor],
|
| 92 |
+
output_batch: dict[str, torch.Tensor | Callable],
|
| 93 |
+
loss_dict: dict[str, torch.Tensor],
|
| 94 |
+
iteration: int = 0,
|
| 95 |
+
) -> None:
|
| 96 |
+
del data_batch, output_batch, loss_dict
|
| 97 |
+
|
| 98 |
+
if self.detailed:
|
| 99 |
+
self.step_end_time = time.perf_counter()
|
| 100 |
+
|
| 101 |
+
if hasattr(self, "config"):
|
| 102 |
+
# only wandb log when config exists
|
| 103 |
+
if iteration % self.every_n == 0 and is_rank0():
|
| 104 |
+
metrics = {}
|
| 105 |
+
|
| 106 |
+
# Calculate iter_time (wall clock time per iteration)
|
| 107 |
+
cur_time = time.time()
|
| 108 |
+
if self.last_log_time is not None:
|
| 109 |
+
iter_time = (cur_time - self.last_log_time) / self.every_n
|
| 110 |
+
logger.info(f"{iteration} : avg iteration time {iter_time:.2f} seconds")
|
| 111 |
+
metrics["profiler/avg_iteration_time"] = iter_time
|
| 112 |
+
self.last_log_time = cur_time
|
| 113 |
+
|
| 114 |
+
# Calculate detailed timing breakdown
|
| 115 |
+
if self.detailed and self.accum_begin_times and self.backward_begin_times:
|
| 116 |
+
data_load_time = self.accum_begin_times[0] - self.train_step_begin_time
|
| 117 |
+
forward_time = sum(
|
| 118 |
+
[b - a for (b, a) in zip(self.backward_begin_times, self.accum_begin_times)]
|
| 119 |
+
) / len(self.accum_begin_times)
|
| 120 |
+
backward_time = self.optimizer_step_begin - self.backward_begin_times[-1]
|
| 121 |
+
optim_step_time = self.step_end_time - self.optimizer_step_begin
|
| 122 |
+
|
| 123 |
+
logger.info(f"{iteration} : data loading time {data_load_time:.2f}")
|
| 124 |
+
logger.info(f"{iteration} : avg forward pass time {forward_time:.2f}")
|
| 125 |
+
logger.info(f"{iteration} : backward pass time {backward_time:.2f}")
|
| 126 |
+
logger.info(f"{iteration} : optimizer step time {optim_step_time:.2f}")
|
| 127 |
+
|
| 128 |
+
metrics.update(
|
| 129 |
+
{
|
| 130 |
+
"profiler/data_loading_time": data_load_time,
|
| 131 |
+
"profiler/avg_forward_pass_time": forward_time,
|
| 132 |
+
"profiler/backward_pass_time": backward_time,
|
| 133 |
+
"profiler/optimizer_step_time": optim_step_time,
|
| 134 |
+
}
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
if wandb.run and metrics:
|
| 138 |
+
wandb.log(metrics, step=iteration)
|
FastGen/fastgen/callbacks/wandb.py
ADDED
|
@@ -0,0 +1,404 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from __future__ import annotations
|
| 5 |
+
import os
|
| 6 |
+
from dataclasses import dataclass, field
|
| 7 |
+
import time
|
| 8 |
+
from typing import Optional, Dict, Callable, TYPE_CHECKING
|
| 9 |
+
import gc
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
import torch
|
| 13 |
+
import torchvision
|
| 14 |
+
from torchvision.transforms import functional as tv_F
|
| 15 |
+
|
| 16 |
+
import wandb
|
| 17 |
+
import wandb.util
|
| 18 |
+
|
| 19 |
+
from fastgen.callbacks.callback import Callback
|
| 20 |
+
from fastgen.configs.config_utils import serialize_config
|
| 21 |
+
from fastgen.utils import basic_utils
|
| 22 |
+
|
| 23 |
+
from fastgen.utils.distributed import rank0_only, synchronize, world_size
|
| 24 |
+
from fastgen.utils import logging_utils as logger
|
| 25 |
+
|
| 26 |
+
if TYPE_CHECKING:
|
| 27 |
+
from fastgen.configs.config import BaseConfig
|
| 28 |
+
from fastgen.methods import FastGenModel
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
def to_wandb(
|
| 32 |
+
tensor: torch.Tensor,
|
| 33 |
+
rgb_range: float = 255.0,
|
| 34 |
+
normalized: bool = False,
|
| 35 |
+
max_plot_img: int = 16,
|
| 36 |
+
max_plot_vid: int = 2,
|
| 37 |
+
fps: int = 16,
|
| 38 |
+
channel_before_time: bool = True,
|
| 39 |
+
caption: str | None = None,
|
| 40 |
+
vid_format: str = "mp4",
|
| 41 |
+
) -> wandb.Image | wandb.Video:
|
| 42 |
+
"""
|
| 43 |
+
Convert a tensor to a wandb.Image or wandb.Video.
|
| 44 |
+
|
| 45 |
+
Args:
|
| 46 |
+
tensor (torch.Tensor): Input tensor of shape [B,C,H,W], [B,T,C,H,W], or [B,T,C,H,W,D].
|
| 47 |
+
rgb_range (float, optional): Output target RGB range (can almost definitely be kept as 255).
|
| 48 |
+
Defaults to 255.0.
|
| 49 |
+
normalized (bool, optional): Whether the tensor is normalized to [0,1]. Defaults to False which assumes [-1,1] range.
|
| 50 |
+
max_plot_img (int, optional): Max number of images to plot. Defaults to 16.
|
| 51 |
+
max_plot_vid (int, optional): Max number of videos to plot. Defaults to 2.
|
| 52 |
+
fps (int, optional): Frames per second. Defaults to 8.
|
| 53 |
+
channel_before_time (bool, optional): Whether the tensor is in the format [B,C,T,..]. Set False if the [B,T,C,..] format is used.
|
| 54 |
+
caption (str, optional): Caption for the image or video. Defaults to None.
|
| 55 |
+
vid_format (str, optional): Format of the video file. Defaults to "mp4".
|
| 56 |
+
|
| 57 |
+
Returns:
|
| 58 |
+
wandb.Image | wandb.Video: Format a tensor for logging to W&B.
|
| 59 |
+
"""
|
| 60 |
+
|
| 61 |
+
if tensor.ndim == 5:
|
| 62 |
+
max_plot = max_plot_vid
|
| 63 |
+
if channel_before_time:
|
| 64 |
+
tensor = tensor.permute(0, 2, 1, 3, 4)
|
| 65 |
+
elif tensor.ndim == 4:
|
| 66 |
+
max_plot = max_plot_img
|
| 67 |
+
else:
|
| 68 |
+
raise ValueError(f"Tensor must be 4 or 5 dimensional, but got {tensor.ndim} dimensions")
|
| 69 |
+
|
| 70 |
+
# slice and adjust range
|
| 71 |
+
if normalized:
|
| 72 |
+
factor = rgb_range
|
| 73 |
+
offset = 0.0
|
| 74 |
+
else:
|
| 75 |
+
factor = rgb_range / 2.0
|
| 76 |
+
offset = rgb_range / 2.0
|
| 77 |
+
tensor = tensor[:max_plot].mul(factor).add(offset).clip_(0, rgb_range).to(torch.uint8)
|
| 78 |
+
|
| 79 |
+
# convert to wandb.Image or wandb.Video
|
| 80 |
+
assert tensor.shape[-3] == 3, "Make sure that the data is in ..., C, H, W format"
|
| 81 |
+
if tensor.ndim == 5:
|
| 82 |
+
return wandb.Video(tensor.cpu().numpy(), fps=fps, format=vid_format, caption=caption)
|
| 83 |
+
else:
|
| 84 |
+
image_grid = torchvision.utils.make_grid(tensor, nrow=4, pad_value=1)
|
| 85 |
+
image_grid = tv_F.to_pil_image(image_grid)
|
| 86 |
+
return wandb.Image(image_grid, caption=caption)
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
@rank0_only
|
| 90 |
+
def init_wandb(config: BaseConfig):
|
| 91 |
+
# wandb login
|
| 92 |
+
wandb_credential = config.log_config.wandb_credential
|
| 93 |
+
if os.path.isfile(wandb_credential):
|
| 94 |
+
os.environ["WANDB_API_KEY"] = open(wandb_credential, encoding="utf-8").read().strip("\n")
|
| 95 |
+
logger.info(f"Loading WANDB_API_KEY from {wandb_credential}")
|
| 96 |
+
|
| 97 |
+
wandb_config = config.log_config
|
| 98 |
+
|
| 99 |
+
# Resume with or generate a wandb id
|
| 100 |
+
logger.info(f"wandb_config.save_path: {wandb_config.save_path}")
|
| 101 |
+
os.makedirs(wandb_config.save_path, exist_ok=True)
|
| 102 |
+
wandb_id_path = f"{wandb_config.save_path}/wandb_id.txt"
|
| 103 |
+
if os.path.isfile(wandb_id_path):
|
| 104 |
+
wandb_id = open(wandb_id_path, encoding="utf-8").read().strip()
|
| 105 |
+
logger.info(f"Resuming with an existing wandb id: {wandb_id}")
|
| 106 |
+
else:
|
| 107 |
+
wandb_id = wandb.util.generate_id()
|
| 108 |
+
with open(wandb_id_path, "w", encoding="utf-8") as f:
|
| 109 |
+
f.write(f"{wandb_id}\n")
|
| 110 |
+
logger.info(f"Generating a wandb id: {wandb_id}")
|
| 111 |
+
|
| 112 |
+
# Get config as plain dict
|
| 113 |
+
config_resolved = serialize_config(config, return_type="dict")
|
| 114 |
+
|
| 115 |
+
# Initialize the wandb library.
|
| 116 |
+
wandb.init(
|
| 117 |
+
id=wandb_id,
|
| 118 |
+
project=wandb_config.project,
|
| 119 |
+
group=wandb_config.group,
|
| 120 |
+
name=wandb_config.name,
|
| 121 |
+
config=config_resolved,
|
| 122 |
+
dir=wandb_config.save_path,
|
| 123 |
+
resume="allow",
|
| 124 |
+
mode=wandb_config.wandb_mode,
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# Save a copy of code to a wandb Artifact (this can be slow)
|
| 128 |
+
# Make code upload optional to avoid distributed training delays
|
| 129 |
+
upload_code = basic_utils.str2bool(os.getenv("WANDB_UPLOAD_CODE", "false"))
|
| 130 |
+
if upload_code:
|
| 131 |
+
logger.info("Uploading code to wandb (this may take a few minutes)...")
|
| 132 |
+
wandb.run.log_code(".")
|
| 133 |
+
logger.info("Code upload to wandb completed")
|
| 134 |
+
else:
|
| 135 |
+
logger.info("Wandb code upload disabled (set WANDB_UPLOAD_CODE=true to enable)")
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclass
|
| 139 |
+
class _LossDictRecord:
|
| 140 |
+
loss_dict: dict = field(default_factory=dict)
|
| 141 |
+
iter_count_dict: dict = field(default_factory=dict)
|
| 142 |
+
|
| 143 |
+
def add(self, loss_dict: Optional[Dict[str, torch.Tensor]]) -> None:
|
| 144 |
+
if loss_dict is not None:
|
| 145 |
+
for loss_name, loss_val in loss_dict.items():
|
| 146 |
+
self.loss_dict[loss_name] = self.loss_dict.get(loss_name, 0.0) + loss_val.float().item()
|
| 147 |
+
self.iter_count_dict[loss_name] = self.iter_count_dict.get(loss_name, 0) + 1
|
| 148 |
+
|
| 149 |
+
def reset(self) -> None:
|
| 150 |
+
self.loss_dict = {}
|
| 151 |
+
self.iter_count_dict = {}
|
| 152 |
+
|
| 153 |
+
def gather_dict(self, dictionary: Dict[str, float | int]) -> Dict[str, float | int]:
|
| 154 |
+
n_ranks = world_size()
|
| 155 |
+
if n_ranks > 1:
|
| 156 |
+
dict_list = [None for _ in range(n_ranks)]
|
| 157 |
+
torch.distributed.all_gather_object(dict_list, dictionary)
|
| 158 |
+
# from list of dicts to dict of summed values
|
| 159 |
+
dictionary = {}
|
| 160 |
+
for d in dict_list:
|
| 161 |
+
for key, value in d.items():
|
| 162 |
+
dictionary[key] = dictionary.get(key, 0.0) + value
|
| 163 |
+
return dictionary
|
| 164 |
+
|
| 165 |
+
def get_stat(self) -> Dict[str, float]:
|
| 166 |
+
# number of ranks that logged this loss
|
| 167 |
+
rank_dict = self.gather_dict({k: 1 for k in self.loss_dict.keys()})
|
| 168 |
+
# number of times this loss was computed
|
| 169 |
+
count_dict = self.gather_dict(self.iter_count_dict)
|
| 170 |
+
# sum of all losses
|
| 171 |
+
loss_dict = self.gather_dict(self.loss_dict)
|
| 172 |
+
|
| 173 |
+
avg_loss_dict = {}
|
| 174 |
+
for loss_name, loss_val in loss_dict.items():
|
| 175 |
+
count = count_dict.get(loss_name, 0)
|
| 176 |
+
ranks = rank_dict.get(loss_name, 1)
|
| 177 |
+
iter_count = count / ranks
|
| 178 |
+
avg_loss = (loss_val / count) * (ranks / world_size()) if count > 0 else 0.0
|
| 179 |
+
logger.info(f"avg_{loss_name}: {avg_loss:.4f}".ljust(30) + f"iter count: {iter_count}")
|
| 180 |
+
avg_loss_dict[loss_name] = avg_loss
|
| 181 |
+
self.reset()
|
| 182 |
+
return avg_loss_dict
|
| 183 |
+
|
| 184 |
+
|
| 185 |
+
class WandbCallback(Callback):
|
| 186 |
+
"""
|
| 187 |
+
The callback gets precision for data from model
|
| 188 |
+
"""
|
| 189 |
+
|
| 190 |
+
def __init__(
|
| 191 |
+
self,
|
| 192 |
+
*args,
|
| 193 |
+
validation_logging_step: int = 1,
|
| 194 |
+
sample_logging_iter: Optional[int] = None,
|
| 195 |
+
vid_format: str = "mp4",
|
| 196 |
+
**kwargs,
|
| 197 |
+
):
|
| 198 |
+
super().__init__(*args, **kwargs)
|
| 199 |
+
|
| 200 |
+
self.validation_logging_step = validation_logging_step
|
| 201 |
+
self.sample_logging_iter = sample_logging_iter
|
| 202 |
+
self.val_sample_map = None
|
| 203 |
+
self.vid_format = vid_format
|
| 204 |
+
self.loss_dict_record = _LossDictRecord()
|
| 205 |
+
self.val_loss_dict_record = _LossDictRecord()
|
| 206 |
+
|
| 207 |
+
def on_app_begin(self) -> None:
|
| 208 |
+
assert hasattr(self, "config"), "Missing config in WandbCallback."
|
| 209 |
+
init_wandb(self.config)
|
| 210 |
+
self.offload_module_in_decoding = self.config.trainer.offload_module_in_decoding
|
| 211 |
+
# disable offloading if using FSDP
|
| 212 |
+
if self.config.trainer.fsdp:
|
| 213 |
+
self.offload_module_in_decoding = False
|
| 214 |
+
if self.sample_logging_iter is None:
|
| 215 |
+
self.sample_logging_iter = self.config.trainer.logging_iter
|
| 216 |
+
synchronize()
|
| 217 |
+
|
| 218 |
+
@rank0_only
|
| 219 |
+
def on_optimizer_step_begin(self, model: FastGenModel, iteration: int = 0) -> None:
|
| 220 |
+
assert hasattr(self, "config"), "Missing config in WandbCallback."
|
| 221 |
+
if iteration % self.config.trainer.logging_iter == 0:
|
| 222 |
+
for name, scheduler in model.scheduler_dict.items():
|
| 223 |
+
wandb.log({f"optimizer/lr_{name}": scheduler.get_last_lr()[0]}, step=iteration)
|
| 224 |
+
|
| 225 |
+
def get_sample_map(
|
| 226 |
+
self, model: FastGenModel, data_batch: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor | Callable]
|
| 227 |
+
) -> dict[str, wandb.Image | wandb.Video]:
|
| 228 |
+
# Collect generated and real data and create copies to avoid modifying the original dicts
|
| 229 |
+
sample_map = {}
|
| 230 |
+
gen_rand = output_batch["gen_rand"]
|
| 231 |
+
if isinstance(gen_rand, Callable):
|
| 232 |
+
synchronize()
|
| 233 |
+
gen_rand = gen_rand()
|
| 234 |
+
synchronize()
|
| 235 |
+
|
| 236 |
+
# Avoid modifying the original dicts
|
| 237 |
+
data_batch = data_batch.copy()
|
| 238 |
+
output_batch = output_batch.copy()
|
| 239 |
+
|
| 240 |
+
# Decide whether we want to visualize multistep teacher generation
|
| 241 |
+
if self.config.trainer.visualize_teacher:
|
| 242 |
+
assert "input_rand" in output_batch, "We need to know the noise to visualize teacher generation"
|
| 243 |
+
teacher_output = model.sample(
|
| 244 |
+
model.teacher,
|
| 245 |
+
output_batch["input_rand"][0:1],
|
| 246 |
+
data_batch["condition"][0:1], # e.g. text condition encoded by the text encoder
|
| 247 |
+
data_batch["neg_condition"][0:1], # e.g. negative text condition encoded by the text encoder
|
| 248 |
+
)
|
| 249 |
+
output_batch["gen_teacher"] = teacher_output
|
| 250 |
+
|
| 251 |
+
# Decode to pixel if it's in latent space
|
| 252 |
+
if hasattr(model.net, "init_preprocessors"):
|
| 253 |
+
torch.cuda.empty_cache()
|
| 254 |
+
device_nets = model.device
|
| 255 |
+
|
| 256 |
+
has_vae = hasattr(model.net, "vae")
|
| 257 |
+
if not has_vae:
|
| 258 |
+
model.net.init_vae()
|
| 259 |
+
model.net.vae.to(device=device_nets, dtype=model.precision)
|
| 260 |
+
|
| 261 |
+
if self.offload_module_in_decoding:
|
| 262 |
+
# offload the unneeded models to CPU (enable it if hitting OOM here)
|
| 263 |
+
logger.info(
|
| 264 |
+
f"GPU Memory BEFORE moving nets to CPU: {torch.cuda.memory_allocated(device_nets) / 1024 ** 2:.2f} MB"
|
| 265 |
+
)
|
| 266 |
+
if hasattr(model, "fake_score"):
|
| 267 |
+
model.fake_score = model.fake_score.to("cpu")
|
| 268 |
+
if hasattr(model, "teacher"):
|
| 269 |
+
model.teacher = model.teacher.to("cpu")
|
| 270 |
+
logger.info(
|
| 271 |
+
f"GPU Memory AFTER moving nets to CPU: {torch.cuda.memory_allocated(device_nets) / 1024 ** 2:.2f} MB"
|
| 272 |
+
)
|
| 273 |
+
synchronize()
|
| 274 |
+
|
| 275 |
+
with basic_utils.inference_mode(precision_amp=model.precision_amp_enc, device_type=device_nets.type):
|
| 276 |
+
if "real" in data_batch:
|
| 277 |
+
# only generate one sample for video
|
| 278 |
+
limit = 1 if len(data_batch["real"].shape) == 5 else len(data_batch["real"])
|
| 279 |
+
data_batch["real"] = model.net.vae.decode(data_batch["real"][:limit])
|
| 280 |
+
if isinstance(gen_rand, dict):
|
| 281 |
+
for k in gen_rand:
|
| 282 |
+
limit = 1 if len(gen_rand[k].shape) == 5 else len(gen_rand[k])
|
| 283 |
+
gen_rand[k] = model.net.vae.decode(gen_rand[k][:limit])
|
| 284 |
+
else:
|
| 285 |
+
limit = 1 if len(gen_rand.shape) == 5 else len(gen_rand)
|
| 286 |
+
gen_rand = model.net.vae.decode(gen_rand[:limit])
|
| 287 |
+
|
| 288 |
+
if "gen_teacher" in output_batch:
|
| 289 |
+
output_batch["gen_teacher"] = model.net.vae.decode(output_batch["gen_teacher"][:limit])
|
| 290 |
+
if logger.LOG_LEVEL == "DEBUG" and "gen_rand_train" in output_batch:
|
| 291 |
+
output_batch["gen_rand_train"] = model.net.vae.decode(output_batch["gen_rand_train"][:limit])
|
| 292 |
+
|
| 293 |
+
if not has_vae:
|
| 294 |
+
del model.net.vae
|
| 295 |
+
|
| 296 |
+
if self.offload_module_in_decoding:
|
| 297 |
+
# move back fake_score to gpu
|
| 298 |
+
if hasattr(model, "fake_score"):
|
| 299 |
+
model.fake_score = model.fake_score.to(device_nets)
|
| 300 |
+
if hasattr(model, "teacher"):
|
| 301 |
+
model.teacher = model.teacher.to(device_nets)
|
| 302 |
+
logger.info(
|
| 303 |
+
f"GPU Memory AFTER moving nets back to GPU: {torch.cuda.memory_allocated(device_nets) / 1024 ** 2:.2f} MB"
|
| 304 |
+
)
|
| 305 |
+
synchronize()
|
| 306 |
+
|
| 307 |
+
if wandb.run:
|
| 308 |
+
if (
|
| 309 |
+
"condition_raw" in data_batch
|
| 310 |
+
and isinstance(data_batch["condition_raw"], (list, tuple))
|
| 311 |
+
and isinstance(data_batch["condition_raw"][0], str)
|
| 312 |
+
):
|
| 313 |
+
caption = "\n".join(data_batch["condition_raw"][: len(gen_rand)])
|
| 314 |
+
else:
|
| 315 |
+
caption = None
|
| 316 |
+
if isinstance(gen_rand, dict):
|
| 317 |
+
for k in gen_rand:
|
| 318 |
+
sample_map[f"student/generation/{k}"] = to_wandb(
|
| 319 |
+
gen_rand[k], caption=caption, vid_format=self.vid_format
|
| 320 |
+
)
|
| 321 |
+
else:
|
| 322 |
+
sample_map["student/generation"] = to_wandb(gen_rand, caption=caption, vid_format=self.vid_format)
|
| 323 |
+
if "real" in data_batch:
|
| 324 |
+
sample_map["data/real"] = to_wandb(data_batch["real"], caption=caption, vid_format=self.vid_format)
|
| 325 |
+
if "gen_teacher" in output_batch:
|
| 326 |
+
sample_map["teacher/generation"] = to_wandb(
|
| 327 |
+
output_batch["gen_teacher"], caption=caption, vid_format=self.vid_format
|
| 328 |
+
)
|
| 329 |
+
if logger.LOG_LEVEL == "DEBUG" and "gen_rand_train" in output_batch:
|
| 330 |
+
sample_map["student/generation_train"] = to_wandb(
|
| 331 |
+
output_batch["gen_rand_train"], caption=caption, vid_format=self.vid_format
|
| 332 |
+
)
|
| 333 |
+
|
| 334 |
+
return sample_map
|
| 335 |
+
|
| 336 |
+
def log_sample_map(
|
| 337 |
+
self,
|
| 338 |
+
model: FastGenModel,
|
| 339 |
+
data_batch: dict[str, torch.Tensor],
|
| 340 |
+
output_batch: dict[str, torch.Tensor | Callable],
|
| 341 |
+
suffix: str = "",
|
| 342 |
+
iteration: int = 0,
|
| 343 |
+
group: str = "train",
|
| 344 |
+
) -> None:
|
| 345 |
+
sample_map = self.get_sample_map(model, data_batch, output_batch)
|
| 346 |
+
sample_map = {f"{group}_media/{k}{suffix}": v for k, v in sample_map.items()}
|
| 347 |
+
if wandb.run:
|
| 348 |
+
wandb.log(sample_map, step=iteration)
|
| 349 |
+
synchronize()
|
| 350 |
+
gc.collect()
|
| 351 |
+
torch.cuda.empty_cache()
|
| 352 |
+
|
| 353 |
+
def log_stats(self, loss_dict_record: _LossDictRecord, iteration: int = 0, group: str = "train") -> None:
|
| 354 |
+
logger.info(f"logging {group} stats at iteration {iteration}" + "-" * 20)
|
| 355 |
+
# Collect distributed statistics
|
| 356 |
+
avg_loss_dict = loss_dict_record.get_stat()
|
| 357 |
+
stats = {f"{group}/{name}": val for name, val in avg_loss_dict.items()}
|
| 358 |
+
base_info = {"optimizer/iteration": iteration}
|
| 359 |
+
|
| 360 |
+
# log stats and base info
|
| 361 |
+
if wandb.run:
|
| 362 |
+
wandb.log(stats, step=iteration)
|
| 363 |
+
wandb.log(base_info, step=iteration)
|
| 364 |
+
|
| 365 |
+
def on_training_step_end(
|
| 366 |
+
self,
|
| 367 |
+
model: FastGenModel,
|
| 368 |
+
data_batch: dict[str, torch.Tensor],
|
| 369 |
+
output_batch: dict[str, torch.Tensor | Callable],
|
| 370 |
+
loss_dict: dict[str, torch.Tensor],
|
| 371 |
+
iteration: int = 0,
|
| 372 |
+
) -> None:
|
| 373 |
+
self.loss_dict_record.add(loss_dict)
|
| 374 |
+
time_start = time.perf_counter()
|
| 375 |
+
logged = False
|
| 376 |
+
if iteration % self.config.trainer.logging_iter == 0 or iteration == 1:
|
| 377 |
+
self.log_stats(self.loss_dict_record, iteration=iteration, group="train")
|
| 378 |
+
logged = True
|
| 379 |
+
if iteration % self.sample_logging_iter == 0 or iteration == 1:
|
| 380 |
+
self.log_sample_map(model, data_batch, output_batch, iteration=iteration, group="train")
|
| 381 |
+
logged = True
|
| 382 |
+
if logged:
|
| 383 |
+
time_taken = time.perf_counter() - time_start
|
| 384 |
+
logger.info(f"WandB logging complete after {time_taken:.2f} seconds")
|
| 385 |
+
|
| 386 |
+
def on_validation_step_end(
|
| 387 |
+
self,
|
| 388 |
+
model: FastGenModel,
|
| 389 |
+
data_batch: dict[str, torch.Tensor],
|
| 390 |
+
output_batch: dict[str, torch.Tensor | Callable],
|
| 391 |
+
loss_dict: dict[str, torch.Tensor],
|
| 392 |
+
step: int = 0,
|
| 393 |
+
iteration: int = 0,
|
| 394 |
+
idx: int = 0,
|
| 395 |
+
) -> None:
|
| 396 |
+
self.val_loss_dict_record.add(loss_dict)
|
| 397 |
+
|
| 398 |
+
if step % self.validation_logging_step == 0:
|
| 399 |
+
self.log_sample_map(
|
| 400 |
+
model, data_batch, output_batch, suffix=f"_{step}", iteration=iteration, group=f"val{idx}"
|
| 401 |
+
)
|
| 402 |
+
|
| 403 |
+
def on_validation_end(self, model: FastGenModel, iteration: int = 0, idx: int = 0) -> None:
|
| 404 |
+
self.log_stats(self.val_loss_dict_record, iteration=iteration, group=f"val{idx}")
|
FastGen/fastgen/configs/README.md
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Configuration System
|
| 2 |
+
|
| 3 |
+
FastGen uses a hierarchical Python-based configuration system built on [Hydra](https://hydra.cc/), [OmegaConf](https://omegaconf.readthedocs.io/), and [attrs](https://www.attrs.org/).
|
| 4 |
+
|
| 5 |
+
## Directory Structure
|
| 6 |
+
|
| 7 |
+
```
|
| 8 |
+
fastgen/configs/
|
| 9 |
+
├── experiments/ # Experiment configs (cifar10, Wan, sdxl, etc.)
|
| 10 |
+
├── methods/ # Method-specific configs (DMD2, CM, KD, SFT, etc.)
|
| 11 |
+
├── callbacks.py # Callback configurations (EMA, WandB, GradClip, etc.)
|
| 12 |
+
├── config_utils.py # Utilities (import, override, serialize configs)
|
| 13 |
+
├── config.py # Base config classes (BaseConfig, BaseModelConfig, BaseTrainerConfig)
|
| 14 |
+
├── data.py # Dataset/dataloader configurations
|
| 15 |
+
├── discriminator.py # Discriminator configs (for GAN-based methods)
|
| 16 |
+
├── net.py # Network architecture configurations
|
| 17 |
+
├── opt.py # Optimizer and scheduler configurations
|
| 18 |
+
```
|
| 19 |
+
|
| 20 |
+
## Config Hierarchy
|
| 21 |
+
|
| 22 |
+
1. **Base Configs** (`config.py`): Core classes defining model, trainer, data, and logging settings
|
| 23 |
+
2. **Method Configs** (`methods/`): Extend base configs with method-specific parameters
|
| 24 |
+
3. **Experiment Configs** (`experiments/`): Concrete configs for specific dataset/model combinations
|
| 25 |
+
|
| 26 |
+
## LazyCall Pattern
|
| 27 |
+
|
| 28 |
+
Deferred instantiation using `LazyCall`:
|
| 29 |
+
|
| 30 |
+
```python
|
| 31 |
+
from fastgen.utils import LazyCall as L
|
| 32 |
+
from fastgen.methods import DMD2Model
|
| 33 |
+
|
| 34 |
+
model_class = L(DMD2Model)(config=None) # Config dict with _target_
|
| 35 |
+
model = instantiate(model_class) # Instantiate later
|
| 36 |
+
```
|
| 37 |
+
|
| 38 |
+
## Command-Line Arguments
|
| 39 |
+
|
| 40 |
+
```bash
|
| 41 |
+
python train.py --config=path/to/config.py [--log_level LEVEL] [--dryrun] - key=value
|
| 42 |
+
```
|
| 43 |
+
|
| 44 |
+
| Argument | Description |
|
| 45 |
+
|----------|-------------|
|
| 46 |
+
| `--config` | Path to the config file |
|
| 47 |
+
| `--log_level` | Log level: DEBUG, INFO (default), WARNING, ERROR |
|
| 48 |
+
| `--dryrun` | Print resolved config and exit without training |
|
| 49 |
+
| `-` | Separator before config overrides (required) |
|
| 50 |
+
|
| 51 |
+
Examples:
|
| 52 |
+
|
| 53 |
+
```bash
|
| 54 |
+
# Override training settings
|
| 55 |
+
python train.py --config=fastgen/configs/experiments/EDM/config_dmd2_test.py - \
|
| 56 |
+
trainer.max_iter=10000 \
|
| 57 |
+
model.gan_loss_weight_gen=0. \
|
| 58 |
+
log_config.name=my_experiment
|
| 59 |
+
|
| 60 |
+
# Debug config without training
|
| 61 |
+
python train.py --config=fastgen/configs/experiments/EDM/config_dmd2_test.py --dryrun
|
| 62 |
+
|
| 63 |
+
# Verbose logging
|
| 64 |
+
python train.py --config=fastgen/configs/experiments/EDM/config_dmd2_test.py --log_level DEBUG
|
| 65 |
+
```
|
| 66 |
+
|
| 67 |
+
## Key Config Classes
|
| 68 |
+
|
| 69 |
+
| Class | Purpose |
|
| 70 |
+
|-------|---------|
|
| 71 |
+
| `BaseConfig` | Top-level: model, trainer, dataloader, logging |
|
| 72 |
+
| `BaseModelConfig` | Network, optimizer, precision, EMA, guidance |
|
| 73 |
+
| `BaseTrainerConfig` | Checkpointing, callbacks, DDP/FSDP, iterations |
|
| 74 |
+
| `LogConfig` | Project, group, name, wandb settings |
|
| 75 |
+
|
| 76 |
+
## Environment Variables
|
| 77 |
+
|
| 78 |
+
### Core Variables
|
| 79 |
+
|
| 80 |
+
| Variable | Description | Default |
|
| 81 |
+
|----------|-------------|---------|
|
| 82 |
+
| `FASTGEN_OUTPUT_ROOT` | Root directory for checkpoints, logs, and outputs | `FASTGEN_OUTPUT` |
|
| 83 |
+
| `DATA_ROOT_DIR` | Root directory for datasets | `$FASTGEN_OUTPUT_ROOT/DATA` |
|
| 84 |
+
| `CKPT_ROOT_DIR` | Root directory for pretrained checkpoints | `$FASTGEN_OUTPUT_ROOT/MODEL` |
|
| 85 |
+
| `HF_HOME` | HuggingFace cache directory | `$FASTGEN_OUTPUT_ROOT/.cache` |
|
| 86 |
+
| `LOCAL_FILES_ONLY` | Use only local files, skip downloads from HuggingFace | `false` |
|
| 87 |
+
| `WANDB_API_KEY` | W&B API key for persistent logging ([get yours here](https://wandb.ai/settings)) | (none, W&B will prompt) |
|
| 88 |
+
| `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION`, `AWS_ENDPOINT_URL` | S3 credentials for data and checkpoint storage | (none) |
|
| 89 |
+
|
| 90 |
+
### Loading Credentials from Files
|
| 91 |
+
|
| 92 |
+
As an alternative to setting environment variables directly, credentials can be loaded automatically from files in the `./credentials/` directory:
|
| 93 |
+
|
| 94 |
+
| File | Environment Variables Set |
|
| 95 |
+
|------|---------------------------|
|
| 96 |
+
| `./credentials/wandb_api.txt` | `WANDB_API_KEY` (plain text file containing your API key) |
|
| 97 |
+
| `./credentials/s3.json` | `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION`, `AWS_ENDPOINT_URL` (JSON format below) |
|
| 98 |
+
|
| 99 |
+
**Format for `s3.json`:**
|
| 100 |
+
|
| 101 |
+
```json
|
| 102 |
+
{
|
| 103 |
+
"aws_access_key_id": "<your_access_key>",
|
| 104 |
+
"aws_secret_access_key": "<your_secret_key>",
|
| 105 |
+
"region_name": "<region>",
|
| 106 |
+
"endpoint_url": "<s3_endpoint_url>"
|
| 107 |
+
}
|
| 108 |
+
```
|
FastGen/fastgen/configs/__init__.py
ADDED
|
@@ -0,0 +1,2 @@
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
FastGen/fastgen/configs/callbacks.py
ADDED
|
@@ -0,0 +1,63 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from fastgen.utils import LazyCall as L
|
| 5 |
+
|
| 6 |
+
from fastgen.callbacks.ct_schedule import CTScheduleCallback
|
| 7 |
+
from fastgen.callbacks.grad_clip import GradClipCallback
|
| 8 |
+
from fastgen.callbacks.param_count import ParamCountCallback
|
| 9 |
+
from fastgen.callbacks.wandb import WandbCallback
|
| 10 |
+
from fastgen.callbacks.ema import EMACallback
|
| 11 |
+
from fastgen.callbacks.train_profiler import TrainProfilerCallback
|
| 12 |
+
from fastgen.callbacks.gpu_stats import GPUStatsCallback
|
| 13 |
+
from fastgen.callbacks.forced_weight_norm import ForcedWeightNormCallback
|
| 14 |
+
from fastgen.callbacks.gpu_mem_profiler import MemTrackerCallback
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
CTSchedule_CALLBACK = dict(
|
| 18 |
+
ct_schedule=L(CTScheduleCallback)(q=2.0, ratio_limit=0.999, kimg_per_stage=12500),
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
EMA_CALLBACK = dict(
|
| 22 |
+
ema=L(EMACallback)(type="constant", beta=0.9999, gamma=16.97, ema_halflife_kimg=500, ema_rampup_ratio=0.05),
|
| 23 |
+
)
|
| 24 |
+
|
| 25 |
+
EMA_CONST_CALLBACKS = dict(
|
| 26 |
+
ema_9999=L(EMACallback)(type="constant", beta=0.9999, ema_name="ema_9999"),
|
| 27 |
+
ema_99995=L(EMACallback)(type="constant", beta=0.99995, ema_name="ema_99995"),
|
| 28 |
+
ema_9996=L(EMACallback)(type="constant", beta=0.9996, ema_name="ema_9996"),
|
| 29 |
+
)
|
| 30 |
+
|
| 31 |
+
EMA_POWER_CALLBACKS = dict(
|
| 32 |
+
ema_1=L(EMACallback)(type="power", gamma=96.99, ema_name="ema_1"),
|
| 33 |
+
ema_5=L(EMACallback)(type="power", gamma=16.97, ema_name="ema_5"),
|
| 34 |
+
ema_10=L(EMACallback)(type="power", gamma=6.94, ema_name="ema_10"),
|
| 35 |
+
)
|
| 36 |
+
|
| 37 |
+
ForcedWeightNorm_CALLBACK = dict(
|
| 38 |
+
forced_weight_norm=L(ForcedWeightNormCallback)(),
|
| 39 |
+
)
|
| 40 |
+
|
| 41 |
+
GradClip_CALLBACK = dict(
|
| 42 |
+
grad_clip=L(GradClipCallback)(grad_norm=10.0, model_key="net"),
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
GPUStats_CALLBACK = dict(
|
| 46 |
+
gpu_stats=L(GPUStatsCallback)(every_n=100),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
ParamCount_CALLBACK = dict(
|
| 50 |
+
param_count=L(ParamCountCallback)(),
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
TrainProfiler_CALLBACK = dict(
|
| 54 |
+
train_profiler=L(TrainProfilerCallback)(every_n=100),
|
| 55 |
+
)
|
| 56 |
+
|
| 57 |
+
WANDB_CALLBACK = dict(
|
| 58 |
+
wandb=L(WandbCallback)(sample_logging_iter=None),
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
MemTracker_CALLBACK = dict(
|
| 62 |
+
mem_tracker=L(MemTrackerCallback)(),
|
| 63 |
+
)
|
FastGen/fastgen/configs/config.py
ADDED
|
@@ -0,0 +1,254 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from typing import Any, List, Optional, Dict
|
| 6 |
+
|
| 7 |
+
import copy
|
| 8 |
+
import attrs
|
| 9 |
+
from omegaconf import DictConfig
|
| 10 |
+
|
| 11 |
+
from fastgen.utils import LazyCall as L
|
| 12 |
+
from fastgen.configs.callbacks import WANDB_CALLBACK
|
| 13 |
+
from fastgen.configs.data import CIFAR10_Loader_Config
|
| 14 |
+
from fastgen.configs.net import EDM_CIFAR10_Config as EDMConfig
|
| 15 |
+
from fastgen.configs.opt import BaseOptimizerConfig, BaseSchedulerConfig
|
| 16 |
+
from fastgen.methods import FastGenModel
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@attrs.define(slots=False)
|
| 20 |
+
class CuDNNConfig:
|
| 21 |
+
# If set to True, cudnn will use deterministic cudnn functions for better reproducibility.
|
| 22 |
+
deterministic: bool = False
|
| 23 |
+
# If set to True, cudnn will benchmark several algorithms and pick the fastest one.
|
| 24 |
+
benchmark: bool = True
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
@attrs.define(slots=False)
|
| 28 |
+
class LogConfig:
|
| 29 |
+
# Project name
|
| 30 |
+
project: str = "fastgen"
|
| 31 |
+
# Experiment name
|
| 32 |
+
group: str = "cifar10"
|
| 33 |
+
# Run/job name
|
| 34 |
+
name: str = "debug"
|
| 35 |
+
# W&B mode, can be "online" or "disabled".
|
| 36 |
+
wandb_mode: str = "online"
|
| 37 |
+
# Wandb credential path
|
| 38 |
+
wandb_credential: str = "./credentials/wandb_api.txt"
|
| 39 |
+
|
| 40 |
+
# save path
|
| 41 |
+
@property
|
| 42 |
+
def save_path(self) -> str:
|
| 43 |
+
return os.path.join(
|
| 44 |
+
os.environ.get("FASTGEN_OUTPUT_ROOT", "FASTGEN_OUTPUT"), f"{self.project}/{self.group}/{self.name}"
|
| 45 |
+
)
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
@attrs.define(slots=False)
|
| 49 |
+
class EvalConfig:
|
| 50 |
+
# Number of samples to generate
|
| 51 |
+
num_samples: int = 50000
|
| 52 |
+
# Save a small batch of images
|
| 53 |
+
save_images: bool = False
|
| 54 |
+
# Minimum checkpoint to evaluate
|
| 55 |
+
min_ckpt: int = 0
|
| 56 |
+
# Maximum checkpoint to evaluate
|
| 57 |
+
max_ckpt: int = 100000000
|
| 58 |
+
# Directory to save samples
|
| 59 |
+
samples_dir: str = "samples"
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
@attrs.define(slots=False)
|
| 63 |
+
class BaseCheckpointerConfig:
|
| 64 |
+
save_dir: str = "checkpoints"
|
| 65 |
+
use_s3: bool = False
|
| 66 |
+
s3_container: str = "s3://checkpoints/fastgen"
|
| 67 |
+
s3_credential: str = "./credentials/s3.json"
|
| 68 |
+
|
| 69 |
+
# path to pretrained model (from previous stages),
|
| 70 |
+
# it's used by loading fsdp/ddp trained ckpt to an fsdp/ddp pipeline
|
| 71 |
+
pretrained_ckpt_path: str = ""
|
| 72 |
+
# submodule names of model and keys of a pretrained checkpoint of the form {"model": {"submodule_key": ...}, ...}
|
| 73 |
+
pretrained_ckpt_key_map: Dict[str, str] = {"net": "net"}
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
@attrs.define(slots=False)
|
| 77 |
+
class SampleTConfig:
|
| 78 |
+
"""Config for sampling t from a time distribution."""
|
| 79 |
+
|
| 80 |
+
# time distribution (currently supporting: uniform, lognormal, polynomial, logitnormal, shift, and log_t)
|
| 81 |
+
time_dist_type: str = "uniform"
|
| 82 |
+
# mu in lognormal, logitnormal, and log_t distributions
|
| 83 |
+
train_p_mean: float = -1.1
|
| 84 |
+
# sigma in lognormal, logitnormal, and log_t distributions
|
| 85 |
+
train_p_std: float = 2.0
|
| 86 |
+
# shift value in shifted sampling (t_shifted = t * shift / (t * (shift - 1) + 1))
|
| 87 |
+
shift: float = 5.0
|
| 88 |
+
# lowest value in truncated range
|
| 89 |
+
min_t: float = 0.002
|
| 90 |
+
# highest value in truncated range
|
| 91 |
+
max_t: float = 80.0
|
| 92 |
+
# If provided, it is in the form [t_max, ..., 0] where len(t_list) needs to equal student_sample_steps + 1
|
| 93 |
+
t_list: Optional[List[float]] = None
|
| 94 |
+
# degree of freedom in log-transformed student-t distribution
|
| 95 |
+
log_t_df: float = 0.01
|
| 96 |
+
|
| 97 |
+
|
| 98 |
+
@attrs.define(slots=False)
|
| 99 |
+
class BaseModelConfig:
|
| 100 |
+
# Use factory functions to ensure each instance gets its own copy
|
| 101 |
+
net: dict = attrs.field(factory=lambda: copy.deepcopy(EDMConfig))
|
| 102 |
+
teacher: Optional[dict] = None # Usually not used, only used when teacher is different from net (i.e. Causvid)
|
| 103 |
+
|
| 104 |
+
# guidance scale for classifier-free guidance in teacher diffusion model. None means no guidance.
|
| 105 |
+
guidance_scale: Optional[float] = None
|
| 106 |
+
|
| 107 |
+
# enable skip layer guidance (currently only wan network has the skip_layers option in cfg)
|
| 108 |
+
skip_layers: List[int] | None = None
|
| 109 |
+
|
| 110 |
+
# optimizer and scheduler for the main net (i.e., one-step generator in DMD)
|
| 111 |
+
net_optimizer: dict = attrs.field(factory=lambda: copy.deepcopy(BaseOptimizerConfig))
|
| 112 |
+
net_scheduler: dict = attrs.field(factory=lambda: copy.deepcopy(BaseSchedulerConfig))
|
| 113 |
+
|
| 114 |
+
# sampling t from a given distribution
|
| 115 |
+
sample_t_cfg: SampleTConfig = attrs.field(factory=SampleTConfig)
|
| 116 |
+
|
| 117 |
+
# shape of the input to the model (defaults to CIFAR-10)
|
| 118 |
+
input_shape: List[int] = [3, 32, 32]
|
| 119 |
+
# device ("cuda" or "cpu")
|
| 120 |
+
device: str = "cuda"
|
| 121 |
+
|
| 122 |
+
# enable gradient scaler
|
| 123 |
+
grad_scaler_enabled: bool = False
|
| 124 |
+
grad_scaler_init_scale: float = 65536.0
|
| 125 |
+
grad_scaler_growth_interval: int = 2000
|
| 126 |
+
|
| 127 |
+
# path to the pretrained teacher model ckpt
|
| 128 |
+
pretrained_model_path: str = ""
|
| 129 |
+
# path to the pretrained student net ckpt (if different from the teacher)
|
| 130 |
+
pretrained_student_net_path: str = ""
|
| 131 |
+
# initialize student from the above checkpoints (can be turned off to only load weights to the teacher)
|
| 132 |
+
load_student_weights: bool = True
|
| 133 |
+
|
| 134 |
+
# enable preprocessors in the model
|
| 135 |
+
enable_preprocessors: bool = True
|
| 136 |
+
|
| 137 |
+
# EMA for the main net (requires EMACallback)
|
| 138 |
+
use_ema: Any = False
|
| 139 |
+
|
| 140 |
+
# multistep generation if larger than 1 (default: single-step generation)
|
| 141 |
+
student_sample_steps: int = 1
|
| 142 |
+
# sampling type in multistep generation ('sde', 'ode')
|
| 143 |
+
student_sample_type: str = "sde"
|
| 144 |
+
|
| 145 |
+
# Enable memory-efficient model loading with meta device:
|
| 146 |
+
# - Rank 0 loads pretrained weights normally
|
| 147 |
+
# - Other ranks use torch.device("meta") for ZERO memory allocation (just metadata)
|
| 148 |
+
# - FSDP materializes meta tensors and broadcasts weights from rank 0
|
| 149 |
+
# This dramatically speeds up initialization for large models (14B+):
|
| 150 |
+
# - Reduces RAM from N*model_size to 1*model_size
|
| 151 |
+
# - Eliminates disk I/O contention (N parallel reads -> 1 read)
|
| 152 |
+
# - Expected speedup: 30+ min -> <1 min for 14B models on 8 GPUs
|
| 153 |
+
fsdp_meta_init: bool = False
|
| 154 |
+
|
| 155 |
+
# whether to add the teacher model to the fsdp_dict
|
| 156 |
+
add_teacher_to_fsdp_dict: bool = True
|
| 157 |
+
|
| 158 |
+
# whether to find unused parameters in ddp
|
| 159 |
+
# - can be turned off for improved performance
|
| 160 |
+
# - however, it is required if the model has a discriminator or the net initializes unused modules (e.g., for logvar predictions)
|
| 161 |
+
ddp_find_unused_parameters: bool = True
|
| 162 |
+
|
| 163 |
+
# precision variables (choose from "float64", "float32", "bfloat16", or "float16")
|
| 164 |
+
# (precision of the time steps is handled in the noise scheduler, defaulting to float64 for numerical stability)
|
| 165 |
+
|
| 166 |
+
# precision for model/optimizer states and data - recommended to be float32 if precision_amp is not None
|
| 167 |
+
precision: str = "float32"
|
| 168 |
+
# AMP during training - if None or equal to precision, AMP is disabled during training.
|
| 169 |
+
precision_amp: str | None = None
|
| 170 |
+
# AMP during inference - if None or equal to precision, AMP is disabled during inference.
|
| 171 |
+
precision_amp_infer: str | None = None
|
| 172 |
+
# AMP during en-/decoding (e.g., for VAEs or text encoders) - if None or equal to precision, AMP is disabled during en-/decoding.
|
| 173 |
+
precision_amp_enc: str | None = None
|
| 174 |
+
|
| 175 |
+
|
| 176 |
+
@attrs.define(slots=False)
|
| 177 |
+
class BaseTrainerConfig:
|
| 178 |
+
cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig)
|
| 179 |
+
checkpointer: BaseCheckpointerConfig = attrs.field(factory=BaseCheckpointerConfig)
|
| 180 |
+
|
| 181 |
+
# Callbacks configs.
|
| 182 |
+
callbacks: dict = DictConfig(WANDB_CALLBACK)
|
| 183 |
+
|
| 184 |
+
# save checkpoint frequency
|
| 185 |
+
save_ckpt_iter: int = 5000
|
| 186 |
+
# test on validation set frequency
|
| 187 |
+
validation_iter: int = 1000
|
| 188 |
+
# logging frequency
|
| 189 |
+
logging_iter: int = 1000
|
| 190 |
+
# maximum training iteration
|
| 191 |
+
max_iter: int = 1000000
|
| 192 |
+
# whether to visualize multistep teacher generation
|
| 193 |
+
visualize_teacher: bool = False
|
| 194 |
+
|
| 195 |
+
# Set the random seed.
|
| 196 |
+
seed: int = 0
|
| 197 |
+
# Validation seed
|
| 198 |
+
val_seed: int | None = None
|
| 199 |
+
# Resume
|
| 200 |
+
resume: bool = True
|
| 201 |
+
|
| 202 |
+
# DDP Parallelism
|
| 203 |
+
ddp: bool = False
|
| 204 |
+
# FSDP Parallelism
|
| 205 |
+
fsdp: bool = False
|
| 206 |
+
# Enable TensorFloat32 (convolution and matmul)
|
| 207 |
+
tf32_enabled: bool = True
|
| 208 |
+
|
| 209 |
+
# Number of gradient accumulation rounds
|
| 210 |
+
grad_accum_rounds: int = 1
|
| 211 |
+
|
| 212 |
+
# Global batch size (if not None, overrides grad_accum_rounds to match the specified batch size)
|
| 213 |
+
batch_size_global: int | None = None
|
| 214 |
+
|
| 215 |
+
# offload other modules to cpu during latent decoding
|
| 216 |
+
offload_module_in_decoding: bool = False
|
| 217 |
+
|
| 218 |
+
# apply cpu offloading in fsdp
|
| 219 |
+
fsdp_cpu_offload: bool = False
|
| 220 |
+
# Fallback minimum number of parameters for FSDP wrapping
|
| 221 |
+
# (10M wraps large models into fairly small shards)
|
| 222 |
+
# The FastGenNetwork should provide a fully_shard method that can be used to shard the network.
|
| 223 |
+
# If we need to shard a different module, we fall back to an auto-sharding policy based on this value.
|
| 224 |
+
fsdp_min_num_params: int = 10_000_000
|
| 225 |
+
# Sharding group size for FSDP. If None, fully shard across all ranks.
|
| 226 |
+
# If set, creates a 2D mesh with (replicate, shard) dimensions.
|
| 227 |
+
fsdp_sharding_group_size: Optional[int] = None
|
| 228 |
+
|
| 229 |
+
# global variables
|
| 230 |
+
global_vars: Optional[dict] = None
|
| 231 |
+
global_vars_val: List[dict | None] = [None]
|
| 232 |
+
|
| 233 |
+
# augment config
|
| 234 |
+
augment_pipe: Optional[DictConfig] = None
|
| 235 |
+
|
| 236 |
+
|
| 237 |
+
@attrs.define(slots=False)
|
| 238 |
+
class BaseConfig:
|
| 239 |
+
# Log config.
|
| 240 |
+
log_config: LogConfig = attrs.field(factory=LogConfig)
|
| 241 |
+
|
| 242 |
+
# Trainer configs.
|
| 243 |
+
trainer: BaseTrainerConfig = attrs.field(factory=BaseTrainerConfig)
|
| 244 |
+
|
| 245 |
+
# Model configs.
|
| 246 |
+
model: BaseModelConfig = attrs.field(factory=BaseModelConfig)
|
| 247 |
+
model_class: DictConfig = L(FastGenModel)(config=None)
|
| 248 |
+
|
| 249 |
+
# Data configs.
|
| 250 |
+
dataloader_train: dict = CIFAR10_Loader_Config
|
| 251 |
+
dataloader_val: Any = None
|
| 252 |
+
|
| 253 |
+
# Eval configs.
|
| 254 |
+
eval: EvalConfig = attrs.field(factory=EvalConfig)
|
FastGen/fastgen/configs/config_utils.py
ADDED
|
@@ -0,0 +1,317 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
from typing import Any, Optional, List, Dict
|
| 6 |
+
import inspect
|
| 7 |
+
from copy import deepcopy
|
| 8 |
+
|
| 9 |
+
import attrs
|
| 10 |
+
import yaml
|
| 11 |
+
from omegaconf import DictConfig, OmegaConf, ListConfig
|
| 12 |
+
from hydra import compose, initialize
|
| 13 |
+
from hydra.core.config_store import ConfigStore
|
| 14 |
+
|
| 15 |
+
import importlib
|
| 16 |
+
from dataclasses import fields as dataclass_fields
|
| 17 |
+
import attr
|
| 18 |
+
from dataclasses import is_dataclass
|
| 19 |
+
import fastgen.utils.logging_utils as logger
|
| 20 |
+
|
| 21 |
+
|
| 22 |
+
def import_config_from_python_file(config_file: str) -> Any:
|
| 23 |
+
"""
|
| 24 |
+
Import a config from a python file.
|
| 25 |
+
|
| 26 |
+
Args:
|
| 27 |
+
config_file (str): The path to the python file.
|
| 28 |
+
|
| 29 |
+
Returns:
|
| 30 |
+
Any: The config object.
|
| 31 |
+
"""
|
| 32 |
+
|
| 33 |
+
if not config_file.endswith(".py"):
|
| 34 |
+
raise ValueError("Config file must be a Python file with a .py extension. " f"Received: {config_file}")
|
| 35 |
+
|
| 36 |
+
if not os.path.isfile(config_file):
|
| 37 |
+
raise FileNotFoundError(f"FastGen config file ({config_file}) not found.")
|
| 38 |
+
|
| 39 |
+
# Convert to importable module format.
|
| 40 |
+
config_module = config_file.replace("/", ".").replace(".py", "")
|
| 41 |
+
|
| 42 |
+
# Import the module
|
| 43 |
+
try:
|
| 44 |
+
config = importlib.import_module(config_module)
|
| 45 |
+
except ImportError as e:
|
| 46 |
+
logger.error(f"Failed to import config from python file: {e}")
|
| 47 |
+
raise e
|
| 48 |
+
|
| 49 |
+
return config.create_config()
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
|
| 53 |
+
"""
|
| 54 |
+
Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data
|
| 55 |
+
|
| 56 |
+
Args:
|
| 57 |
+
ref_instance: The reference instance to determine the type and fields when needed
|
| 58 |
+
kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data
|
| 59 |
+
|
| 60 |
+
Returns:
|
| 61 |
+
Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data
|
| 62 |
+
|
| 63 |
+
Raises:
|
| 64 |
+
AssertionError: If the fields do not match or if extra keys are found.
|
| 65 |
+
Exception: If there is an error constructing the new instance.
|
| 66 |
+
"""
|
| 67 |
+
is_type = is_attrs_or_dataclass(ref_instance)
|
| 68 |
+
if not is_type:
|
| 69 |
+
return kwargs
|
| 70 |
+
else:
|
| 71 |
+
ref_fields = set(get_fields(ref_instance))
|
| 72 |
+
assert isinstance(kwargs, dict) or isinstance(kwargs, DictConfig), "kwargs must be a dictionary or a DictConfig"
|
| 73 |
+
keys = set(kwargs.keys())
|
| 74 |
+
|
| 75 |
+
# ref_fields must equal to or include all keys
|
| 76 |
+
extra_keys = keys - ref_fields
|
| 77 |
+
assert (
|
| 78 |
+
ref_fields == keys or keys.issubset(ref_fields)
|
| 79 |
+
), f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"
|
| 80 |
+
|
| 81 |
+
resolved_kwargs: Dict[str, Any] = {}
|
| 82 |
+
for f in keys:
|
| 83 |
+
resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
|
| 84 |
+
try:
|
| 85 |
+
new_instance = type(ref_instance)(**resolved_kwargs)
|
| 86 |
+
except Exception as e:
|
| 87 |
+
logger.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
|
| 88 |
+
logger.error(e)
|
| 89 |
+
raise e
|
| 90 |
+
return new_instance
|
| 91 |
+
|
| 92 |
+
|
| 93 |
+
def flatten_dict(nested_dict, parent_key="", sep=".", exclude_key="_target_"):
|
| 94 |
+
"""
|
| 95 |
+
Flattens a nested dictionary by joining keys with a separator.
|
| 96 |
+
|
| 97 |
+
Args:
|
| 98 |
+
nested_dict (dict): The dictionary to flatten.
|
| 99 |
+
parent_key (str, optional): The base key to prepend to flattened keys.
|
| 100 |
+
Used in recursion. Defaults to ''.
|
| 101 |
+
sep (str, optional): The separator to use between keys. Defaults to '.'.
|
| 102 |
+
|
| 103 |
+
Returns:
|
| 104 |
+
dict: A new, flattened dictionary.
|
| 105 |
+
"""
|
| 106 |
+
items = []
|
| 107 |
+
for key, value in nested_dict.items():
|
| 108 |
+
if key == exclude_key:
|
| 109 |
+
continue
|
| 110 |
+
if value is None:
|
| 111 |
+
continue
|
| 112 |
+
|
| 113 |
+
# Create the new key by joining parent_key and key
|
| 114 |
+
new_key = parent_key + sep + key if parent_key else key
|
| 115 |
+
|
| 116 |
+
# If the value is a non-empty dictionary, recurse
|
| 117 |
+
if isinstance(value, DictConfig):
|
| 118 |
+
value = OmegaConf.to_container(value)
|
| 119 |
+
if isinstance(value, dict) and value:
|
| 120 |
+
items.extend(flatten_dict(value, new_key, sep=sep, exclude_key=exclude_key).items())
|
| 121 |
+
# If the value is not a dictionary, it's a leaf node
|
| 122 |
+
else:
|
| 123 |
+
items.append((new_key, value))
|
| 124 |
+
|
| 125 |
+
return dict(items)
|
| 126 |
+
|
| 127 |
+
|
| 128 |
+
def override_config_with_opts(config: Any, opts: Optional[List[str]] = None) -> Any:
|
| 129 |
+
"""
|
| 130 |
+
Override the config with the opts.
|
| 131 |
+
|
| 132 |
+
Args:
|
| 133 |
+
config (Any): The config object.
|
| 134 |
+
opts (Dict[str, Any]): Dict for the overrides.
|
| 135 |
+
|
| 136 |
+
Returns:
|
| 137 |
+
Any: The config object.
|
| 138 |
+
"""
|
| 139 |
+
|
| 140 |
+
# Convert Config object to a DictConfig object for Hydra
|
| 141 |
+
|
| 142 |
+
if not isinstance(config, DictConfig):
|
| 143 |
+
config_dict = attrs.asdict(config)
|
| 144 |
+
config_dict = DictConfig(content=config_dict, flags={"allow_objects": True})
|
| 145 |
+
else:
|
| 146 |
+
config_dict = config
|
| 147 |
+
|
| 148 |
+
# Use Hydra to handle overrides
|
| 149 |
+
cs = ConfigStore.instance()
|
| 150 |
+
cs.store(name="config", node=config_dict)
|
| 151 |
+
|
| 152 |
+
if opts is None:
|
| 153 |
+
opts = []
|
| 154 |
+
|
| 155 |
+
if len(opts) > 0 and opts[0] != "-":
|
| 156 |
+
raise ValueError(f"opts must start with '-' to separate from other arguments. Got: {opts}")
|
| 157 |
+
opts = opts[1:]
|
| 158 |
+
with initialize(version_base=None):
|
| 159 |
+
try:
|
| 160 |
+
cfg = compose(config_name="config", overrides=opts)
|
| 161 |
+
except Exception as e:
|
| 162 |
+
raise ValueError(f"Failed to compose config with opts: {e}")
|
| 163 |
+
|
| 164 |
+
OmegaConf.resolve(cfg)
|
| 165 |
+
|
| 166 |
+
config = config_from_dict(config, cfg)
|
| 167 |
+
|
| 168 |
+
return config
|
| 169 |
+
|
| 170 |
+
|
| 171 |
+
def override_config_with_yaml(config: Any, yaml_path: str) -> Any:
|
| 172 |
+
"""
|
| 173 |
+
Override the config with the yaml file. **_target_ field is excluded**.
|
| 174 |
+
"""
|
| 175 |
+
with open(yaml_path, "r") as f:
|
| 176 |
+
yaml_dict = yaml.safe_load(f)
|
| 177 |
+
yaml_dict = flatten_dict(yaml_dict)
|
| 178 |
+
config_dict = flatten_dict(attrs.asdict(config))
|
| 179 |
+
|
| 180 |
+
# Loose overriding: all mismatched keys are ignored
|
| 181 |
+
override_dict = {k: v for k, v in yaml_dict.items() if k in config_dict}
|
| 182 |
+
|
| 183 |
+
opts = ["-"] + [f"{k}={v}" for k, v in override_dict.items()]
|
| 184 |
+
return override_config_with_opts(config, opts=opts)
|
| 185 |
+
|
| 186 |
+
|
| 187 |
+
def is_attrs_or_dataclass(obj) -> bool:
|
| 188 |
+
"""
|
| 189 |
+
Check if the object is an instance of an attrs class or a dataclass.
|
| 190 |
+
|
| 191 |
+
Args:
|
| 192 |
+
obj: The object to check.
|
| 193 |
+
|
| 194 |
+
Returns:
|
| 195 |
+
bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
|
| 196 |
+
"""
|
| 197 |
+
return is_dataclass(obj) or attr.has(type(obj))
|
| 198 |
+
|
| 199 |
+
|
| 200 |
+
def get_fields(obj):
|
| 201 |
+
"""
|
| 202 |
+
Get the fields of an attrs class or a dataclass.
|
| 203 |
+
|
| 204 |
+
Args:
|
| 205 |
+
obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.
|
| 206 |
+
|
| 207 |
+
Returns:
|
| 208 |
+
list: A list of field names.
|
| 209 |
+
|
| 210 |
+
Raises:
|
| 211 |
+
ValueError: If the object is neither an attrs class nor a dataclass.
|
| 212 |
+
"""
|
| 213 |
+
if is_dataclass(obj):
|
| 214 |
+
return [field.name for field in dataclass_fields(obj)]
|
| 215 |
+
elif attr.has(type(obj)):
|
| 216 |
+
return [field.name for field in attr.fields(type(obj))]
|
| 217 |
+
else:
|
| 218 |
+
raise ValueError("The object is neither an attrs class nor a dataclass.")
|
| 219 |
+
|
| 220 |
+
|
| 221 |
+
def serialize_config(
|
| 222 |
+
config: Any,
|
| 223 |
+
return_type: str = "dict",
|
| 224 |
+
path: str | bytes | None = None,
|
| 225 |
+
filename: str = "config.yaml",
|
| 226 |
+
include_defaults: bool = False,
|
| 227 |
+
) -> Dict[str, Any] | str:
|
| 228 |
+
"""
|
| 229 |
+
Serialize a config (BaseConfig or DictConfig) to various formats.
|
| 230 |
+
|
| 231 |
+
Args:
|
| 232 |
+
config: The config to serialize (BaseConfig attrs object or DictConfig).
|
| 233 |
+
return_type: Output format - "dict" (plain dict), "yaml" (YAML string), or "file" (save to file).
|
| 234 |
+
path: Directory path to save the file. Required if return_type is "file".
|
| 235 |
+
filename: Name of the file to save. Only used if return_type is "file".
|
| 236 |
+
include_defaults: If True, add default parameter values from _target_ classes.
|
| 237 |
+
|
| 238 |
+
Returns:
|
| 239 |
+
Dict[str, Any] if return_type is "dict"
|
| 240 |
+
str (YAML) if return_type is "yaml" or "file"
|
| 241 |
+
|
| 242 |
+
Raises:
|
| 243 |
+
ValueError: If return_type is "file" but path is not provided.
|
| 244 |
+
"""
|
| 245 |
+
if return_type == "file" and path is None:
|
| 246 |
+
raise ValueError("path must be provided when return_type is 'file'")
|
| 247 |
+
|
| 248 |
+
# Deep copy to avoid modifying original
|
| 249 |
+
config = deepcopy(config)
|
| 250 |
+
|
| 251 |
+
# Normalize to DictConfig with object support
|
| 252 |
+
if not isinstance(config, DictConfig):
|
| 253 |
+
config_dict = attrs.asdict(config)
|
| 254 |
+
config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
|
| 255 |
+
else:
|
| 256 |
+
config_omegaconf = config
|
| 257 |
+
|
| 258 |
+
def is_serializable(item) -> bool:
|
| 259 |
+
try:
|
| 260 |
+
OmegaConf.to_yaml(item)
|
| 261 |
+
return True
|
| 262 |
+
except Exception:
|
| 263 |
+
return False
|
| 264 |
+
|
| 265 |
+
def get_default_params(cls_or_func):
|
| 266 |
+
if callable(cls_or_func):
|
| 267 |
+
signature = inspect.signature(cls_or_func)
|
| 268 |
+
else:
|
| 269 |
+
signature = inspect.signature(cls_or_func.__init__)
|
| 270 |
+
params = signature.parameters
|
| 271 |
+
return {name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty}
|
| 272 |
+
|
| 273 |
+
def process_config(conf):
|
| 274 |
+
if isinstance(conf, DictConfig):
|
| 275 |
+
for key, value in conf.items():
|
| 276 |
+
if isinstance(value, (DictConfig, ListConfig)):
|
| 277 |
+
# Optionally add default params from _target_ classes
|
| 278 |
+
if include_defaults:
|
| 279 |
+
try:
|
| 280 |
+
if "_target_" in value:
|
| 281 |
+
default_params = get_default_params(value["_target_"])
|
| 282 |
+
for default_key, default_v in default_params.items():
|
| 283 |
+
if default_key not in value:
|
| 284 |
+
value[default_key] = default_v
|
| 285 |
+
except Exception as e:
|
| 286 |
+
logger.error(f"Failed to add default argument values: {e}")
|
| 287 |
+
process_config(value)
|
| 288 |
+
else:
|
| 289 |
+
if not is_serializable(value) and value is not None:
|
| 290 |
+
conf[key] = str(value)
|
| 291 |
+
elif isinstance(conf, ListConfig):
|
| 292 |
+
for i, item in enumerate(conf):
|
| 293 |
+
if isinstance(item, (DictConfig, ListConfig)):
|
| 294 |
+
process_config(item)
|
| 295 |
+
else:
|
| 296 |
+
if not is_serializable(item) and item is not None:
|
| 297 |
+
conf[i] = str(item)
|
| 298 |
+
else:
|
| 299 |
+
raise NotImplementedError("Input config must be a DictConfig or ListConfig.")
|
| 300 |
+
return conf
|
| 301 |
+
|
| 302 |
+
config_omegaconf = process_config(config_omegaconf)
|
| 303 |
+
result_dict: Dict[str, Any] = OmegaConf.to_container(config_omegaconf, resolve=True) # type: ignore
|
| 304 |
+
|
| 305 |
+
if return_type == "dict":
|
| 306 |
+
return result_dict
|
| 307 |
+
|
| 308 |
+
# For yaml and file, convert to YAML string
|
| 309 |
+
yaml_str = yaml.dump(result_dict, default_flow_style=False, sort_keys=True)
|
| 310 |
+
|
| 311 |
+
if return_type == "file":
|
| 312 |
+
os.makedirs(path, exist_ok=True) # type: ignore
|
| 313 |
+
with open(f"{path}/{filename}", "w") as f:
|
| 314 |
+
f.write(yaml_str)
|
| 315 |
+
logger.info(f"Config is saved at {path}/{filename}")
|
| 316 |
+
|
| 317 |
+
return yaml_str
|
FastGen/fastgen/configs/data.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import os
|
| 5 |
+
|
| 6 |
+
from fastgen.datasets.class_cond_dataloader import ImageLoader
|
| 7 |
+
from fastgen.datasets.wds_dataloaders import (
|
| 8 |
+
WDSLoader,
|
| 9 |
+
ImageWDSLoader,
|
| 10 |
+
VideoWDSLoader,
|
| 11 |
+
)
|
| 12 |
+
|
| 13 |
+
from fastgen.utils import LazyCall as L
|
| 14 |
+
|
| 15 |
+
OUTPUT_ROOT = os.environ.get("FASTGEN_OUTPUT_ROOT", "FASTGEN_OUTPUT")
|
| 16 |
+
DATA_ROOT_DIR = os.getenv("DATA_ROOT_DIR", f"{OUTPUT_ROOT}/DATA")
|
| 17 |
+
S3_DATA_ROOT_DIR = os.getenv("DATA_ROOT_DIR", "s3://data")
|
| 18 |
+
|
| 19 |
+
# ################################################################################
|
| 20 |
+
# Generic Loaders (for config templates - override datatags for actual use)
|
| 21 |
+
# ################################################################################
|
| 22 |
+
# See fastgen/datasets/README.md for more details.
|
| 23 |
+
|
| 24 |
+
ImageLoaderConfig = L(ImageWDSLoader)(
|
| 25 |
+
datatags=["WDS:/path/to/images"],
|
| 26 |
+
batch_size=32,
|
| 27 |
+
key_map={"real": "jpg", "condition": "txt"},
|
| 28 |
+
presets_map={"neg_condition": "empty_string"},
|
| 29 |
+
input_res=512,
|
| 30 |
+
)
|
| 31 |
+
|
| 32 |
+
ImageLatentLoaderConfig = L(WDSLoader)(
|
| 33 |
+
datatags=["WDS:/path/to/image_latents"],
|
| 34 |
+
batch_size=32,
|
| 35 |
+
key_map={"real": "latent.pth", "condition": "txt_emb.pth"},
|
| 36 |
+
# Negative condition embedding loaded from a shared file (same for all samples)
|
| 37 |
+
files_map={"neg_condition": "/path/to/neg_prompt_emb.npy"},
|
| 38 |
+
)
|
| 39 |
+
|
| 40 |
+
VideoLoaderConfig = L(VideoWDSLoader)(
|
| 41 |
+
datatags=["WDS:/path/to/videos"],
|
| 42 |
+
batch_size=2,
|
| 43 |
+
key_map={"real": "mp4", "condition": "txt"},
|
| 44 |
+
presets_map={"neg_condition": "neg_prompt_wan"},
|
| 45 |
+
sequence_length=81,
|
| 46 |
+
img_size=(832, 480),
|
| 47 |
+
)
|
| 48 |
+
|
| 49 |
+
VideoLatentLoaderConfig = L(WDSLoader)(
|
| 50 |
+
datatags=["WDS:/path/to/video_latents"],
|
| 51 |
+
batch_size=2,
|
| 52 |
+
key_map={"real": "latent.pth", "condition": "txt_emb.pth"},
|
| 53 |
+
# Negative condition embedding loaded from a shared file (same for all samples)
|
| 54 |
+
files_map={"neg_condition": "/path/to/neg_prompt_emb.npy"},
|
| 55 |
+
# NOTE: For v2v tasks, add condition latent (e.g., depth) to key_map:
|
| 56 |
+
# key_map={"real": "latent.pth", "condition": "txt_emb.pth", "depth_latent": "depth_latent.pth"}
|
| 57 |
+
)
|
| 58 |
+
|
| 59 |
+
# ################################################################################
|
| 60 |
+
# Generic KD Loaders (for paired/path data)
|
| 61 |
+
# ################################################################################
|
| 62 |
+
# See fastgen/methods/knowledge_distillation/README.md for more details.
|
| 63 |
+
|
| 64 |
+
# For single-step KD: provides (real, noise, condition) pairs
|
| 65 |
+
# Data requirements: {"real": clean, "noise": noise, "condition": cond}
|
| 66 |
+
PairLoaderConfig = L(WDSLoader)(
|
| 67 |
+
datatags=["WDS:/path/to/pairs"],
|
| 68 |
+
batch_size=2,
|
| 69 |
+
key_map={"real": "latent.pth", "noise": "noise.pth", "condition": "txt_emb.pth"},
|
| 70 |
+
)
|
| 71 |
+
|
| 72 |
+
# For multi-step KD: provides (real, path, condition) with denoising trajectory
|
| 73 |
+
# Data requirements: {"real": clean, "path": [B, steps, C, ...], "condition": cond}
|
| 74 |
+
# path contains intermediate denoising steps (typically 4 steps)
|
| 75 |
+
PathLoaderConfig = L(WDSLoader)(
|
| 76 |
+
datatags=["WDS:/path/to/paths"],
|
| 77 |
+
batch_size=2,
|
| 78 |
+
key_map={"real": "latent.pth", "path": "path.pth", "condition": "txt_emb.pth"},
|
| 79 |
+
)
|
| 80 |
+
|
| 81 |
+
# ################################################################################
|
| 82 |
+
# Specific Datasets
|
| 83 |
+
# ################################################################################
|
| 84 |
+
|
| 85 |
+
CIFAR10_Loader_Config = L(ImageLoader)(
|
| 86 |
+
dataset_path=f"{DATA_ROOT_DIR}/cifar10/cifar10-32x32.zip",
|
| 87 |
+
s3_path=f"{S3_DATA_ROOT_DIR}/cifar10/cifar10-32x32.zip",
|
| 88 |
+
use_labels=True,
|
| 89 |
+
cache=True,
|
| 90 |
+
batch_size=128,
|
| 91 |
+
shuffle=True,
|
| 92 |
+
sampler_start_idx=None,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
ImageNet64_Loader_Config = L(ImageLoader)(
|
| 96 |
+
dataset_path=f"{DATA_ROOT_DIR}/imagenet-64/imagenet-64x64.zip",
|
| 97 |
+
s3_path=f"{S3_DATA_ROOT_DIR}/imagenet-64/imagenet-64x64.zip",
|
| 98 |
+
use_labels=True,
|
| 99 |
+
cache=True,
|
| 100 |
+
batch_size=32,
|
| 101 |
+
shuffle=True,
|
| 102 |
+
sampler_start_idx=None,
|
| 103 |
+
)
|
| 104 |
+
|
| 105 |
+
ImageNet256_Loader_Config = L(ImageLoader)(
|
| 106 |
+
dataset_path=f"{DATA_ROOT_DIR}/imagenet-256/imagenet_256_sd.zip",
|
| 107 |
+
s3_path=f"{S3_DATA_ROOT_DIR}/imagenet-256/imagenet_256_sd.zip",
|
| 108 |
+
use_labels=True,
|
| 109 |
+
cache=True,
|
| 110 |
+
batch_size=32,
|
| 111 |
+
shuffle=True,
|
| 112 |
+
sampler_start_idx=None,
|
| 113 |
+
)
|
| 114 |
+
|
| 115 |
+
ImageNet64_EDMV2_Loader_Config = L(ImageLoader)(
|
| 116 |
+
dataset_path=f"{DATA_ROOT_DIR}/imagenet-64/imagenet-64x64-edmv2.zip",
|
| 117 |
+
s3_path=f"{S3_DATA_ROOT_DIR}/imagenet-64/imagenet-64x64-edmv2.zip",
|
| 118 |
+
use_labels=True,
|
| 119 |
+
cache=True,
|
| 120 |
+
batch_size=32,
|
| 121 |
+
shuffle=True,
|
| 122 |
+
sampler_start_idx=None,
|
| 123 |
+
)
|
FastGen/fastgen/configs/data_dummy.py
ADDED
|
@@ -0,0 +1,67 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""Dummy data loader for data-free distillation training (e.g., Self-Forcing without GAN)."""
|
| 5 |
+
|
| 6 |
+
import torch
|
| 7 |
+
from fastgen.utils import LazyCall as L
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
class DummyVideoLoader:
|
| 11 |
+
"""A dummy dataloader that generates random tensors for data-free training.
|
| 12 |
+
|
| 13 |
+
Use this when running Self-Forcing or other distillation methods without real data.
|
| 14 |
+
Requires setting gan_loss_weight_gen=0 since there's no real data for discriminator.
|
| 15 |
+
"""
|
| 16 |
+
|
| 17 |
+
def __init__(
|
| 18 |
+
self,
|
| 19 |
+
batch_size: int = 1,
|
| 20 |
+
input_shape: list = [16, 21, 60, 104], # [C, T, H, W]
|
| 21 |
+
text_seq_len: int = 512,
|
| 22 |
+
text_dim: int = 4096,
|
| 23 |
+
device: str = "cuda",
|
| 24 |
+
dtype: str = "bfloat16",
|
| 25 |
+
):
|
| 26 |
+
self.batch_size = batch_size
|
| 27 |
+
self.input_shape = input_shape
|
| 28 |
+
self.text_seq_len = text_seq_len
|
| 29 |
+
self.text_dim = text_dim
|
| 30 |
+
self.device = device
|
| 31 |
+
self._dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
|
| 32 |
+
|
| 33 |
+
def __iter__(self):
|
| 34 |
+
return self
|
| 35 |
+
|
| 36 |
+
def __next__(self):
|
| 37 |
+
return {
|
| 38 |
+
"real": torch.randn(
|
| 39 |
+
self.batch_size, *self.input_shape,
|
| 40 |
+
device=self.device, dtype=self._dtype
|
| 41 |
+
),
|
| 42 |
+
"condition": torch.randn(
|
| 43 |
+
self.batch_size, self.text_seq_len, self.text_dim,
|
| 44 |
+
device=self.device, dtype=self._dtype
|
| 45 |
+
),
|
| 46 |
+
"neg_condition": torch.zeros(
|
| 47 |
+
self.batch_size, self.text_seq_len, self.text_dim,
|
| 48 |
+
device=self.device, dtype=self._dtype
|
| 49 |
+
),
|
| 50 |
+
}
|
| 51 |
+
|
| 52 |
+
|
| 53 |
+
# Config for WAN T2V 480p (832x480, 81 frames -> 21 latent frames)
|
| 54 |
+
DummyVideoLoaderConfig = L(DummyVideoLoader)(
|
| 55 |
+
batch_size=1,
|
| 56 |
+
input_shape=[16, 21, 60, 104], # [C, T, H, W] latent shape for 480p
|
| 57 |
+
text_seq_len=512,
|
| 58 |
+
text_dim=4096,
|
| 59 |
+
)
|
| 60 |
+
|
| 61 |
+
# Config for WAN T2V 720p (1280x720, 81 frames -> 21 latent frames)
|
| 62 |
+
DummyVideoLoader720pConfig = L(DummyVideoLoader)(
|
| 63 |
+
batch_size=1,
|
| 64 |
+
input_shape=[16, 21, 90, 160], # [C, T, H, W] latent shape for 720p
|
| 65 |
+
text_seq_len=512,
|
| 66 |
+
text_dim=4096,
|
| 67 |
+
)
|
FastGen/fastgen/configs/discriminator.py
ADDED
|
@@ -0,0 +1,106 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
|
| 6 |
+
from fastgen.utils import LazyCall as L
|
| 7 |
+
from fastgen.networks.discriminators import (
|
| 8 |
+
Discriminator_EDM,
|
| 9 |
+
Discriminator_SD15,
|
| 10 |
+
Discriminator_SDXL,
|
| 11 |
+
Discriminator_ImageDiT,
|
| 12 |
+
Discriminator_VideoDiT,
|
| 13 |
+
)
|
| 14 |
+
|
| 15 |
+
Discriminator_EDM_CIFAR10_Config: DictConfig = L(Discriminator_EDM)(
|
| 16 |
+
feature_indices={0, 1, 2},
|
| 17 |
+
all_res=[32, 16, 8],
|
| 18 |
+
in_channels=256,
|
| 19 |
+
)
|
| 20 |
+
|
| 21 |
+
Discriminator_EDM_ImageNet64_Config: DictConfig = L(Discriminator_EDM)(
|
| 22 |
+
feature_indices=None,
|
| 23 |
+
all_res=[64, 32, 16, 8],
|
| 24 |
+
in_channels=768,
|
| 25 |
+
)
|
| 26 |
+
|
| 27 |
+
Discriminator_SD15_Res512_Config: DictConfig = L(Discriminator_SD15)(
|
| 28 |
+
feature_indices=None,
|
| 29 |
+
all_res=[32, 16, 8, 8, 8],
|
| 30 |
+
in_channels=1280,
|
| 31 |
+
)
|
| 32 |
+
|
| 33 |
+
Discriminator_SDXL_Res512_Config: DictConfig = L(Discriminator_SDXL)(
|
| 34 |
+
feature_indices=None,
|
| 35 |
+
all_res=[32, 16, 16, 16],
|
| 36 |
+
in_channels=1280,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
Discriminator_SDXL_Res1024_Config: DictConfig = L(Discriminator_SDXL)(
|
| 40 |
+
feature_indices=None,
|
| 41 |
+
all_res=[64, 32, 32, 32],
|
| 42 |
+
in_channels=1280,
|
| 43 |
+
)
|
| 44 |
+
|
| 45 |
+
# Flux: hidden_dim=3072, 19 joint blocks + 38 single blocks = 57 total
|
| 46 |
+
Discriminator_Flux_Config: DictConfig = L(Discriminator_ImageDiT)(
|
| 47 |
+
feature_indices=None,
|
| 48 |
+
num_blocks=57, # 19 joint + 38 single blocks
|
| 49 |
+
inner_dim=3072, # Flux hidden dimension
|
| 50 |
+
)
|
| 51 |
+
|
| 52 |
+
# 2B patchify: spatial-2, temporal-1; inner_dim=1920; layer=30
|
| 53 |
+
Discriminator_CogVideoX2B_Config = L(Discriminator_VideoDiT)(
|
| 54 |
+
feature_indices=None,
|
| 55 |
+
num_blocks=30,
|
| 56 |
+
disc_type="dit_simple_conv3d",
|
| 57 |
+
inner_dim=1920 // 4,
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
# 5B patchify: spatial-2, temporal-1; inner_dim=3072; layer=42
|
| 61 |
+
Discriminator_CogVideoX5B_Config = L(Discriminator_VideoDiT)(
|
| 62 |
+
feature_indices=None,
|
| 63 |
+
num_blocks=42,
|
| 64 |
+
disc_type="dit_simple_conv3d",
|
| 65 |
+
inner_dim=3072 // 4,
|
| 66 |
+
)
|
| 67 |
+
|
| 68 |
+
# 1.3B patchify: spatial-2, temporal-1; inner_dim=1536; layer=30
|
| 69 |
+
Discriminator_Wan_1_3B_Config: DictConfig = L(Discriminator_VideoDiT)(
|
| 70 |
+
feature_indices=None,
|
| 71 |
+
num_blocks=30,
|
| 72 |
+
disc_type="dit_simple_conv3d",
|
| 73 |
+
inner_dim=1536 // 4,
|
| 74 |
+
)
|
| 75 |
+
|
| 76 |
+
# 14B patchify: spatial-2, temporal-1; inner_dim=5120; layer=40
|
| 77 |
+
Discriminator_Wan_14B_Config: DictConfig = L(Discriminator_VideoDiT)(
|
| 78 |
+
feature_indices=None,
|
| 79 |
+
num_blocks=40,
|
| 80 |
+
disc_type="dit_simple_conv3d",
|
| 81 |
+
inner_dim=5120 // 4,
|
| 82 |
+
)
|
| 83 |
+
|
| 84 |
+
# 5B patchify: spatial-2, temporal-1; inner_dim=3072; layer=30
|
| 85 |
+
Discriminator_Wan22_5B_Config: DictConfig = L(Discriminator_VideoDiT)(
|
| 86 |
+
feature_indices=None,
|
| 87 |
+
num_blocks=30,
|
| 88 |
+
disc_type="dit_simple_conv3d",
|
| 89 |
+
inner_dim=3072 // 4,
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
# Cosmos Predict2.5-2B: patchify spatial-2, temporal-1; inner_dim=2048; layer=28
|
| 93 |
+
Discriminator_CosmosPredict2_2B_Config: DictConfig = L(Discriminator_VideoDiT)(
|
| 94 |
+
feature_indices=None,
|
| 95 |
+
num_blocks=28,
|
| 96 |
+
disc_type="dit_simple_conv3d",
|
| 97 |
+
inner_dim=2048, # Must match model's inner_dim for Cosmos
|
| 98 |
+
)
|
| 99 |
+
|
| 100 |
+
# Cosmos Predict2.5-14B: patchify spatial-2, temporal-1; inner_dim=5120; layer=36
|
| 101 |
+
Discriminator_CosmosPredict2_14B_Config: DictConfig = L(Discriminator_VideoDiT)(
|
| 102 |
+
feature_indices=None,
|
| 103 |
+
num_blocks=36,
|
| 104 |
+
disc_type="dit_simple_conv3d",
|
| 105 |
+
inner_dim=5120, # Must match model's inner_dim for Cosmos
|
| 106 |
+
)
|
FastGen/fastgen/configs/experiments/CogVideoX/config_dmd2.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
from fastgen.configs.discriminator import Discriminator_CogVideoX2B_Config
|
| 6 |
+
import fastgen.configs.methods.config_dmd2 as config_dmd2_default
|
| 7 |
+
from fastgen.configs.data import VideoLatentLoaderConfig
|
| 8 |
+
from fastgen.configs.net import CogVideoXConfig
|
| 9 |
+
|
| 10 |
+
""" Configs for the DMD2 model on CogVideoX model. """
|
| 11 |
+
|
| 12 |
+
|
| 13 |
+
def create_config():
|
| 14 |
+
config = config_dmd2_default.create_config()
|
| 15 |
+
|
| 16 |
+
config.model.net_optimizer.lr = 1e-5
|
| 17 |
+
config.model.discriminator_optimizer.lr = 1e-5
|
| 18 |
+
config.model.fake_score_optimizer.lr = 1e-5
|
| 19 |
+
|
| 20 |
+
config.model.input_shape = [16, 13, 60, 90]
|
| 21 |
+
config.model.discriminator = Discriminator_CogVideoX2B_Config
|
| 22 |
+
config.model.discriminator.feature_indices = [15, 22, 29]
|
| 23 |
+
config.model.gan_loss_weight_gen = 0.03
|
| 24 |
+
config.model.net = CogVideoXConfig
|
| 25 |
+
config.model.guidance_scale = 6.0
|
| 26 |
+
config.model.enable_preprocessors = False
|
| 27 |
+
|
| 28 |
+
config.model.sample_t_cfg.time_dist_type = "uniform"
|
| 29 |
+
config.model.sample_t_cfg.min_t = 0.001
|
| 30 |
+
config.model.sample_t_cfg.max_t = 0.999
|
| 31 |
+
|
| 32 |
+
config.model.gan_use_same_t_noise = True
|
| 33 |
+
config.model.fake_score_pred_type = "x0"
|
| 34 |
+
config.model.student_sample_type = "ode"
|
| 35 |
+
|
| 36 |
+
# setting for 4-step training
|
| 37 |
+
config.model.student_sample_steps = 4
|
| 38 |
+
config.model.sample_t_cfg.t_list = [0.999, 0.937, 0.833, 0.624, 0.0]
|
| 39 |
+
|
| 40 |
+
config.dataloader_train = VideoLatentLoaderConfig
|
| 41 |
+
config.dataloader_train.batch_size = 2
|
| 42 |
+
|
| 43 |
+
config.trainer.max_iter = 10000
|
| 44 |
+
config.trainer.logging_iter = 100
|
| 45 |
+
config.trainer.save_ckpt_iter = 500
|
| 46 |
+
|
| 47 |
+
config.log_config.group = "CogVideoX_dmd2"
|
| 48 |
+
return config
|
FastGen/fastgen/configs/experiments/CogVideoX/config_kd.py
ADDED
|
@@ -0,0 +1,30 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
|
| 5 |
+
import fastgen.configs.methods.config_kd as config_kd_default
|
| 6 |
+
from fastgen.configs.data import PairLoaderConfig
|
| 7 |
+
from fastgen.configs.net import CogVideoXConfig
|
| 8 |
+
|
| 9 |
+
""" Configs for the KD model on CogVideoX model. """
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_config():
|
| 13 |
+
config = config_kd_default.create_config()
|
| 14 |
+
|
| 15 |
+
config.trainer.max_iter = 6000
|
| 16 |
+
config.trainer.logging_iter = 100
|
| 17 |
+
config.trainer.save_ckpt_iter = 500
|
| 18 |
+
|
| 19 |
+
config.model.net_optimizer.lr = 1e-4
|
| 20 |
+
|
| 21 |
+
config.model.input_shape = [16, 13, 60, 90]
|
| 22 |
+
config.model.net = CogVideoXConfig
|
| 23 |
+
config.model.enable_preprocessors = False
|
| 24 |
+
config.model.precision = "bfloat16"
|
| 25 |
+
|
| 26 |
+
config.dataloader_train = PairLoaderConfig
|
| 27 |
+
config.dataloader_train.batch_size = 2
|
| 28 |
+
|
| 29 |
+
config.log_config.group = "CogVideoX_kd"
|
| 30 |
+
return config
|
FastGen/fastgen/configs/experiments/CogVideoX/config_sft.py
ADDED
|
@@ -0,0 +1,35 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import fastgen.configs.methods.config_sft as config_sft_default
|
| 5 |
+
from fastgen.configs.data import VideoLatentLoaderConfig
|
| 6 |
+
from fastgen.configs.net import CogVideoXConfig
|
| 7 |
+
|
| 8 |
+
""" Configs for the SFT model on CogVideoX model. """
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_config():
|
| 12 |
+
config = config_sft_default.create_config()
|
| 13 |
+
|
| 14 |
+
config.trainer.logging_iter = 500
|
| 15 |
+
|
| 16 |
+
config.model.net_optimizer.lr = 5e-5
|
| 17 |
+
|
| 18 |
+
config.model.guidance_scale = 6.0
|
| 19 |
+
config.model.student_sample_steps = 50
|
| 20 |
+
|
| 21 |
+
config.model.sample_t_cfg.time_dist_type = "uniform"
|
| 22 |
+
config.model.sample_t_cfg.min_t = 0.001
|
| 23 |
+
config.model.sample_t_cfg.max_t = 0.999
|
| 24 |
+
|
| 25 |
+
# CogVideoX latent shape: [C, T, H, W] = [16, 13, 60, 90]
|
| 26 |
+
# Corresponds to 49 frames at 480x720 resolution after VAE encoding
|
| 27 |
+
config.model.input_shape = [16, 13, 60, 90]
|
| 28 |
+
config.model.net = CogVideoXConfig
|
| 29 |
+
config.model.enable_preprocessors = False # Using precomputed latents
|
| 30 |
+
|
| 31 |
+
config.dataloader_train = VideoLatentLoaderConfig
|
| 32 |
+
config.dataloader_train.batch_size = 2
|
| 33 |
+
|
| 34 |
+
config.log_config.group = "CogVideoX_sft"
|
| 35 |
+
return config
|
FastGen/fastgen/configs/experiments/CogVideoX/config_sft_5b.py
ADDED
|
@@ -0,0 +1,40 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
import fastgen.configs.methods.config_sft as config_sft_default
|
| 5 |
+
from fastgen.configs.data import VideoLatentLoaderConfig
|
| 6 |
+
from fastgen.configs.net import CogVideoX5BConfig
|
| 7 |
+
|
| 8 |
+
""" Configs for the SFT model on CogVideoX-5B model. """
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_config():
|
| 12 |
+
config = config_sft_default.create_config()
|
| 13 |
+
|
| 14 |
+
config.trainer.logging_iter = 500
|
| 15 |
+
|
| 16 |
+
# Slightly lower LR for larger 5B model
|
| 17 |
+
config.model.net_optimizer.lr = 2e-5
|
| 18 |
+
|
| 19 |
+
config.model.guidance_scale = 6.0
|
| 20 |
+
config.model.student_sample_steps = 50
|
| 21 |
+
|
| 22 |
+
config.model.sample_t_cfg.time_dist_type = "uniform"
|
| 23 |
+
config.model.sample_t_cfg.min_t = 0.001
|
| 24 |
+
config.model.sample_t_cfg.max_t = 0.999
|
| 25 |
+
|
| 26 |
+
config.model.precision = "bfloat16"
|
| 27 |
+
|
| 28 |
+
# CogVideoX latent shape: [C, T, H, W] = [16, 13, 60, 90]
|
| 29 |
+
# Corresponds to 49 frames at 480x720 resolution after VAE encoding
|
| 30 |
+
# Same as 2B variant - they share the same VAE
|
| 31 |
+
config.model.input_shape = [16, 13, 60, 90]
|
| 32 |
+
config.model.net = CogVideoX5BConfig
|
| 33 |
+
config.model.enable_preprocessors = False # Using precomputed latents
|
| 34 |
+
|
| 35 |
+
config.dataloader_train = VideoLatentLoaderConfig
|
| 36 |
+
# Reduced batch size due to larger model
|
| 37 |
+
config.dataloader_train.batch_size = 1
|
| 38 |
+
|
| 39 |
+
config.log_config.group = "CogVideoX5B_sft"
|
| 40 |
+
return config
|
FastGen/fastgen/configs/experiments/CosmosPredict2/config_dmd2.py
ADDED
|
@@ -0,0 +1,69 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""DMD2 config for Cosmos-Predict2.5-2B model."""
|
| 5 |
+
|
| 6 |
+
import fastgen.configs.methods.config_dmd2 as config_dmd2_default
|
| 7 |
+
from fastgen.configs.data import VideoLoaderConfig
|
| 8 |
+
from fastgen.configs.discriminator import Discriminator_CosmosPredict2_2B_Config
|
| 9 |
+
from fastgen.configs.net import CosmosPredict2_2B_Config, CosmosPredict2_2B_Aggressive_Config, CKPT_ROOT_DIR
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
def create_config():
|
| 13 |
+
config = config_dmd2_default.create_config()
|
| 14 |
+
|
| 15 |
+
config.trainer.max_iter = 10000
|
| 16 |
+
config.trainer.logging_iter = 100
|
| 17 |
+
config.trainer.save_ckpt_iter = 500
|
| 18 |
+
|
| 19 |
+
# Optimizer settings
|
| 20 |
+
config.model.net_optimizer.lr = 1e-5
|
| 21 |
+
config.model.discriminator_optimizer.lr = 1e-5
|
| 22 |
+
config.model.fake_score_optimizer.lr = 1e-5
|
| 23 |
+
|
| 24 |
+
config.model.precision = "bfloat16"
|
| 25 |
+
|
| 26 |
+
# Latent shape: [C, T_latent, H_latent, W_latent]
|
| 27 |
+
# Cosmos VAE uses 4x8x8 compression (time, height, width)
|
| 28 |
+
# Small resolution for testing: 320x176 video -> 40x22 latent, 93 frames -> 24 latent
|
| 29 |
+
# Full 720p: [16, 24, 88, 160] (1280x704 @ 93 frames)
|
| 30 |
+
# config.model.input_shape = [16, 24, 22, 40] # cthw - 256p (320x176 video)
|
| 31 |
+
config.model.input_shape = [16, 24, 60, 104] # cthw - 480p, 93 frames
|
| 32 |
+
# config.model.input_shape = [16, 24, 88, 160] # cthw - full 720p, 93 frames
|
| 33 |
+
|
| 34 |
+
# Network and discriminator
|
| 35 |
+
config.model.net = CosmosPredict2_2B_Config
|
| 36 |
+
config.model.discriminator = Discriminator_CosmosPredict2_2B_Config
|
| 37 |
+
config.model.discriminator.disc_type = "multiscale_down_mlp_large"
|
| 38 |
+
config.model.discriminator.feature_indices = [13, 20, 27]
|
| 39 |
+
# Teacher uses AGGRESSIVE SAC for memory savings
|
| 40 |
+
config.model.teacher = CosmosPredict2_2B_Aggressive_Config
|
| 41 |
+
|
| 42 |
+
# DMD2 settings
|
| 43 |
+
config.model.gan_loss_weight_gen = 0.03
|
| 44 |
+
config.model.gan_use_same_t_noise = True
|
| 45 |
+
config.model.fake_score_pred_type = "x0"
|
| 46 |
+
config.model.student_sample_type = "ode"
|
| 47 |
+
config.model.guidance_scale = 3.0
|
| 48 |
+
config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cosmos_predict2/Cosmos-Predict2.5-2B/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt"
|
| 49 |
+
|
| 50 |
+
# Timestep sampling
|
| 51 |
+
config.model.sample_t_cfg.time_dist_type = "shifted"
|
| 52 |
+
config.model.sample_t_cfg.min_t = 0.001
|
| 53 |
+
config.model.sample_t_cfg.max_t = 0.999
|
| 54 |
+
|
| 55 |
+
# setting for 4-step training
|
| 56 |
+
config.model.student_sample_steps = 4
|
| 57 |
+
config.model.sample_t_cfg.t_list = [0.999, 0.937, 0.833, 0.624, 0.0]
|
| 58 |
+
|
| 59 |
+
# Dataloader settings
|
| 60 |
+
config.dataloader_train = VideoLoaderConfig
|
| 61 |
+
config.dataloader_train.batch_size = 1
|
| 62 |
+
|
| 63 |
+
# Dataloader img_size = (W, H) = (latent_W * 8, latent_H * 8)
|
| 64 |
+
config.dataloader_train.img_size = (config.model.input_shape[-1] * 8, config.model.input_shape[-2] * 8)
|
| 65 |
+
config.dataloader_train.sequence_length = (config.model.input_shape[1] - 1) * 4 + 1
|
| 66 |
+
|
| 67 |
+
config.log_config.group = "cosmos_predict2_dmd2"
|
| 68 |
+
|
| 69 |
+
return config
|
FastGen/fastgen/configs/experiments/CosmosPredict2/config_dmd2_14b.py
ADDED
|
@@ -0,0 +1,37 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""DMD2 config for Cosmos-Predict2.5-14B model."""
|
| 5 |
+
|
| 6 |
+
import fastgen.configs.experiments.CosmosPredict2.config_dmd2 as config_dmd2_base
|
| 7 |
+
from fastgen.configs.discriminator import Discriminator_CosmosPredict2_14B_Config
|
| 8 |
+
from fastgen.configs.net import CosmosPredict2_14B_Config, CosmosPredict2_14B_Aggressive_Config, CKPT_ROOT_DIR
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_config():
|
| 12 |
+
config = config_dmd2_base.create_config()
|
| 13 |
+
config.trainer.fsdp_cpu_offload = True
|
| 14 |
+
|
| 15 |
+
# Latent shape: [C, T_latent, H_latent, W_latent]
|
| 16 |
+
# Cosmos VAE uses 4x8x8 compression (time, height, width)
|
| 17 |
+
# Small resolution for testing: 320x176 video -> 40x22 latent, 93 frames -> 24 latent
|
| 18 |
+
# Full 720p: [16, 24, 88, 160] (1280x704 @ 93 frames)
|
| 19 |
+
# config.model.input_shape = [16, 24, 22, 40] # cthw - 256p (320x176 video)
|
| 20 |
+
config.model.input_shape = [16, 24, 60, 104] # cthw - 480p, 93 frames
|
| 21 |
+
# config.model.input_shape = [16, 24, 88, 160] # cthw - full 720p, 93 frames
|
| 22 |
+
|
| 23 |
+
# Network and discriminator for 14B
|
| 24 |
+
config.model.net = CosmosPredict2_14B_Config
|
| 25 |
+
config.model.discriminator = Discriminator_CosmosPredict2_14B_Config
|
| 26 |
+
# Teacher uses AGGRESSIVE SAC for memory savings
|
| 27 |
+
config.model.teacher = CosmosPredict2_14B_Aggressive_Config
|
| 28 |
+
|
| 29 |
+
config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cosmos_predict2/Cosmos-Predict2.5-14B/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt"
|
| 30 |
+
|
| 31 |
+
# Dataloader img_size = (W, H) = (latent_W * 8, latent_H * 8)
|
| 32 |
+
config.dataloader_train.img_size = (config.model.input_shape[-1] * 8, config.model.input_shape[-2] * 8)
|
| 33 |
+
config.dataloader_train.sequence_length = (config.model.input_shape[1] - 1) * 4 + 1
|
| 34 |
+
|
| 35 |
+
config.log_config.group = "cosmos_predict2_14b_dmd2"
|
| 36 |
+
|
| 37 |
+
return config
|
FastGen/fastgen/configs/experiments/CosmosPredict2/config_dmd2_v2w.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""DMD2 config for Cosmos-Predict2.5-2B video2world model."""
|
| 5 |
+
|
| 6 |
+
import fastgen.configs.experiments.CosmosPredict2.config_dmd2 as config_dmd2_base
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def create_config():
|
| 10 |
+
config = config_dmd2_base.create_config()
|
| 11 |
+
|
| 12 |
+
# Network for v2w
|
| 13 |
+
config.model.net.is_video2world = True
|
| 14 |
+
config.model.net.num_conditioning_frames = 1
|
| 15 |
+
|
| 16 |
+
config.log_config.group = "cosmos_predict2_dmd2_v2w"
|
| 17 |
+
|
| 18 |
+
return config
|
FastGen/fastgen/configs/experiments/CosmosPredict2/config_sft.py
ADDED
|
@@ -0,0 +1,48 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""Configs for SFT on Cosmos-Predict2.5-2B model."""
|
| 5 |
+
|
| 6 |
+
import fastgen.configs.methods.config_sft as config_sft_default
|
| 7 |
+
from fastgen.configs.data import VideoLoaderConfig
|
| 8 |
+
from fastgen.configs.net import CosmosPredict2_2B_Config, CKPT_ROOT_DIR
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_config():
|
| 12 |
+
config = config_sft_default.create_config()
|
| 13 |
+
config.trainer.max_iter = 10000
|
| 14 |
+
config.trainer.logging_iter = 100
|
| 15 |
+
config.trainer.save_ckpt_iter = 500
|
| 16 |
+
|
| 17 |
+
config.model.net_optimizer.lr = 1e-5
|
| 18 |
+
|
| 19 |
+
config.model.sample_t_cfg.time_dist_type = "uniform"
|
| 20 |
+
config.model.sample_t_cfg.min_t = 0.001
|
| 21 |
+
config.model.sample_t_cfg.max_t = 0.999
|
| 22 |
+
|
| 23 |
+
config.model.precision = "bfloat16"
|
| 24 |
+
|
| 25 |
+
# Latent shape: [C, T_latent, H_latent, W_latent]
|
| 26 |
+
# Cosmos VAE uses 4x8x8 compression (time, height, width)
|
| 27 |
+
# Small resolution for testing: 320x176 video -> 40x22 latent, 21 frames -> 6 latent
|
| 28 |
+
# Full 720p: [16, 24, 88, 160] (1280x704 @ 93 frames)
|
| 29 |
+
# config.model.input_shape = [16, 24, 24, 40] # cthw - 256p
|
| 30 |
+
# config.model.input_shape = [16, 21, 60, 104] # cthw - 480p, 81 frames
|
| 31 |
+
config.model.input_shape = [16, 24, 60, 104] # cthw - 480p
|
| 32 |
+
# config.model.input_shape = [16, 24, 88, 160] # cthw - full 720p
|
| 33 |
+
|
| 34 |
+
config.model.net = CosmosPredict2_2B_Config
|
| 35 |
+
config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cosmos_predict2/Cosmos-Predict2.5-2B/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt"
|
| 36 |
+
config.model.guidance_scale = 3.0
|
| 37 |
+
config.model.student_sample_steps = 35
|
| 38 |
+
|
| 39 |
+
config.dataloader_train = VideoLoaderConfig
|
| 40 |
+
config.dataloader_train.batch_size = 1
|
| 41 |
+
|
| 42 |
+
# Dataloader img_size = (W, H) = (latent_W * 8, latent_H * 8)
|
| 43 |
+
config.dataloader_train.img_size = (config.model.input_shape[-1] * 8, config.model.input_shape[-2] * 8)
|
| 44 |
+
config.dataloader_train.sequence_length = (config.model.input_shape[1] - 1) * 4 + 1
|
| 45 |
+
|
| 46 |
+
config.log_config.group = "cosmos_predict2_sft"
|
| 47 |
+
|
| 48 |
+
return config
|
FastGen/fastgen/configs/experiments/CosmosPredict2/config_sft_14b.py
ADDED
|
@@ -0,0 +1,19 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""Configs for SFT on Cosmos-Predict2.5-14B model."""
|
| 5 |
+
|
| 6 |
+
import fastgen.configs.experiments.CosmosPredict2.config_sft as config_sft_base
|
| 7 |
+
from fastgen.configs.net import CosmosPredict2_14B_Config, CKPT_ROOT_DIR
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_config():
|
| 11 |
+
config = config_sft_base.create_config()
|
| 12 |
+
|
| 13 |
+
# Network for 14B
|
| 14 |
+
config.model.net = CosmosPredict2_14B_Config
|
| 15 |
+
config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cosmos_predict2/Cosmos-Predict2.5-14B/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt"
|
| 16 |
+
|
| 17 |
+
config.log_config.group = "cosmos_predict2_14b_sft"
|
| 18 |
+
|
| 19 |
+
return config
|
FastGen/fastgen/configs/experiments/CosmosPredict2/config_sft_v2w.py
ADDED
|
@@ -0,0 +1,18 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
"""Configs for SFT on Cosmos-Predict2.5-2B video2world model."""
|
| 5 |
+
|
| 6 |
+
import fastgen.configs.experiments.CosmosPredict2.config_sft as config_sft_base
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
def create_config():
|
| 10 |
+
config = config_sft_base.create_config()
|
| 11 |
+
|
| 12 |
+
# Network for v2w
|
| 13 |
+
config.model.net.is_video2world = True
|
| 14 |
+
config.model.net.num_conditioning_frames = 1
|
| 15 |
+
|
| 16 |
+
config.log_config.group = "cosmos_predict2_sft_v2w"
|
| 17 |
+
|
| 18 |
+
return config
|
FastGen/fastgen/configs/experiments/DiT/config_mf_b.py
ADDED
|
@@ -0,0 +1,70 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
|
| 6 |
+
import fastgen.configs.methods.config_mean_flow as config_mf_default
|
| 7 |
+
from fastgen.configs.data import ImageNet256_Loader_Config
|
| 8 |
+
from fastgen.configs.net import DiT_IN256_B_Config
|
| 9 |
+
from fastgen.configs.callbacks import EMA_CONST_CALLBACKS
|
| 10 |
+
|
| 11 |
+
""" Configs for the MeanFlow model, on DiT-XL and ImageNet-256 dataset. """
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_config():
|
| 15 |
+
config = config_mf_default.create_config()
|
| 16 |
+
|
| 17 |
+
# model
|
| 18 |
+
config.model.input_shape = [4, 32, 32]
|
| 19 |
+
config.model.precision_amp = "bfloat16"
|
| 20 |
+
config.model.precision_amp_jvp = "float32"
|
| 21 |
+
config.model.cond_dropout_prob = 0.1
|
| 22 |
+
config.model.guidance_mixture_ratio = 0.5
|
| 23 |
+
|
| 24 |
+
config.model.sample_t_cfg.time_dist_type = "logitnormal"
|
| 25 |
+
config.model.sample_t_cfg.train_p_mean = -0.4
|
| 26 |
+
config.model.sample_t_cfg.train_p_std = 1.0
|
| 27 |
+
config.model.sample_t_cfg.min_t = 0.0
|
| 28 |
+
config.model.sample_t_cfg.max_t = 0.999
|
| 29 |
+
config.model.sample_t_cfg.r_sample_ratio = 0.25
|
| 30 |
+
|
| 31 |
+
config.model.loss_config.norm_method = "poly_1.0"
|
| 32 |
+
config.model.loss_config.norm_const = 1.0
|
| 33 |
+
config.model.loss_config.tangent_warmup_steps = 0
|
| 34 |
+
config.model.loss_config.loss_type = "l2"
|
| 35 |
+
|
| 36 |
+
config.model.net = DiT_IN256_B_Config
|
| 37 |
+
# remove the additional 1000 factor in JVP
|
| 38 |
+
config.model.net.scale_t = False
|
| 39 |
+
config.model.net.r_timestep = True
|
| 40 |
+
config.model.net.time_cond_type = "diff"
|
| 41 |
+
|
| 42 |
+
config.model.net_optimizer.optim_type = "adam"
|
| 43 |
+
config.model.net_optimizer.lr = 1e-4
|
| 44 |
+
config.model.net_optimizer.betas = (0.9, 0.95)
|
| 45 |
+
config.model.net_optimizer.weight_decay = 0.0
|
| 46 |
+
|
| 47 |
+
# ema
|
| 48 |
+
config.model.use_ema = ["ema_9999", "ema_99995", "ema_9996"]
|
| 49 |
+
config.trainer.callbacks = DictConfig(
|
| 50 |
+
{k: v for k, v in config.trainer.callbacks.items() if not k.startswith("ema")}
|
| 51 |
+
)
|
| 52 |
+
config.trainer.callbacks.update(EMA_CONST_CALLBACKS)
|
| 53 |
+
|
| 54 |
+
# Recommended setting for 2-step:
|
| 55 |
+
# config.model.sample_t_cfg.t_list = [0.999, 0.5, 0.0]
|
| 56 |
+
# config.model.student_sample_steps = 2
|
| 57 |
+
|
| 58 |
+
# dataloader
|
| 59 |
+
config.dataloader_train = ImageNet256_Loader_Config
|
| 60 |
+
config.dataloader_train.batch_size = 32
|
| 61 |
+
|
| 62 |
+
# trainer
|
| 63 |
+
config.trainer.batch_size_global = 1024
|
| 64 |
+
config.trainer.max_iter = 1200000
|
| 65 |
+
config.trainer.save_ckpt_iter = 50000
|
| 66 |
+
config.trainer.logging_iter = 10000
|
| 67 |
+
|
| 68 |
+
config.log_config.group = "imagenet256"
|
| 69 |
+
|
| 70 |
+
return config
|
FastGen/fastgen/configs/experiments/DiT/config_mf_xl.py
ADDED
|
@@ -0,0 +1,26 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from fastgen.configs.net import DiT_IN256_XL_Config
|
| 5 |
+
import fastgen.configs.experiments.DiT.config_mf_b as config_mf_default
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
""" Configs for the MeanFlow model, on DiT-XL and ImageNet-256 dataset. """
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def create_config():
|
| 12 |
+
config = config_mf_default.create_config()
|
| 13 |
+
|
| 14 |
+
config.model.guidance_t_end = 0.75
|
| 15 |
+
config.model.guidance_scale = 0.2
|
| 16 |
+
config.model.guidance_mixture_ratio = 0.92
|
| 17 |
+
|
| 18 |
+
config.model.net = DiT_IN256_XL_Config
|
| 19 |
+
# remove the additional 1000 factor in JVP
|
| 20 |
+
config.model.net.scale_t = False
|
| 21 |
+
config.model.net.r_timestep = True
|
| 22 |
+
config.model.net.time_cond_type = "diff"
|
| 23 |
+
|
| 24 |
+
config.dataloader_train.batch_size = 8
|
| 25 |
+
|
| 26 |
+
return config
|
FastGen/fastgen/configs/experiments/DiT/config_sft_dit_xl.py
ADDED
|
@@ -0,0 +1,68 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
|
| 6 |
+
import fastgen.configs.methods.config_sft as config_sft_default
|
| 7 |
+
from fastgen.configs.data import ImageNet256_Loader_Config
|
| 8 |
+
from fastgen.configs.net import DiT_IN256_XL_Config, CKPT_ROOT_DIR
|
| 9 |
+
from fastgen.configs.callbacks import EMA_CONST_CALLBACKS
|
| 10 |
+
|
| 11 |
+
"""Configs for SFT (Supervised Fine-Tuning) on DiT-XL and ImageNet-256 dataset."""
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def create_config():
|
| 15 |
+
config = config_sft_default.create_config()
|
| 16 |
+
|
| 17 |
+
# model
|
| 18 |
+
# DiT latent shape: [C, H, W] = [4, 32, 32] for 256x256 images (256/8 = 32)
|
| 19 |
+
config.model.input_shape = [4, 32, 32]
|
| 20 |
+
config.model.precision_amp = "bfloat16"
|
| 21 |
+
config.model.cond_dropout_prob = 0.1 # 10% dropout for CFG training
|
| 22 |
+
|
| 23 |
+
# Timestep sampling config
|
| 24 |
+
config.model.sample_t_cfg.time_dist_type = "logitnormal"
|
| 25 |
+
config.model.sample_t_cfg.train_p_mean = -0.4
|
| 26 |
+
config.model.sample_t_cfg.train_p_std = 1.0
|
| 27 |
+
config.model.sample_t_cfg.min_t = 0.001
|
| 28 |
+
config.model.sample_t_cfg.max_t = 0.999
|
| 29 |
+
|
| 30 |
+
# Pretrained model path
|
| 31 |
+
# DiT-XL/2 ImageNet-256 checkpoint from https://github.com/facebookresearch/DiT
|
| 32 |
+
config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/imagenet-256/DiT-XL-2-256x256.pt"
|
| 33 |
+
|
| 34 |
+
# Network config
|
| 35 |
+
config.model.net = DiT_IN256_XL_Config
|
| 36 |
+
config.model.net.learn_sigma = True # Facebook DiT checkpoint was trained with learn_sigma=True
|
| 37 |
+
# Facebook DiT was trained with DDPM (epsilon prediction), not flow matching
|
| 38 |
+
config.model.net.net_pred_type = "eps"
|
| 39 |
+
config.model.net.schedule_type = "sd" # Uses same linear beta schedule as DiT (0.0001 to 0.02)
|
| 40 |
+
|
| 41 |
+
# Optimizer config (lower LR for fine-tuning; 1e-4 is for training from scratch)
|
| 42 |
+
config.model.net_optimizer.optim_type = "adamw"
|
| 43 |
+
config.model.net_optimizer.lr = 1e-5 # 10x lower for fine-tuning
|
| 44 |
+
config.model.net_optimizer.betas = (0.9, 0.95)
|
| 45 |
+
config.model.net_optimizer.weight_decay = 0.0
|
| 46 |
+
|
| 47 |
+
# EMA config
|
| 48 |
+
config.model.use_ema = ["ema_9999", "ema_99995"]
|
| 49 |
+
config.trainer.callbacks = DictConfig(
|
| 50 |
+
{k: v for k, v in config.trainer.callbacks.items() if not k.startswith("ema")}
|
| 51 |
+
)
|
| 52 |
+
config.trainer.callbacks.update(EMA_CONST_CALLBACKS)
|
| 53 |
+
|
| 54 |
+
# Dataloader
|
| 55 |
+
config.dataloader_train = ImageNet256_Loader_Config
|
| 56 |
+
|
| 57 |
+
# Sampling config for visualization
|
| 58 |
+
config.model.student_sample_steps = 50 # DiT typically uses 50-250 steps
|
| 59 |
+
|
| 60 |
+
# Trainer
|
| 61 |
+
config.trainer.batch_size_global = 256
|
| 62 |
+
config.trainer.max_iter = 400000
|
| 63 |
+
config.trainer.save_ckpt_iter = 10000
|
| 64 |
+
config.trainer.logging_iter = 1000
|
| 65 |
+
|
| 66 |
+
config.log_config.group = "dit_xl_imagenet256_sft"
|
| 67 |
+
|
| 68 |
+
return config
|
FastGen/fastgen/configs/experiments/DiT/config_sft_sit_xl.py
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
|
| 6 |
+
import fastgen.configs.methods.config_sft as config_sft_default
|
| 7 |
+
from fastgen.configs.data import ImageNet256_Loader_Config
|
| 8 |
+
from fastgen.configs.net import DiT_IN256_XL_Config, CKPT_ROOT_DIR
|
| 9 |
+
from fastgen.configs.callbacks import EMA_CONST_CALLBACKS
|
| 10 |
+
|
| 11 |
+
"""Configs for SFT (Supervised Fine-Tuning) on SiT-XL and ImageNet-256 dataset.
|
| 12 |
+
|
| 13 |
+
SiT (Scalable Interpolant Transformers) uses the same architecture as DiT but
|
| 14 |
+
is trained with flow matching (rectified flow) instead of DDPM.
|
| 15 |
+
|
| 16 |
+
Reference: https://github.com/willisma/SiT
|
| 17 |
+
"""
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
def create_config():
|
| 21 |
+
config = config_sft_default.create_config()
|
| 22 |
+
|
| 23 |
+
# model
|
| 24 |
+
# SiT latent shape: [C, H, W] = [4, 32, 32] for 256x256 images (256/8 = 32)
|
| 25 |
+
config.model.input_shape = [4, 32, 32]
|
| 26 |
+
config.model.precision_amp = "bfloat16"
|
| 27 |
+
config.model.cond_dropout_prob = 0.1 # 10% dropout for CFG training
|
| 28 |
+
|
| 29 |
+
# Timestep sampling config for flow matching
|
| 30 |
+
# SiT uses uniform time sampling during training
|
| 31 |
+
config.model.sample_t_cfg.time_dist_type = "uniform"
|
| 32 |
+
config.model.sample_t_cfg.min_t = 0.001
|
| 33 |
+
config.model.sample_t_cfg.max_t = 0.999
|
| 34 |
+
|
| 35 |
+
# Pretrained model path
|
| 36 |
+
# SiT-XL/2 ImageNet-256 checkpoint from https://github.com/willisma/SiT
|
| 37 |
+
config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/imagenet-256/SiT-XL-2-256x256.pt"
|
| 38 |
+
|
| 39 |
+
# Network config - SiT uses DiT architecture
|
| 40 |
+
config.model.net = DiT_IN256_XL_Config
|
| 41 |
+
config.model.net.learn_sigma = True # SiT checkpoint outputs 8 channels
|
| 42 |
+
# SiT was trained with flow matching (rectified flow)
|
| 43 |
+
config.model.net.net_pred_type = "flow" # Flow/velocity prediction
|
| 44 |
+
config.model.net.schedule_type = "rf" # Use RF schedule
|
| 45 |
+
config.model.net.use_sit_convention = True # SiT convention: t -> 1-t, v -> -v
|
| 46 |
+
config.model.net.scale_t = False # SiT uses continuous time t in [0, 1]
|
| 47 |
+
|
| 48 |
+
# Optimizer config (lower LR for fine-tuning; 1e-4 is for training from scratch)
|
| 49 |
+
config.model.net_optimizer.optim_type = "adamw"
|
| 50 |
+
config.model.net_optimizer.lr = 1e-5 # 10x lower for fine-tuning
|
| 51 |
+
config.model.net_optimizer.betas = (0.9, 0.95)
|
| 52 |
+
config.model.net_optimizer.weight_decay = 0.0
|
| 53 |
+
|
| 54 |
+
# EMA config
|
| 55 |
+
config.model.use_ema = ["ema_9999", "ema_99995"]
|
| 56 |
+
config.trainer.callbacks = DictConfig(
|
| 57 |
+
{k: v for k, v in config.trainer.callbacks.items() if not k.startswith("ema")}
|
| 58 |
+
)
|
| 59 |
+
config.trainer.callbacks.update(EMA_CONST_CALLBACKS)
|
| 60 |
+
|
| 61 |
+
# Dataloader
|
| 62 |
+
config.dataloader_train = ImageNet256_Loader_Config
|
| 63 |
+
|
| 64 |
+
# Sampling config for visualization
|
| 65 |
+
config.model.student_sample_steps = 50 # Standard steps
|
| 66 |
+
config.model.guidance_scale = 4.0 # Standard CFG
|
| 67 |
+
|
| 68 |
+
# Trainer
|
| 69 |
+
config.trainer.batch_size_global = 256
|
| 70 |
+
config.trainer.max_iter = 400000
|
| 71 |
+
config.trainer.save_ckpt_iter = 10000
|
| 72 |
+
|
| 73 |
+
config.log_config.group = "sit_xl_imagenet256_sft"
|
| 74 |
+
|
| 75 |
+
return config
|
FastGen/fastgen/configs/experiments/EDM/config_cm_cifar10.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
from fastgen.configs.methods.config_cm import create_config as cm_create_config
|
| 6 |
+
from fastgen.configs.callbacks import EMA_CONST_CALLBACKS
|
| 7 |
+
from fastgen.configs.net import CKPT_ROOT_DIR
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_config():
|
| 11 |
+
config = cm_create_config()
|
| 12 |
+
|
| 13 |
+
# Recommended setting for 2-step:
|
| 14 |
+
# config.model.sample_t_cfg.t_list = [80.0, 0.821, 0.0]
|
| 15 |
+
# config.model.student_sample_steps = 2
|
| 16 |
+
|
| 17 |
+
config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cifar10/edm-cifar10-32x32-uncond-vp.pth"
|
| 18 |
+
|
| 19 |
+
config.model.use_ema = ["ema_9999", "ema_99995", "ema_9996"]
|
| 20 |
+
config.trainer.callbacks = DictConfig(
|
| 21 |
+
{k: v for k, v in config.trainer.callbacks.items() if not k.startswith("ema")}
|
| 22 |
+
)
|
| 23 |
+
config.trainer.callbacks.update(EMA_CONST_CALLBACKS)
|
| 24 |
+
|
| 25 |
+
config.trainer.max_iter = 350000
|
| 26 |
+
config.trainer.batch_size_global = 512
|
| 27 |
+
|
| 28 |
+
return config
|
FastGen/fastgen/configs/experiments/EDM/config_cm_cifar10_fast.py
ADDED
|
@@ -0,0 +1,25 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from fastgen.configs.methods.config_cm import create_config as cm_create_config
|
| 5 |
+
from fastgen.configs.net import CKPT_ROOT_DIR
|
| 6 |
+
|
| 7 |
+
|
| 8 |
+
def create_config():
|
| 9 |
+
config = cm_create_config()
|
| 10 |
+
|
| 11 |
+
# Recommended setting for 2-step:
|
| 12 |
+
# config.model.sample_t_cfg.t_list = [80.0, 0.821, 0.0]
|
| 13 |
+
# config.model.student_sample_steps = 2
|
| 14 |
+
|
| 15 |
+
config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cifar10/edm-cifar10-32x32-uncond-vp.pth"
|
| 16 |
+
|
| 17 |
+
config.trainer.callbacks.ct_schedule.kimg_per_stage = 512000
|
| 18 |
+
config.dataloader_train.batch_size = 128
|
| 19 |
+
|
| 20 |
+
config.trainer.max_iter = 8000
|
| 21 |
+
config.trainer.callbacks.ct_schedule.q = 256.0
|
| 22 |
+
config.trainer.callbacks.ema.beta = 0.9993
|
| 23 |
+
config.trainer.save_ckpt_iter = 500
|
| 24 |
+
config.trainer.logging_iter = 100
|
| 25 |
+
return config
|
FastGen/fastgen/configs/experiments/EDM/config_cm_in64.py
ADDED
|
@@ -0,0 +1,65 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
|
| 6 |
+
import fastgen.configs.methods.config_cm as config_cm_default
|
| 7 |
+
from fastgen.configs.data import ImageNet64_Loader_Config
|
| 8 |
+
from fastgen.configs.net import EDM_ImageNet64_Config, CKPT_ROOT_DIR
|
| 9 |
+
from fastgen.utils import LazyCall as L
|
| 10 |
+
from fastgen.utils.lr_scheduler import LambdaInverseSquareRootScheduler
|
| 11 |
+
from fastgen.configs.callbacks import EMA_POWER_CALLBACKS
|
| 12 |
+
|
| 13 |
+
""" Configs for the CM model, on EDM2 and ImageNet-64 dataset. """
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def create_config():
|
| 17 |
+
config = config_cm_default.create_config()
|
| 18 |
+
|
| 19 |
+
# trainer
|
| 20 |
+
# recommended setting for ImageNet-64 is max_iter * batch_size // (4 * 1000)
|
| 21 |
+
config.trainer.callbacks.ct_schedule.kimg_per_stage = 3200
|
| 22 |
+
config.trainer.callbacks.ct_schedule.q = 4
|
| 23 |
+
config.trainer.callbacks.ct_schedule.ratio_limit = 0.9961
|
| 24 |
+
|
| 25 |
+
# ema
|
| 26 |
+
config.model.use_ema = ["ema_1", "ema_5", "ema_10"]
|
| 27 |
+
config.trainer.callbacks = DictConfig(
|
| 28 |
+
{k: v for k, v in config.trainer.callbacks.items() if not k.startswith("ema")}
|
| 29 |
+
)
|
| 30 |
+
config.trainer.callbacks.update(EMA_POWER_CALLBACKS)
|
| 31 |
+
|
| 32 |
+
# model
|
| 33 |
+
config.model.precision_amp = "float16"
|
| 34 |
+
config.model.grad_scaler_enabled = True
|
| 35 |
+
config.model.grad_scaler_init_scale = 16
|
| 36 |
+
config.model.grad_scaler_growth_interval = 20000
|
| 37 |
+
config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/imagenet-64/edm-imagenet-64x64-cond-adm.pth"
|
| 38 |
+
config.model.input_shape = [3, 64, 64]
|
| 39 |
+
config.model.sample_t_cfg.train_p_mean = -0.8
|
| 40 |
+
config.model.sample_t_cfg.train_p_std = 1.6
|
| 41 |
+
config.model.loss_config.huber_const = 0.06
|
| 42 |
+
config.model.loss_config.weighting_ct_loss = "c_out_sq"
|
| 43 |
+
|
| 44 |
+
# Recommended setting for 2-step:
|
| 45 |
+
# config.model.sample_t_cfg.t_list = [80.0, 1.526, 0.0]
|
| 46 |
+
# config.model.student_sample_steps = 2
|
| 47 |
+
|
| 48 |
+
config.model.net = EDM_ImageNet64_Config
|
| 49 |
+
config.model.net.dropout = 0.2
|
| 50 |
+
config.model.net_optimizer.optim_type = "adam"
|
| 51 |
+
config.model.net_optimizer.lr = 1e-3
|
| 52 |
+
config.model.net_optimizer.betas = (0.9, 0.99)
|
| 53 |
+
config.model.net_optimizer.weight_decay = 0.0
|
| 54 |
+
config.model.net_scheduler = L(LambdaInverseSquareRootScheduler)(
|
| 55 |
+
warm_up_steps=0,
|
| 56 |
+
decay_steps=2000,
|
| 57 |
+
)
|
| 58 |
+
# During inference, sigma_shift can improve 2-step results
|
| 59 |
+
# config.model.net.sigma_shift = 0.003
|
| 60 |
+
|
| 61 |
+
# dataloader
|
| 62 |
+
config.dataloader_train = ImageNet64_Loader_Config
|
| 63 |
+
|
| 64 |
+
config.log_config.group = "edm_imagenet64_cm"
|
| 65 |
+
return config
|
FastGen/fastgen/configs/experiments/EDM/config_dmd2_cifar10.py
ADDED
|
@@ -0,0 +1,24 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
|
| 4 |
+
from omegaconf import DictConfig
|
| 5 |
+
from fastgen.configs.methods.config_dmd2 import create_config as dmd2_create_config
|
| 6 |
+
from fastgen.configs.callbacks import EMA_CONST_CALLBACKS
|
| 7 |
+
from fastgen.configs.net import CKPT_ROOT_DIR
|
| 8 |
+
|
| 9 |
+
|
| 10 |
+
def create_config():
|
| 11 |
+
config = dmd2_create_config()
|
| 12 |
+
|
| 13 |
+
config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cifar10/edm-cifar10-32x32-cond-vp.pth"
|
| 14 |
+
|
| 15 |
+
config.model.use_ema = ["ema_9999", "ema_99995", "ema_9996"]
|
| 16 |
+
config.trainer.callbacks = DictConfig(
|
| 17 |
+
{k: v for k, v in config.trainer.callbacks.items() if not k.startswith("ema")}
|
| 18 |
+
)
|
| 19 |
+
config.trainer.callbacks.update(EMA_CONST_CALLBACKS)
|
| 20 |
+
|
| 21 |
+
config.trainer.max_iter = 100000
|
| 22 |
+
config.trainer.batch_size_global = 2048
|
| 23 |
+
|
| 24 |
+
return config
|