taohu commited on
Commit
0839907
·
verified ·
1 Parent(s): f4f2bac

Upload folder using huggingface_hub

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +108 -0
  2. FastGen/.claude/settings.local.json +19 -0
  3. FastGen/.gitignore +180 -0
  4. FastGen/CLAUDE.md +7 -0
  5. FastGen/CONTRIBUTING.md +93 -0
  6. FastGen/Dockerfile +77 -0
  7. FastGen/LICENSE +202 -0
  8. FastGen/Makefile +29 -0
  9. FastGen/README.md +182 -0
  10. FastGen/SF.md +146 -0
  11. FastGen/assets/teaser.png +3 -0
  12. FastGen/fastgen/__init__.py +2 -0
  13. FastGen/fastgen/callbacks/README.md +59 -0
  14. FastGen/fastgen/callbacks/__init__.py +2 -0
  15. FastGen/fastgen/callbacks/callback.py +183 -0
  16. FastGen/fastgen/callbacks/ct_schedule.py +83 -0
  17. FastGen/fastgen/callbacks/ema.py +155 -0
  18. FastGen/fastgen/callbacks/forced_weight_norm.py +28 -0
  19. FastGen/fastgen/callbacks/gpu_mem_profiler.py +134 -0
  20. FastGen/fastgen/callbacks/gpu_stats.py +92 -0
  21. FastGen/fastgen/callbacks/grad_clip.py +222 -0
  22. FastGen/fastgen/callbacks/param_count.py +116 -0
  23. FastGen/fastgen/callbacks/train_profiler.py +138 -0
  24. FastGen/fastgen/callbacks/wandb.py +404 -0
  25. FastGen/fastgen/configs/README.md +108 -0
  26. FastGen/fastgen/configs/__init__.py +2 -0
  27. FastGen/fastgen/configs/callbacks.py +63 -0
  28. FastGen/fastgen/configs/config.py +254 -0
  29. FastGen/fastgen/configs/config_utils.py +317 -0
  30. FastGen/fastgen/configs/data.py +123 -0
  31. FastGen/fastgen/configs/data_dummy.py +67 -0
  32. FastGen/fastgen/configs/discriminator.py +106 -0
  33. FastGen/fastgen/configs/experiments/CogVideoX/config_dmd2.py +48 -0
  34. FastGen/fastgen/configs/experiments/CogVideoX/config_kd.py +30 -0
  35. FastGen/fastgen/configs/experiments/CogVideoX/config_sft.py +35 -0
  36. FastGen/fastgen/configs/experiments/CogVideoX/config_sft_5b.py +40 -0
  37. FastGen/fastgen/configs/experiments/CosmosPredict2/config_dmd2.py +69 -0
  38. FastGen/fastgen/configs/experiments/CosmosPredict2/config_dmd2_14b.py +37 -0
  39. FastGen/fastgen/configs/experiments/CosmosPredict2/config_dmd2_v2w.py +18 -0
  40. FastGen/fastgen/configs/experiments/CosmosPredict2/config_sft.py +48 -0
  41. FastGen/fastgen/configs/experiments/CosmosPredict2/config_sft_14b.py +19 -0
  42. FastGen/fastgen/configs/experiments/CosmosPredict2/config_sft_v2w.py +18 -0
  43. FastGen/fastgen/configs/experiments/DiT/config_mf_b.py +70 -0
  44. FastGen/fastgen/configs/experiments/DiT/config_mf_xl.py +26 -0
  45. FastGen/fastgen/configs/experiments/DiT/config_sft_dit_xl.py +68 -0
  46. FastGen/fastgen/configs/experiments/DiT/config_sft_sit_xl.py +75 -0
  47. FastGen/fastgen/configs/experiments/EDM/config_cm_cifar10.py +28 -0
  48. FastGen/fastgen/configs/experiments/EDM/config_cm_cifar10_fast.py +25 -0
  49. FastGen/fastgen/configs/experiments/EDM/config_cm_in64.py +65 -0
  50. FastGen/fastgen/configs/experiments/EDM/config_dmd2_cifar10.py +24 -0
.gitattributes CHANGED
@@ -33,3 +33,111 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ FastGen/assets/teaser.png filter=lfs diff=lfs merge=lfs -text
37
+ FastGen/scripts/inference/examples/00_child_swings_rusty_swing_set.mp4 filter=lfs diff=lfs merge=lfs -text
38
+ FastGen/scripts/inference/examples/00_child_swings_rusty_swing_set.png filter=lfs diff=lfs merge=lfs -text
39
+ FastGen/scripts/inference/examples/01_impressionist_rubber_duck_sunset.mp4 filter=lfs diff=lfs merge=lfs -text
40
+ FastGen/scripts/inference/examples/01_impressionist_rubber_duck_sunset.png filter=lfs diff=lfs merge=lfs -text
41
+ FastGen/scripts/inference/examples/02_two_cyclists_stop_sign.mp4 filter=lfs diff=lfs merge=lfs -text
42
+ FastGen/scripts/inference/examples/02_two_cyclists_stop_sign.png filter=lfs diff=lfs merge=lfs -text
43
+ FastGen/scripts/inference/examples/03_astronaut_cow_beach_van_gogh.mp4 filter=lfs diff=lfs merge=lfs -text
44
+ FastGen/scripts/inference/examples/03_astronaut_cow_beach_van_gogh.png filter=lfs diff=lfs merge=lfs -text
45
+ FastGen/scripts/inference/examples/04_bird_flying_church_tower_clock.mp4 filter=lfs diff=lfs merge=lfs -text
46
+ FastGen/scripts/inference/examples/04_bird_flying_church_tower_clock.png filter=lfs diff=lfs merge=lfs -text
47
+ checkpoints/Self-Forcing/vidprom_filtered_extended.txt filter=lfs diff=lfs merge=lfs -text
48
+ hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/comp_effic.png filter=lfs diff=lfs merge=lfs -text
49
+ hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/data_for_diff_stage.jpg filter=lfs diff=lfs merge=lfs -text
50
+ hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/i2v_res.png filter=lfs diff=lfs merge=lfs -text
51
+ hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/t2v_res.jpg filter=lfs diff=lfs merge=lfs -text
52
+ hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/vben_1.3b_vs_sota.png filter=lfs diff=lfs merge=lfs -text
53
+ hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/vben_vs_sota.png filter=lfs diff=lfs merge=lfs -text
54
+ hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/video_dit_arch.jpg filter=lfs diff=lfs merge=lfs -text
55
+ hf_models/Wan2.1-T2V-1.3B-Diffusers/assets/video_vae_res.jpg filter=lfs diff=lfs merge=lfs -text
56
+ hf_models/Wan2.1-T2V-1.3B-Diffusers/examples/i2v_input.JPG filter=lfs diff=lfs merge=lfs -text
57
+ hf_models/Wan2.1-T2V-1.3B-Diffusers/tokenizer/tokenizer.json filter=lfs diff=lfs merge=lfs -text
58
+ pip_wheels/accelerate-1.12.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
59
+ pip_wheels/anyio-4.12.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
60
+ pip_wheels/av-14.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
61
+ pip_wheels/babel-2.18.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
62
+ pip_wheels/beautifulsoup4-4.14.3-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
63
+ pip_wheels/bleach-6.3.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
64
+ pip_wheels/bokeh-3.8.2-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
65
+ pip_wheels/boto3-1.42.39-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
66
+ pip_wheels/botocore-1.42.39-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
67
+ pip_wheels/certifi-2026.1.4-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
68
+ pip_wheels/cffi-2.0.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl filter=lfs diff=lfs merge=lfs -text
69
+ pip_wheels/charset_normalizer-3.4.4-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl filter=lfs diff=lfs merge=lfs -text
70
+ pip_wheels/click-8.3.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
71
+ pip_wheels/contourpy-1.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
72
+ pip_wheels/debugpy-1.8.20-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
73
+ pip_wheels/diffusers-0.35.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
74
+ pip_wheels/fsspec-2026.1.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
75
+ pip_wheels/gitpython-3.1.46-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
76
+ pip_wheels/hf_xet-1.2.0-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
77
+ pip_wheels/huggingface_hub-0.36.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
78
+ pip_wheels/imageio-2.37.2-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
79
+ pip_wheels/imageio_ffmpeg-0.6.0-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
80
+ pip_wheels/ipykernel-7.1.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
81
+ pip_wheels/ipython-9.10.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
82
+ pip_wheels/ipywidgets-8.1.8-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
83
+ pip_wheels/jedi-0.19.2-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
84
+ pip_wheels/jinja2-3.1.6-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
85
+ pip_wheels/jupyter_client-8.8.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
86
+ pip_wheels/jupyter_server-2.17.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
87
+ pip_wheels/jupyterlab-4.5.3-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
88
+ pip_wheels/jupyterlab_widgets-3.0.16-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
89
+ pip_wheels/lark-1.3.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
90
+ pip_wheels/moviepy-2.2.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
91
+ pip_wheels/mpmath-1.3.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
92
+ pip_wheels/narwhals-2.16.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
93
+ pip_wheels/nbconvert-7.17.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
94
+ pip_wheels/networkx-3.6.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
95
+ pip_wheels/notebook-7.5.3-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
96
+ pip_wheels/numpy-1.26.4-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
97
+ pip_wheels/nvidia_cublas_cu12-12.4.5.8-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
98
+ pip_wheels/nvidia_cuda_cupti_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
99
+ pip_wheels/nvidia_cuda_nvrtc_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
100
+ pip_wheels/nvidia_cuda_runtime_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
101
+ pip_wheels/nvidia_cudnn_cu12-9.1.0.70-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
102
+ pip_wheels/nvidia_cufft_cu12-11.2.1.3-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
103
+ pip_wheels/nvidia_curand_cu12-10.3.5.147-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
104
+ pip_wheels/nvidia_cusolver_cu12-11.6.1.9-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
105
+ pip_wheels/nvidia_cusparse_cu12-12.3.1.170-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
106
+ pip_wheels/nvidia_cusparselt_cu12-0.6.2-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
107
+ pip_wheels/nvidia_nccl_cu12-2.21.5-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
108
+ pip_wheels/nvidia_nvjitlink_cu12-12.4.127-py3-none-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
109
+ pip_wheels/opencv_python_headless-4.11.0.86-cp37-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
110
+ pip_wheels/pandas-2.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
111
+ pip_wheels/parso-0.8.5-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
112
+ pip_wheels/pillow-11.3.0-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl filter=lfs diff=lfs merge=lfs -text
113
+ pip_wheels/plotly-6.5.2-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
114
+ pip_wheels/prompt_toolkit-3.0.52-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
115
+ pip_wheels/protobuf-6.33.5-cp39-abi3-manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
116
+ pip_wheels/psutil-7.2.2-cp36-abi3-manylinux2010_x86_64.manylinux_2_12_x86_64.manylinux_2_28_x86_64.whl filter=lfs diff=lfs merge=lfs -text
117
+ pip_wheels/pydantic-2.12.5-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
118
+ pip_wheels/pydantic_core-2.41.5-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
119
+ pip_wheels/pygments-2.19.2-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
120
+ pip_wheels/python_dateutil-2.9.0.post0-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
121
+ pip_wheels/pytz-2025.2-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
122
+ pip_wheels/pyyaml-6.0.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl filter=lfs diff=lfs merge=lfs -text
123
+ pip_wheels/pyzmq-26.4.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
124
+ pip_wheels/rdkit-2024.3.2-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
125
+ pip_wheels/regex-2026.1.15-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.manylinux_2_28_x86_64.whl filter=lfs diff=lfs merge=lfs -text
126
+ pip_wheels/rpds_py-0.30.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
127
+ pip_wheels/safetensors-0.7.0-cp38-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
128
+ pip_wheels/scipy-1.16.3-cp312-cp312-manylinux2014_x86_64.manylinux_2_17_x86_64.whl filter=lfs diff=lfs merge=lfs -text
129
+ pip_wheels/sentencepiece-0.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
130
+ pip_wheels/sentry_sdk-2.51.0-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
131
+ pip_wheels/setuptools-80.10.2-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
132
+ pip_wheels/sympy-1.13.1-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
133
+ pip_wheels/timm-1.0.24-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
134
+ pip_wheels/tokenizers-0.21.4-cp39-abi3-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
135
+ pip_wheels/torch-2.6.0-cp312-cp312-manylinux1_x86_64.whl filter=lfs diff=lfs merge=lfs -text
136
+ pip_wheels/torchvision-0.21.0-cp312-cp312-manylinux1_x86_64.whl filter=lfs diff=lfs merge=lfs -text
137
+ pip_wheels/tornado-6.5.4-cp39-abi3-manylinux_2_5_x86_64.manylinux1_x86_64.manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
138
+ pip_wheels/transformers-4.49.0-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
139
+ pip_wheels/triton-3.2.0-cp312-cp312-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
140
+ pip_wheels/tzdata-2025.3-py2.py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
141
+ pip_wheels/urllib3-2.6.3-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
142
+ pip_wheels/wandb-0.23.0-py3-none-manylinux_2_17_x86_64.manylinux2014_x86_64.whl filter=lfs diff=lfs merge=lfs -text
143
+ pip_wheels/widgetsnbextension-4.0.15-py3-none-any.whl filter=lfs diff=lfs merge=lfs -text
FastGen/.claude/settings.local.json ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "permissions": {
3
+ "allow": [
4
+ "Bash(conda activate:*)",
5
+ "Bash(pip install:*)",
6
+ "Bash(source:*)",
7
+ "Bash(conda info:*)",
8
+ "Bash(pip cache:*)",
9
+ "Bash(python:*)",
10
+ "Bash(curl:*)",
11
+ "Bash(echo:*)",
12
+ "Bash(huggingface-cli download:*)",
13
+ "Bash(chmod:*)",
14
+ "Bash(bash:*)",
15
+ "Bash(ls:*)",
16
+ "Bash(huggingface-cli:*)"
17
+ ]
18
+ }
19
+ }
FastGen/.gitignore ADDED
@@ -0,0 +1,180 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # macOS
2
+ .DS_Store
3
+
4
+ # Byte-compiled / optimized / DLL files
5
+ __pycache__/
6
+ *.py[cod]
7
+ *$py.class
8
+
9
+ # C extensions
10
+ *.so
11
+
12
+ # Distribution / packaging
13
+ .Python
14
+ build/
15
+ develop-eggs/
16
+ dist/
17
+ downloads/
18
+ eggs/
19
+ .eggs/
20
+ lib/
21
+ lib64/
22
+ parts/
23
+ sdist/
24
+ var/
25
+ wheels/
26
+ share/python-wheels/
27
+ *.egg-info/
28
+ .installed.cfg
29
+ *.egg
30
+ MANIFEST
31
+
32
+ # PyInstaller
33
+ # Usually these files are written by a python script from a template
34
+ # before PyInstaller builds the exe, so as to inject date/other infos into it.
35
+ *.manifest
36
+ *.spec
37
+
38
+ # Installer logs
39
+ pip-log.txt
40
+ pip-delete-this-directory.txt
41
+
42
+ # Unit test / coverage reports
43
+ htmlcov/
44
+ .tox/
45
+ .nox/
46
+ .coverage
47
+ .coverage.*
48
+ .cache
49
+ nosetests.xml
50
+ coverage.xml
51
+ *.cover
52
+ *.py,cover
53
+ .hypothesis/
54
+ .pytest_cache/
55
+ cover/
56
+
57
+ # Translations
58
+ *.mo
59
+ *.pot
60
+
61
+ # Django stuff:
62
+ *.log
63
+ local_settings.py
64
+ db.sqlite3
65
+ db.sqlite3-journal
66
+
67
+ # Flask stuff:
68
+ instance/
69
+ .webassets-cache
70
+
71
+ # Scrapy stuff:
72
+ .scrapy
73
+
74
+ # Sphinx documentation
75
+ docs/_build/
76
+
77
+ # PyBuilder
78
+ .pybuilder/
79
+ target/
80
+
81
+ # Jupyter Notebook
82
+ .ipynb_checkpoints
83
+
84
+ # IPython
85
+ profile_default/
86
+ ipython_config.py
87
+
88
+ # pyenv
89
+ # For a library or package, you might want to ignore these files since the code is
90
+ # intended to run in multiple environments; otherwise, check them in:
91
+ # .python-version
92
+
93
+ # pipenv
94
+ # According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
95
+ # However, in case of collaboration, if having platform-specific dependencies or dependencies
96
+ # having no cross-platform support, pipenv may install dependencies that don't work, or not
97
+ # install all needed dependencies.
98
+ #Pipfile.lock
99
+
100
+ # poetry
101
+ # Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
102
+ # This is especially recommended for binary packages to ensure reproducibility, and is more
103
+ # commonly ignored for libraries.
104
+ # https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
105
+ #poetry.lock
106
+
107
+ # pdm
108
+ # Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
109
+ #pdm.lock
110
+ # pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
111
+ # in version control.
112
+ # https://pdm.fming.dev/#use-with-ide
113
+ .pdm.toml
114
+
115
+ # PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
116
+ __pypackages__/
117
+
118
+ # Celery stuff
119
+ celerybeat-schedule
120
+ celerybeat.pid
121
+
122
+ # SageMath parsed files
123
+ *.sage.py
124
+
125
+ # Environments
126
+ .env
127
+ .venv
128
+ env/
129
+ venv/
130
+ ENV/
131
+ env.bak/
132
+ venv.bak/
133
+
134
+ # Spyder project settings
135
+ .spyderproject
136
+ .spyproject
137
+
138
+ # Rope project settings
139
+ .ropeproject
140
+
141
+ # mkdocs documentation
142
+ /site
143
+
144
+ # mypy
145
+ .mypy_cache/
146
+ .dmypy.json
147
+ dmypy.json
148
+
149
+ # Pyre type checker
150
+ .pyre/
151
+
152
+ # pytype static type analyzer
153
+ .pytype/
154
+
155
+ # Cython debug symbols
156
+ cython_debug/
157
+
158
+ # PyCharm
159
+ # JetBrains specific template is maintained in a separate JetBrains.gitignore that can
160
+ # be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
161
+ # and can be added to the global gitignore or merged into this file. For a more nuclear
162
+ # option (not recommended) you can uncomment the following to ignore the entire idea folder.
163
+ .idea/
164
+
165
+ # large files
166
+ *.zip
167
+
168
+ # temporal folder
169
+ tmp/
170
+
171
+ # credentials
172
+ credentials
173
+
174
+ # logs
175
+ FASTGEN_OUTPUT/
176
+ wandb/
177
+ *.err
178
+
179
+ # vscode
180
+ *.vscode/
FastGen/CLAUDE.md ADDED
@@ -0,0 +1,7 @@
 
 
 
 
 
 
 
 
1
+ # FastGen Project Notes
2
+
3
+ ## Environment
4
+
5
+ - Conda environment: `sf`
6
+ - Activate: `conda activate sf`
7
+ - Install: `pip install -e .`
FastGen/CONTRIBUTING.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Contributing to FastGen
2
+
3
+ Thank you for your interest in contributing to FastGen!
4
+
5
+ ## Issue Tracking
6
+
7
+ * All enhancement, bugfix, or change requests should begin with the creation of an issue.
8
+ * The issue should be reviewed and approved prior to code review.
9
+
10
+ ## Development Setup
11
+
12
+ All development should be done using the provided Docker image to ensure a consistent environment.
13
+
14
+ ```bash
15
+ # Build the Docker image
16
+ docker build -t fastgen:dev .
17
+
18
+ # Run interactively with GPU support (as current user, not root)
19
+ docker run --gpus all -it --rm \
20
+ --user $(id -u):$(id -g) \
21
+ -e HOME=/workspace \
22
+ -v /etc/passwd:/etc/passwd:ro \
23
+ -v /etc/group:/etc/group:ro \
24
+ -v $(pwd):/workspace \
25
+ fastgen:dev bash
26
+
27
+ # Inside the container, install linters
28
+ make install
29
+ ```
30
+
31
+ ## Code Style and Formatting
32
+
33
+ FastGen uses **Ruff** for linting and code formatting. Follow existing conventions when adding new code.
34
+
35
+ ```bash
36
+ make format # Auto-format code
37
+ make lint # Check compliance
38
+ ```
39
+
40
+ **Note:** The `fastgen/third_party/` directory is excluded from checks.
41
+
42
+ ## Continuous Integration
43
+
44
+ The GitLab CI pipeline runs automatically on every push and merge request:
45
+
46
+ | Stage | Commands |
47
+ |-------|----------|
48
+ | Lint | `make lint`, `make mypy` |
49
+ | Test | `make pytest` |
50
+ | Install | `make install-fastgen` |
51
+
52
+ **Note:** Before submitting a pull request, ensure all checks pass locally.
53
+ ## Pull Request Process
54
+
55
+ 1. Fork the repository (for external contributors) or create a feature branch.
56
+ 2. Make your changes following the code style guidelines.
57
+ 3. Run all checks locally (see [CI section](#continuous-integration)).
58
+ 4. Commit your changes with sign-off (PRs with unsigned commits will not be accepted):
59
+ ```bash
60
+ git commit -s -m "Add feature X"
61
+ ```
62
+ 5. Push and create a Pull Request.
63
+
64
+ **Note:** Keep PRs focused on a single concern and reference related issues (e.g., "Fixes #123").
65
+
66
+
67
+ ## Developer Certificate of Origin
68
+
69
+ ```
70
+ Developer Certificate of Origin
71
+ Version 1.1
72
+
73
+ Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
74
+ 1 Letterman Drive
75
+ Suite D4700
76
+ San Francisco, CA, 94129
77
+
78
+ Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed.
79
+ ```
80
+
81
+ ```
82
+ Developer's Certificate of Origin 1.1
83
+
84
+ By making a contribution to this project, I certify that:
85
+
86
+ (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or
87
+
88
+ (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or
89
+
90
+ (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it.
91
+
92
+ (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved.
93
+ ```
FastGen/Dockerfile ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ FROM nvcr.io/nvidia/pytorch:25.06-py3
2
+
3
+ # Defines the architecture (amd64 or arm64) automatically during build
4
+ ARG TARGETARCH
5
+
6
+ ENV PYTHONDONTWRITEBYTECODE=1
7
+ ENV PYTHONUNBUFFERED=1
8
+ ENV DEBIAN_FRONTEND=noninteractive
9
+
10
+ # Set the timezone
11
+ ENV TZ=America/Los_Angeles
12
+ RUN ln -snf /usr/share/zoneinfo/$TZ /etc/localtime && echo $TZ > /etc/timezone
13
+
14
+ # Install system dependencies
15
+ RUN apt-get update --fix-missing && \
16
+ apt-get install -y --no-install-recommends \
17
+ build-essential \
18
+ ninja-build \
19
+ cmake \
20
+ wget \
21
+ ca-certificates \
22
+ git \
23
+ curl \
24
+ unzip \
25
+ python3-tk \
26
+ && rm -rf /var/lib/apt/lists/* \
27
+ && cmake --version \
28
+ && echo "✅ All system dependencies installed successfully"
29
+
30
+ # Install s5cmd with architecture logic
31
+ RUN S5CMD_VERSION="2.1.0-beta.1" && \
32
+ if [ "$TARGETARCH" = "arm64" ]; then \
33
+ S5CMD_ARCH="Linux-arm64"; \
34
+ else \
35
+ S5CMD_ARCH="Linux-64bit"; \
36
+ fi && \
37
+ wget "https://github.com/peak/s5cmd/releases/download/v${S5CMD_VERSION}/s5cmd_${S5CMD_VERSION}_${S5CMD_ARCH}.tar.gz" && \
38
+ tar -xf "s5cmd_${S5CMD_VERSION}_${S5CMD_ARCH}.tar.gz" && \
39
+ install s5cmd /usr/local/bin/ && \
40
+ rm s5cmd "s5cmd_${S5CMD_VERSION}_${S5CMD_ARCH}.tar.gz"
41
+
42
+ # Set build optimization flags
43
+ ENV MAX_JOBS=4
44
+ ENV TORCH_CUDA_ARCH_LIST="8.0;8.6;9.0;10.0"
45
+
46
+ # Upgrade pip
47
+ RUN pip install --upgrade pip setuptools wheel
48
+ # numpy<2.0.0 is required for the preinstalled numba and thinc
49
+ RUN pip install wandb[media] \
50
+ diffusers==0.35.1 \
51
+ transformers==4.49.0 \
52
+ accelerate \
53
+ safetensors \
54
+ scipy \
55
+ einops \
56
+ jupyter \
57
+ hydra-core \
58
+ imageio \
59
+ tqdm \
60
+ webdataset \
61
+ av \
62
+ loguru \
63
+ boto3 \
64
+ timm \
65
+ ftfy \
66
+ opencv-python-headless \
67
+ sentencepiece \
68
+ "numpy<2.0.0"
69
+ RUN pip uninstall -y apex
70
+ RUN pip install --index-url=https://sc-hw-artf.nvidia.com/artifactory/api/pypi/hwinf-mlwfo-pypi/simple --upgrade one-logger-utils
71
+
72
+ WORKDIR /workspace
73
+
74
+ # Entry point
75
+ RUN (printf '#!/bin/bash\nexec \"$@\"\n' >> /entry.sh) && chmod a+x /entry.sh
76
+ ENTRYPOINT ["/entry.sh"]
77
+
FastGen/LICENSE ADDED
@@ -0,0 +1,202 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+
2
+ Apache License
3
+ Version 2.0, January 2004
4
+ http://www.apache.org/licenses/
5
+
6
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
7
+
8
+ 1. Definitions.
9
+
10
+ "License" shall mean the terms and conditions for use, reproduction,
11
+ and distribution as defined by Sections 1 through 9 of this document.
12
+
13
+ "Licensor" shall mean the copyright owner or entity authorized by
14
+ the copyright owner that is granting the License.
15
+
16
+ "Legal Entity" shall mean the union of the acting entity and all
17
+ other entities that control, are controlled by, or are under common
18
+ control with that entity. For the purposes of this definition,
19
+ "control" means (i) the power, direct or indirect, to cause the
20
+ direction or management of such entity, whether by contract or
21
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
22
+ outstanding shares, or (iii) beneficial ownership of such entity.
23
+
24
+ "You" (or "Your") shall mean an individual or Legal Entity
25
+ exercising permissions granted by this License.
26
+
27
+ "Source" form shall mean the preferred form for making modifications,
28
+ including but not limited to software source code, documentation
29
+ source, and configuration files.
30
+
31
+ "Object" form shall mean any form resulting from mechanical
32
+ transformation or translation of a Source form, including but
33
+ not limited to compiled object code, generated documentation,
34
+ and conversions to other media types.
35
+
36
+ "Work" shall mean the work of authorship, whether in Source or
37
+ Object form, made available under the License, as indicated by a
38
+ copyright notice that is included in or attached to the work
39
+ (an example is provided in the Appendix below).
40
+
41
+ "Derivative Works" shall mean any work, whether in Source or Object
42
+ form, that is based on (or derived from) the Work and for which the
43
+ editorial revisions, annotations, elaborations, or other modifications
44
+ represent, as a whole, an original work of authorship. For the purposes
45
+ of this License, Derivative Works shall not include works that remain
46
+ separable from, or merely link (or bind by name) to the interfaces of,
47
+ the Work and Derivative Works thereof.
48
+
49
+ "Contribution" shall mean any work of authorship, including
50
+ the original version of the Work and any modifications or additions
51
+ to that Work or Derivative Works thereof, that is intentionally
52
+ submitted to Licensor for inclusion in the Work by the copyright owner
53
+ or by an individual or Legal Entity authorized to submit on behalf of
54
+ the copyright owner. For the purposes of this definition, "submitted"
55
+ means any form of electronic, verbal, or written communication sent
56
+ to the Licensor or its representatives, including but not limited to
57
+ communication on electronic mailing lists, source code control systems,
58
+ and issue tracking systems that are managed by, or on behalf of, the
59
+ Licensor for the purpose of discussing and improving the Work, but
60
+ excluding communication that is conspicuously marked or otherwise
61
+ designated in writing by the copyright owner as "Not a Contribution."
62
+
63
+ "Contributor" shall mean Licensor and any individual or Legal Entity
64
+ on behalf of whom a Contribution has been received by Licensor and
65
+ subsequently incorporated within the Work.
66
+
67
+ 2. Grant of Copyright License. Subject to the terms and conditions of
68
+ this License, each Contributor hereby grants to You a perpetual,
69
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
70
+ copyright license to reproduce, prepare Derivative Works of,
71
+ publicly display, publicly perform, sublicense, and distribute the
72
+ Work and such Derivative Works in Source or Object form.
73
+
74
+ 3. Grant of Patent License. Subject to the terms and conditions of
75
+ this License, each Contributor hereby grants to You a perpetual,
76
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
77
+ (except as stated in this section) patent license to make, have made,
78
+ use, offer to sell, sell, import, and otherwise transfer the Work,
79
+ where such license applies only to those patent claims licensable
80
+ by such Contributor that are necessarily infringed by their
81
+ Contribution(s) alone or by combination of their Contribution(s)
82
+ with the Work to which such Contribution(s) was submitted. If You
83
+ institute patent litigation against any entity (including a
84
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
85
+ or a Contribution incorporated within the Work constitutes direct
86
+ or contributory patent infringement, then any patent licenses
87
+ granted to You under this License for that Work shall terminate
88
+ as of the date such litigation is filed.
89
+
90
+ 4. Redistribution. You may reproduce and distribute copies of the
91
+ Work or Derivative Works thereof in any medium, with or without
92
+ modifications, and in Source or Object form, provided that You
93
+ meet the following conditions:
94
+
95
+ (a) You must give any other recipients of the Work or
96
+ Derivative Works a copy of this License; and
97
+
98
+ (b) You must cause any modified files to carry prominent notices
99
+ stating that You changed the files; and
100
+
101
+ (c) You must retain, in the Source form of any Derivative Works
102
+ that You distribute, all copyright, patent, trademark, and
103
+ attribution notices from the Source form of the Work,
104
+ excluding those notices that do not pertain to any part of
105
+ the Derivative Works; and
106
+
107
+ (d) If the Work includes a "NOTICE" text file as part of its
108
+ distribution, then any Derivative Works that You distribute must
109
+ include a readable copy of the attribution notices contained
110
+ within such NOTICE file, excluding those notices that do not
111
+ pertain to any part of the Derivative Works, in at least one
112
+ of the following places: within a NOTICE text file distributed
113
+ as part of the Derivative Works; within the Source form or
114
+ documentation, if provided along with the Derivative Works; or,
115
+ within a display generated by the Derivative Works, if and
116
+ wherever such third-party notices normally appear. The contents
117
+ of the NOTICE file are for informational purposes only and
118
+ do not modify the License. You may add Your own attribution
119
+ notices within Derivative Works that You distribute, alongside
120
+ or as an addendum to the NOTICE text from the Work, provided
121
+ that such additional attribution notices cannot be construed
122
+ as modifying the License.
123
+
124
+ You may add Your own copyright statement to Your modifications and
125
+ may provide additional or different license terms and conditions
126
+ for use, reproduction, or distribution of Your modifications, or
127
+ for any such Derivative Works as a whole, provided Your use,
128
+ reproduction, and distribution of the Work otherwise complies with
129
+ the conditions stated in this License.
130
+
131
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
132
+ any Contribution intentionally submitted for inclusion in the Work
133
+ by You to the Licensor shall be under the terms and conditions of
134
+ this License, without any additional terms or conditions.
135
+ Notwithstanding the above, nothing herein shall supersede or modify
136
+ the terms of any separate license agreement you may have executed
137
+ with Licensor regarding such Contributions.
138
+
139
+ 6. Trademarks. This License does not grant permission to use the trade
140
+ names, trademarks, service marks, or product names of the Licensor,
141
+ except as required for reasonable and customary use in describing the
142
+ origin of the Work and reproducing the content of the NOTICE file.
143
+
144
+ 7. Disclaimer of Warranty. Unless required by applicable law or
145
+ agreed to in writing, Licensor provides the Work (and each
146
+ Contributor provides its Contributions) on an "AS IS" BASIS,
147
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
148
+ implied, including, without limitation, any warranties or conditions
149
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
150
+ PARTICULAR PURPOSE. You are solely responsible for determining the
151
+ appropriateness of using or redistributing the Work and assume any
152
+ risks associated with Your exercise of permissions under this License.
153
+
154
+ 8. Limitation of Liability. In no event and under no legal theory,
155
+ whether in tort (including negligence), contract, or otherwise,
156
+ unless required by applicable law (such as deliberate and grossly
157
+ negligent acts) or agreed to in writing, shall any Contributor be
158
+ liable to You for damages, including any direct, indirect, special,
159
+ incidental, or consequential damages of any character arising as a
160
+ result of this License or out of the use or inability to use the
161
+ Work (including but not limited to damages for loss of goodwill,
162
+ work stoppage, computer failure or malfunction, or any and all
163
+ other commercial damages or losses), even if such Contributor
164
+ has been advised of the possibility of such damages.
165
+
166
+ 9. Accepting Warranty or Additional Liability. While redistributing
167
+ the Work or Derivative Works thereof, You may choose to offer,
168
+ and charge a fee for, acceptance of support, warranty, indemnity,
169
+ or other liability obligations and/or rights consistent with this
170
+ License. However, in accepting such obligations, You may act only
171
+ on Your own behalf and on Your sole responsibility, not on behalf
172
+ of any other Contributor, and only if You agree to indemnify,
173
+ defend, and hold each Contributor harmless for any liability
174
+ incurred by, or claims asserted against, such Contributor by reason
175
+ of your accepting any such warranty or additional liability.
176
+
177
+ END OF TERMS AND CONDITIONS
178
+
179
+ APPENDIX: How to apply the Apache License to your work.
180
+
181
+ To apply the Apache License to your work, attach the following
182
+ boilerplate notice, with the fields enclosed by brackets "[]"
183
+ replaced with your own identifying information. (Don't include
184
+ the brackets!) The text should be enclosed in the appropriate
185
+ comment syntax for the file format. We also recommend that a
186
+ file or class name and description of purpose be included on the
187
+ same "printed page" as the copyright notice for easier
188
+ identification within third-party archives.
189
+
190
+ Copyright 2026 NVIDIA CORPORATION & AFFILIATES
191
+
192
+ Licensed under the Apache License, Version 2.0 (the "License");
193
+ you may not use this file except in compliance with the License.
194
+ You may obtain a copy of the License at
195
+
196
+ http://www.apache.org/licenses/LICENSE-2.0
197
+
198
+ Unless required by applicable law or agreed to in writing, software
199
+ distributed under the License is distributed on an "AS IS" BASIS,
200
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
201
+ See the License for the specific language governing permissions and
202
+ limitations under the License.
FastGen/Makefile ADDED
@@ -0,0 +1,29 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ sinclude .env
2
+
3
+ all: help
4
+
5
+ install-linters: ## install the linters
6
+ pip install black==23.10.0 ruff==0.6.9 mypy==1.9.0 types-psutil
7
+
8
+ install: install-linters
9
+
10
+ mypy:
11
+ python3 -m mypy --check-untyped-defs --follow-imports=silent --exclude third_party/ train.py
12
+
13
+ lint:
14
+ python3 -m ruff format --check --exclude fastgen/third_party/
15
+ python3 -m ruff check --exclude fastgen/third_party/
16
+
17
+ format:
18
+ python3 -m ruff format --exclude fastgen/third_party/
19
+ python3 -m ruff check --fix --exclude fastgen/third_party/
20
+
21
+ install-fastgen:
22
+ python3 -m pip install -e .
23
+
24
+ pytest:
25
+ ulimit -n 4096 && python3 -m pytest --ignore=FASTGEN_OUTPUT --ignore=runs --ignore=tmp --ignore third_party
26
+
27
+ .PHONY: help
28
+ help:
29
+ @grep -E '^[a-zA-Z_-]+:.*?## .*$$' $(MAKEFILE_LIST) | sort | awk 'BEGIN {FS = ":.*?## "}; {printf "\033[36m%-30s\033[0m %s\n", $$1, $$2}'
FastGen/README.md ADDED
@@ -0,0 +1,182 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <h1 align="center">NVIDIA FastGen: Fast Generation from Diffusion Models</h1>
2
+
3
+ [![License](https://img.shields.io/badge/License-Apache%202.0-blue.svg)](LICENSE)
4
+ [![PyTorch](https://img.shields.io/badge/PyTorch-2.0%2B-ee4c2c?logo=pytorch&logoColor=white)](https://pytorch.org/)
5
+ [![NVIDIA](https://img.shields.io/badge/NVIDIA-76B900?logo=nvidia&logoColor=white)](https://www.nvidia.com/)
6
+
7
+ <p align="center">
8
+ <b></b> <a href="https://weilinie.github.io/">Weili Nie</a> • <a href="https://jberner.info/">Julius Berner</a> • <a href="https://research.nvidia.com/labs/genair/author/chao-liu/">Chao Liu</a> • <a href="http://latentspace.cc/">Arash Vahdat</a>
9
+ </p>
10
+
11
+ <p align="center">
12
+ <img src="assets/teaser.png" alt="FastGen header" width=100%/>
13
+ </p>
14
+
15
+ FastGen is a PyTorch-based framework for building fast generative models using various distillation and acceleration techniques. It supports:
16
+ - large-scale training with ≥10B parameters.
17
+ - different tasks and modalities, including T2I, I2V, and V2V.
18
+ - various distillation methods, including consistency models, distribution matching distillation, self-forcing, and more.
19
+
20
+ ## Repository Structure
21
+
22
+ ```
23
+ fastgen/
24
+ ├── fastgen/
25
+ │ ├── callbacks/ # Training callbacks (EMA, profiling, etc.)
26
+ │ ├── configs/ # Configuration system
27
+ │ │ ├── experiments/ # Experiment configs
28
+ │ │ └── methods/ # Method-specific configs
29
+ │ ├── datasets/ # Dataset loaders
30
+ │ ├── methods/ # Training methods (CM, DMD2, SFT, KD etc.)
31
+ │ ├── networks/ # Neural network architectures
32
+ │ ├── third_party/ # Third-party dependencies
33
+ │ ├── trainer.py # Main training loop
34
+ │ └── utils/ # Utilities (distributed, checkpointing)
35
+ ├── scripts/ # Inference and evaluation scripts
36
+ ├── tests/ # Unit tests
37
+ ├── Makefile # Development commands (lint, format, test)
38
+ └── train.py # Main training entry point
39
+ ```
40
+
41
+ ## Setup
42
+
43
+ **Recommended:** Use the provided Docker container for a consistent environment. See [CONTRIBUTING.md](CONTRIBUTING.md) for Docker setup instructions. Otherwise, create a new [conda](https://www.anaconda.com/docs/getting-started/miniconda/install) environment with `conda create -y -n fastgen python=3.12.3 pip; conda activate fastgen`.
44
+
45
+ ### Installation
46
+
47
+ ```bash
48
+ git clone https://github.com/NVlabs/FastGen.git
49
+ cd FastGen
50
+ pip install -e .
51
+ ```
52
+
53
+ ### Credentials (Optional)
54
+
55
+ For W&B logging, [get your API key](https://wandb.ai/settings) and save it to `credentials/wandb_api.txt` or set the `WANDB_API_KEY` environment variable.
56
+ Without either of these, W&B will prompt for your API key interactively.
57
+ For more details, including S3 storage and other environment variables, see [fastgen/configs/README.md](fastgen/configs/README.md#environment-variables).
58
+
59
+ ## Quick Start
60
+
61
+ Before running the following commands, download the CIFAR-10 dataset and pretrained EDM models:
62
+
63
+ ```bash
64
+ python scripts/download_data.py --dataset cifar10
65
+ ```
66
+
67
+ For other datasets and models, see [fastgen/networks/README.md](fastgen/networks/README.md) and [fastgen/datasets/README.md](fastgen/datasets/README.md).
68
+
69
+ ### Basic Training
70
+
71
+ ```bash
72
+ python train.py --config=fastgen/configs/experiments/EDM/config_dmd2_test.py
73
+ ```
74
+
75
+ If you run out-of-memory, try a smaller batch-size, e.g., `dataloader_train.batch_size=32`, which automatically uses gradient accumulation to match the global batch-size.
76
+
77
+ **Expected Output:** See the training log for a link to the run on [wandb.ai](https://wandb.ai). Training outputs go to `$FASTGEN_OUTPUT_ROOT/{project}/{group}/{name}/`. With default settings, outputs are organized as follows:
78
+ ```
79
+ FASTGEN_OUTPUT/fastgen/cifar10/debug/
80
+ ├── checkpoints/ # Model checkpoints in the format {iteration:07d}.pth
81
+ │ ├── 0001000.pth
82
+ │ └── ...
83
+ ├── config.yaml # Resolved configuration for reproducibility
84
+ ├── wandb_id.txt # W&B run ID for resuming
85
+ └── ...
86
+ ```
87
+
88
+ ### DDP/FSDP2 Training
89
+
90
+ For multi-GPU training, use DDP:
91
+
92
+ ```bash
93
+ torchrun --nproc_per_node=8 train.py \
94
+ --config=fastgen/configs/experiments/EDM/config_dmd2_test.py \
95
+ - trainer.ddp=True log_config.name=test_ddp
96
+ ```
97
+
98
+ For large models, use FSDP2 for model sharding by replacing `trainer.ddp=True` with `trainer.fsdp=True`.
99
+
100
+
101
+ ### Inference
102
+
103
+ ```bash
104
+ python scripts/inference/image_model_inference.py --config fastgen/configs/experiments/EDM/config_dmd2_test.py \
105
+ --classes=10 --prompt_file=scripts/inference/prompts/classes.txt --ckpt=FASTGEN_OUTPUT/fastgen/cifar10/debug/checkpoints/0002000.pth - log_config.name=test_inference
106
+ ```
107
+
108
+ For other inferences modes and FID evaluations, see [scripts/README.md](scripts/README.md).
109
+
110
+
111
+ ### Command-Line Overrides
112
+
113
+ Override any config parameter using Hydra-style syntax (note the `-` separator):
114
+
115
+ ```bash
116
+ python train.py --config=path/to/config.py - key=value nested.key=value
117
+ ```
118
+
119
+ ## Documentation
120
+
121
+ Detailed documentation is available in each component's README:
122
+
123
+ | Component | Documentation | Description |
124
+ |-----------|---------------|-------------|
125
+ | **Methods** | [fastgen/methods/README.md](fastgen/methods/README.md) | Training methods (sCM, MeanFlow, DMD2, Self-Forcing, etc.) |
126
+ | **Networks** | [fastgen/networks/README.md](fastgen/networks/README.md) | Network architectures (EDM, SD, SDXL, Flux, WAN, CogVideoX, Cosmos) and pretrained models |
127
+ | **Configs** | [fastgen/configs/README.md](fastgen/configs/README.md) | Configuration system, environment variables, and creating custom configs |
128
+ | **Datasets** | [fastgen/datasets/README.md](fastgen/datasets/README.md) | Dataset preparation and WebDataset loaders |
129
+ | **Callbacks** | [fastgen/callbacks/README.md](fastgen/callbacks/README.md) | Training callbacks (EMA, logging, gradient clipping, etc.) |
130
+ | **Inference** | [scripts/README.md](scripts/README.md) | Inference modes (T2I, T2V, I2V, V2V, etc.) and FID evaluation |
131
+ | **Third Party** | [fastgen/third_party/README.md](fastgen/third_party/README.md) | Third-party dependencies (Depth Anything V2, etc.) |
132
+
133
+ ## Supported Methods
134
+
135
+ | Category | Methods |
136
+ |----------|---------|
137
+ | **Consistency Models** | [CM](https://arxiv.org/abs/2303.01469), [sCM](https://arxiv.org/abs/2410.11081), [TCM](https://arxiv.org/abs/2410.14895), [MeanFlow](https://arxiv.org/abs/2505.13447) |
138
+ | **Distribution Matching** | [DMD2](https://arxiv.org/abs/2405.14867), [f-Distill](https://arxiv.org/abs/2502.15681), [LADD](https://arxiv.org/abs/2403.12015), [CausVid](https://arxiv.org/abs/2412.07772), [Self-Forcing](https://arxiv.org/abs/2506.08009) |
139
+ | **Fine-Tuning** | [SFT](https://arxiv.org/abs/2006.11239), [CausalSFT](https://arxiv.org/abs/2407.01392) |
140
+ | **Knowledge Distillation** | [KD](https://arxiv.org/abs/2101.02388), [CausalKD](https://arxiv.org/abs/2412.07772) |
141
+
142
+ See [fastgen/methods/README.md](fastgen/methods/README.md) for details.
143
+
144
+ ## Supported Networks and Data
145
+
146
+ FastGen is designed to be **agnostic to the network and data** and you can add your own architectures and datasets (see [fastgen/networks/README.md](fastgen/networks/README.md) and [fastgen/datasets/README.md](fastgen/datasets/README.md)). For reference, we provide the following implementations:
147
+
148
+ | Data | Networks |
149
+ |------|----------|
150
+ | **Image** | EDM, EDM2, DiT, SD 1.5, SDXL, Flux |
151
+ | **Video** | WAN (T2V, I2V, VACE), CogVideoX, Cosmos Predict2 |
152
+
153
+ See [fastgen/networks/README.md](fastgen/networks/README.md) for details.
154
+ Not all combinations of methods and networks are currently supported. We provide typical use-cases in our predefined configs in [fastgen/configs/experiments](fastgen/configs/experiments).
155
+
156
+ **We plan to provide distilled student checkpoints for CIFAR-10 and ImageNet soon.**
157
+
158
+ ## Contributing
159
+
160
+ We welcome contributions! Please see [CONTRIBUTING.md](CONTRIBUTING.md) for details.
161
+
162
+ We thank everyone who has helped design, build, and test FastGen!
163
+
164
+ - **Core contributors:** Weili Nie, Julius Berner, Chao Liu
165
+ - **Other contributors:** James Lucas, David Pankratz, Sihyun Yu, Willis Ma, Yilun Xu, Shengqu Cai, Xinyin Ma, Yanke Song
166
+ - **Collaborators:** Sophia Zalewski, Wei Xiong, Christian Laforte, Sajad Norouzi, Kaiwen Zheng, Miloš Hašan, Saeed Hadadan, Gene Liu, David Dynerman, Grace Lam, Pooya Jannaty, Jan Kautz, and many more.
167
+ - **Project lead:** Arash Vahdat
168
+
169
+ ## License
170
+
171
+ This project is licensed under the Apache License 2.0 - see [LICENSE](LICENSE) for details. Third-party licenses are documented in [licenses/README.md](licenses/README.md).
172
+
173
+ ## Reference
174
+
175
+ ```
176
+ @article{fastgen2026,
177
+ title={NVIDIA FastGen: Fast Generation from Diffusion Models},
178
+ author={Nie, Weili and Berner, Julius and Liu, Chao and Vahdat, Arash},
179
+ url={https://github.com/NVlabs/FastGen},
180
+ year={2026},
181
+ }
182
+ ```
FastGen/SF.md ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Self-Forcing: Quick Reference Guide
2
+
3
+ Self-Forcing is a distribution matching distillation method for causal/autoregressive video generation. It trains a fast student model to match a teacher model's output distribution while maintaining gradient flow through an autoregressive rollout process.
4
+
5
+ **Reference**: [Huang et al., 2025](https://arxiv.org/abs/2506.08009)
6
+
7
+ ---
8
+
9
+ ## Environment Setup
10
+
11
+ ```bash
12
+ conda activate sf
13
+ pip install -e .
14
+ ```
15
+
16
+ ---
17
+
18
+ ## Quick Start
19
+
20
+ ### 1. Download Checkpoint
21
+
22
+ ```bash
23
+ huggingface-cli download gdhe17/Self-Forcing checkpoints/ode_init.pt \
24
+ --local-dir FASTGEN_OUTPUT/MODEL/Self-Forcing
25
+ ```
26
+
27
+ ### 2. Run Training
28
+
29
+ #### Data-Free Training (Recommended for testing)
30
+
31
+ Uses random tensors instead of real video data. GAN discriminator is disabled.
32
+
33
+ **Single GPU:**
34
+ ```bash
35
+ python train.py --config=fastgen/configs/experiments/WanT2V/config_sf_datafree.py
36
+ ```
37
+
38
+ **Single GPU (specific device):**
39
+ ```bash
40
+ CUDA_VISIBLE_DEVICES=0 python train.py --config=fastgen/configs/experiments/WanT2V/config_sf_datafree.py
41
+ ```
42
+
43
+ **Multi-GPU (8 GPUs):**
44
+ ```bash
45
+ torchrun --nproc_per_node=8 train.py \
46
+ --config=fastgen/configs/experiments/WanT2V/config_sf_datafree.py \
47
+ - trainer.ddp=True
48
+ ```
49
+
50
+ #### Training with Real Video Data
51
+
52
+ For best quality, use real video data with GAN discriminator enabled.
53
+
54
+ Prepare WebDataset shards (see `fastgen/datasets/README.md`):
55
+ ```
56
+ /path/to/videos/
57
+ ├── 00000.tar (sample_001.mp4, sample_001.txt, ...)
58
+ ├── 00001.tar
59
+ └── ...
60
+ ```
61
+
62
+ **Single GPU:**
63
+ ```bash
64
+ python train.py --config=fastgen/configs/experiments/WanT2V/config_sf.py \
65
+ - dataloader_train.datatags='["WDS:/path/to/your/videos"]'
66
+ ```
67
+
68
+ **Multi-GPU (8 GPUs):**
69
+ ```bash
70
+ torchrun --nproc_per_node=8 train.py \
71
+ --config=fastgen/configs/experiments/WanT2V/config_sf.py \
72
+ - trainer.ddp=True dataloader_train.datatags='["WDS:/path/to/your/videos"]'
73
+ ```
74
+
75
+ ---
76
+
77
+ ## Available Configs
78
+
79
+ | Model | Config Path | Notes |
80
+ |-------|-------------|-------|
81
+ | WAN T2V (light) | `fastgen/configs/experiments/WanT2V/config_sf_datafree_light.py` | Data-free, reduced resolution (~8-12GB VRAM) |
82
+ | WAN T2V (data-free) | `fastgen/configs/experiments/WanT2V/config_sf_datafree.py` | Data-free, full 480p (~24GB VRAM) |
83
+ | WAN T2V (with data) | `fastgen/configs/experiments/WanT2V/config_sf.py` | Requires video data, GAN enabled |
84
+ | VACE-WAN V2V | `fastgen/configs/experiments/WanV2V/config_sf.py` | Requires video data |
85
+
86
+ ---
87
+
88
+ ## Key Configuration Parameters
89
+
90
+ ### Self-Forcing Specific
91
+
92
+ | Parameter | Default | Description |
93
+ |-----------|---------|-------------|
94
+ | `enable_gradient_in_rollout` | `True` | Enable gradients at exit step during rollout |
95
+ | `start_gradient_frame` | `0` | Frame index to start gradient tracking |
96
+ | `same_step_across_blocks` | `True` | Use same exit step for all blocks |
97
+ | `last_step_only` | `False` | Exit only at the last denoising step |
98
+ | `context_noise` | `0.0` | Noise level added to cached context (0-1) |
99
+
100
+ ### Training Settings
101
+
102
+ | Parameter | Default | Description |
103
+ |-----------|---------|-------------|
104
+ | `student_sample_steps` | `4` | Number of denoising steps |
105
+ | `student_update_freq` | `5` | Student update frequency |
106
+ | `gan_loss_weight_gen` | `0.001` | GAN loss weight for generator |
107
+
108
+ ---
109
+
110
+ ## Custom Training Example
111
+
112
+ ```bash
113
+ python train.py --config=fastgen/configs/experiments/WanT2V/config_sf.py \
114
+ - model.gan_loss_weight_gen=0.005 \
115
+ model.student_sample_steps=6 \
116
+ model.context_noise=0.1 \
117
+ trainer.max_iter=10000 \
118
+ dataloader_train.batch_size=2
119
+ ```
120
+
121
+ ---
122
+
123
+ ## How It Works
124
+
125
+ 1. **Autoregressive Rollout**: Processes video chunk-by-chunk with KV-caching
126
+ 2. **Gradient Tracking**: Enables gradients at stochastic exit steps during rollout
127
+ 3. **Distribution Matching**: Combines VSD loss with GAN training
128
+ 4. **Alternating Updates**:
129
+ - Student updates (VSD + GAN loss)
130
+ - Discriminator/fake score updates (adversarial training)
131
+
132
+ ---
133
+
134
+ ## Key Files
135
+
136
+ - **Method**: `fastgen/methods/distribution_matching/self_forcing.py`
137
+ - **Config Template**: `fastgen/configs/methods/config_self_forcing.py`
138
+ - **Tests**: `tests/test_sfmodel.py`
139
+
140
+ ---
141
+
142
+ ## Testing
143
+
144
+ ```bash
145
+ pytest tests/test_sfmodel.py -v
146
+ ```
FastGen/assets/teaser.png ADDED

Git LFS Details

  • SHA256: a824dd86d32d840adcf4d4d182930b4544cc11c7ae5237fb27d5ac05e2f42ddf
  • Pointer size: 132 Bytes
  • Size of remote file: 1.46 MB
FastGen/fastgen/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
FastGen/fastgen/callbacks/README.md ADDED
@@ -0,0 +1,59 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Callbacks
2
+
3
+ Training callbacks for FastGen. All callbacks inherit from `Callback` and hook into the training lifecycle.
4
+
5
+ ## Available Callbacks
6
+
7
+ | Callback | Description |
8
+ |----------|-------------|
9
+ | [`EMACallback`](ema.py) | Exponential Moving Average updates (constant, power, or halflife schedules) |
10
+ | [`GradClipCallback`](grad_clip.py) | Gradient norm clipping with NaN/Inf handling |
11
+ | [`WandbCallback`](wandb.py) | Weights & Biases logging (metrics, images, videos) |
12
+ | [`GPUStatsCallback`](gpu_stats.py) | GPU memory and utilization logging |
13
+ | [`MemTrackerCallback`](gpu_mem_profiler.py) | GPU memory profiling with HTML visualizations |
14
+ | [`TrainProfilerCallback`](train_profiler.py) | Training speed and timing breakdown |
15
+ | [`ParamCountCallback`](param_count.py) | Parameter count logging |
16
+ | [`ForcedWeightNormCallback`](forced_weight_norm.py) | Forced weight normalization (EDM2) |
17
+ | [`CTScheduleCallback`](ct_schedule.py) | Consistency training curriculum schedule |
18
+
19
+ ## Usage
20
+
21
+ Configure callbacks as a dictionary in the config:
22
+
23
+ ```python
24
+ from fastgen.configs.callbacks import WANDB_CALLBACK, GradClip_CALLBACK, EMA_CALLBACK
25
+ from omegaconf import DictConfig
26
+
27
+ config.trainer.callbacks = DictConfig({
28
+ **WANDB_CALLBACK,
29
+ **GradClip_CALLBACK,
30
+ **EMA_CALLBACK,
31
+ })
32
+ ```
33
+
34
+ Predefined configs are in `fastgen/configs/callbacks.py`.
35
+
36
+ ## Lifecycle Hooks
37
+
38
+ Callbacks can override these hooks:
39
+
40
+ | Phase | Hooks |
41
+ |-------|-------|
42
+ | Setup | `on_app_begin`, `on_model_init_start/end`, `on_optimizer_init_start/end`, `on_load_checkpoint_start/end`, `on_dataloader_init_start/end` |
43
+ | Training | `on_train_begin/end`, `on_training_step_begin/end`, `on_training_accum_step_begin`, `on_backward_begin`, `on_optimizer_step_begin` |
44
+ | Validation | `on_validation_begin/end`, `on_validation_step_begin/end` |
45
+ | Checkpointing | `on_save_checkpoint_start/success/end`, `state_dict`, `load_state_dict` |
46
+ | Cleanup | `on_app_end` |
47
+
48
+ ## Custom Callbacks
49
+
50
+ ```python
51
+ from fastgen.callbacks.callback import Callback
52
+
53
+ class MyCallback(Callback):
54
+ def on_training_step_end(self, model, data_batch, output_batch, loss_dict, iteration=0):
55
+ if iteration % self.config.trainer.logging_iter == 0:
56
+ print(f"Step {iteration}: loss = {loss_dict['total_loss'].item():.4f}")
57
+ ```
58
+
59
+ Callbacks have access to `self.config` and `self.trainer` after initialization.
FastGen/fastgen/callbacks/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
FastGen/fastgen/callbacks/callback.py ADDED
@@ -0,0 +1,183 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+ from typing import Callable, Any, TYPE_CHECKING
6
+ import torch
7
+ from torch.utils.data import DataLoader
8
+
9
+ from fastgen.utils import instantiate
10
+ import fastgen.utils.logging_utils as logger
11
+
12
+ if TYPE_CHECKING:
13
+ from fastgen.configs.config import BaseConfig
14
+ from fastgen.trainer import Trainer
15
+ from fastgen.methods import FastGenModel
16
+
17
+
18
+ class CallbackDict:
19
+ def __init__(self, config: BaseConfig, trainer: Trainer):
20
+ self._callbacks = {}
21
+ callback_configs = config.trainer.callbacks
22
+ if callback_configs:
23
+ if isinstance(callback_configs, list):
24
+ logger.warning(msg="The 'config.trainer.callbacks' parameter should be a dict instead of a list. ")
25
+ callback_configs = {f"callback_{k}": v for k, v in enumerate(callback_configs)}
26
+ for callback_name, current_callback_cfg in callback_configs.items():
27
+ if "_target_" not in current_callback_cfg:
28
+ logger.critical(
29
+ f"Callback {callback_name} is missing the '_target_' field. \n Skip {current_callback_cfg}"
30
+ )
31
+ continue
32
+ logger.critical(f"Instantiating callback {callback_name}: {current_callback_cfg}")
33
+ _callback = instantiate(current_callback_cfg)
34
+ assert isinstance(_callback, Callback), f"{current_callback_cfg} is not a valid callback."
35
+ _callback.config = config
36
+ _callback.trainer = trainer
37
+ _callback.on_app_begin()
38
+ self._callbacks[callback_name] = _callback
39
+
40
+ def __getattr__(self, method_name: str) -> Callable:
41
+ def load_state_dict(state_dict: dict[str, Any]) -> None:
42
+ for name, callback in self._callbacks.items():
43
+ if name in state_dict:
44
+ callback.load_state_dict(state_dict[name])
45
+ else:
46
+ logger.warning(f"Callback {name} not found in checkpoint.")
47
+
48
+ def state_dict() -> dict[str, Any]:
49
+ return {name: self._callbacks[name].state_dict() for name in self._callbacks}
50
+
51
+ def callbacks_wrapper(*args, **kwargs):
52
+ for callback in self._callbacks.values():
53
+ assert hasattr(callback, method_name)
54
+ method = getattr(callback, method_name)
55
+ assert callable(method), f"{method_name} is not callable."
56
+ method(*args, **kwargs)
57
+
58
+ if method_name == "state_dict":
59
+ return state_dict
60
+ if method_name == "load_state_dict":
61
+ return load_state_dict
62
+ return callbacks_wrapper
63
+
64
+
65
+ class Callback:
66
+ config: "BaseConfig"
67
+ trainer: "Trainer"
68
+
69
+ def on_app_begin(self) -> None:
70
+ pass
71
+
72
+ def on_model_init_start(self, model: FastGenModel) -> None:
73
+ pass
74
+
75
+ def on_model_init_end(self, model: FastGenModel | torch.nn.parallel.DistributedDataParallel) -> None:
76
+ pass
77
+
78
+ def on_optimizer_init_start(self, model: FastGenModel) -> None:
79
+ pass
80
+
81
+ def on_optimizer_init_end(self, model: FastGenModel) -> None:
82
+ pass
83
+
84
+ def on_load_checkpoint_start(self, model: FastGenModel) -> None:
85
+ pass
86
+
87
+ def on_load_checkpoint_end(self, model: FastGenModel, iteration: int = 0) -> None:
88
+ pass
89
+
90
+ def on_dataloader_init_start(self, model: FastGenModel, iteration: int = 0) -> None:
91
+ pass
92
+
93
+ def on_dataloader_init_end(
94
+ self, model: FastGenModel, dataloader_train: DataLoader, dataloader_val: DataLoader, iteration: int = 0
95
+ ) -> None:
96
+ pass
97
+
98
+ def on_train_begin(self, model: FastGenModel, iteration: int = 0) -> None:
99
+ pass
100
+
101
+ def on_training_step_begin(
102
+ self,
103
+ model: FastGenModel,
104
+ iteration: int = 0,
105
+ ) -> None:
106
+ pass
107
+
108
+ def on_training_accum_step_begin(
109
+ self,
110
+ model: FastGenModel,
111
+ data_batch: dict[str, torch.Tensor],
112
+ iteration: int = 0,
113
+ accum_iter: int = 0,
114
+ ) -> None:
115
+ pass
116
+
117
+ def on_backward_begin(
118
+ self,
119
+ model: FastGenModel,
120
+ data_batch: dict[str, torch.Tensor],
121
+ output_batch: dict[str, torch.Tensor | Callable],
122
+ loss_dict: dict[str, torch.Tensor],
123
+ iteration: int = 0,
124
+ accum_iter: int = 0,
125
+ ) -> None:
126
+ pass
127
+
128
+ def on_training_step_end(
129
+ self,
130
+ model: FastGenModel,
131
+ data_batch: dict[str, torch.Tensor],
132
+ output_batch: dict[str, torch.Tensor | Callable],
133
+ loss_dict: dict[str, torch.Tensor],
134
+ iteration: int = 0,
135
+ ) -> None:
136
+ pass
137
+
138
+ def on_optimizer_step_begin(self, model: FastGenModel, iteration: int = 0) -> None:
139
+ pass
140
+
141
+ def on_train_end(self, model: FastGenModel, iteration: int = 0) -> None:
142
+ pass
143
+
144
+ def on_validation_begin(self, model: FastGenModel, iteration: int = 0, idx: int = 0) -> None:
145
+ pass
146
+
147
+ def on_validation_step_begin(
148
+ self, model: FastGenModel, data_batch: dict[str, torch.Tensor], step: int = 0, iteration: int = 0, idx: int = 0
149
+ ) -> None:
150
+ pass
151
+
152
+ def on_validation_step_end(
153
+ self,
154
+ model: FastGenModel,
155
+ data_batch: dict[str, torch.Tensor],
156
+ output_batch: dict[str, torch.Tensor | Callable],
157
+ loss_dict: dict[str, torch.Tensor],
158
+ step: int = 0,
159
+ iteration: int = 0,
160
+ idx: int = 0,
161
+ ) -> None:
162
+ pass
163
+
164
+ def on_validation_end(self, model: FastGenModel, iteration: int = 0, idx: int = 0) -> None:
165
+ pass
166
+
167
+ def on_save_checkpoint_start(self, model: FastGenModel, iteration: int = 0) -> None:
168
+ pass
169
+
170
+ def on_save_checkpoint_success(self, model: FastGenModel, iteration: int = 0, path: str = None) -> None:
171
+ pass
172
+
173
+ def on_save_checkpoint_end(self, model: FastGenModel, iteration: int = 0) -> None:
174
+ pass
175
+
176
+ def on_app_end(self, model: FastGenModel, iteration: int = 0) -> None:
177
+ pass
178
+
179
+ def state_dict(self) -> dict[str, Any]:
180
+ return {}
181
+
182
+ def load_state_dict(self, state_dict: dict[str, Any]) -> None:
183
+ pass
FastGen/fastgen/callbacks/ct_schedule.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+ from typing import Callable, TYPE_CHECKING
6
+ import wandb
7
+
8
+ import torch
9
+
10
+ from fastgen.callbacks.callback import Callback
11
+ import fastgen.utils.logging_utils as logger
12
+ from fastgen.utils.basic_utils import get_batch_size_total
13
+ from fastgen.utils.distributed import is_rank0
14
+
15
+ if TYPE_CHECKING:
16
+ from fastgen.methods import FastGenModel
17
+ from fastgen.configs.config import BaseConfig
18
+
19
+
20
+ class CTScheduleCallback(Callback):
21
+ config: "BaseConfig"
22
+
23
+ def __init__(
24
+ self,
25
+ q: float = 2.0,
26
+ ratio_limit: float = 0.999,
27
+ kimg_per_stage: int = 12500,
28
+ batch_size: int = 1,
29
+ ):
30
+ self.q = q
31
+ self.ratio_limit = ratio_limit
32
+ self.kimg_per_stage = kimg_per_stage
33
+ self.batch_size = batch_size
34
+
35
+ self.stage = 0
36
+ self.ratio = 0.0
37
+
38
+ def _get_cur_stage(self, model, iteration):
39
+ # Start from the saved iteration of the first-stage model in TCM
40
+ if hasattr(model, "resume_iter"):
41
+ assert isinstance(model.resume_iter, int)
42
+ iteration = iteration + model.resume_iter
43
+
44
+ batch_size = self.batch_size
45
+ if hasattr(self, "config"):
46
+ # override the batch_size using self.config
47
+ batch_size = get_batch_size_total(self.config)
48
+
49
+ cur_nimg = iteration * batch_size
50
+ stage = cur_nimg // (self.kimg_per_stage * 1000)
51
+ return stage, cur_nimg
52
+
53
+ def _update_schedule(self, stage):
54
+ self.stage = stage
55
+ self.ratio = 1 - 1 / self.q ** (stage + 1)
56
+ if self.ratio > self.ratio_limit:
57
+ logger.info(f"Clipping ratio from {self.ratio} -> {self.ratio_limit}")
58
+ self.ratio = self.ratio_limit
59
+
60
+ def on_train_begin(self, model: FastGenModel, iteration: int = 0) -> None:
61
+ stage, _ = self._get_cur_stage(model, iteration)
62
+ self._update_schedule(stage)
63
+ setattr(model, "ratio", self.ratio)
64
+
65
+ def on_training_step_end(
66
+ self,
67
+ model: FastGenModel,
68
+ data_batch: dict[str, torch.Tensor],
69
+ output_batch: dict[str, torch.Tensor | Callable],
70
+ loss_dict: dict[str, torch.Tensor],
71
+ iteration: int = 0,
72
+ ) -> None:
73
+ del data_batch, output_batch, loss_dict
74
+ new_stage, cur_nimg = self._get_cur_stage(model, iteration)
75
+ if new_stage > self.stage:
76
+ self._update_schedule(new_stage)
77
+ setattr(model, "ratio", self.ratio)
78
+
79
+ if hasattr(self, "config"):
80
+ # only wandb log when config exists
81
+ if iteration % self.config.trainer.logging_iter == 0 and is_rank0():
82
+ if wandb.run:
83
+ wandb.log({"ct_schedule/kimg": cur_nimg / 1e3, "ct_schedule/ratio": self.ratio}, step=iteration)
FastGen/fastgen/callbacks/ema.py ADDED
@@ -0,0 +1,155 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import Callable, TYPE_CHECKING, Optional
7
+
8
+ import torch
9
+ import wandb
10
+
11
+ from fastgen.callbacks.callback import Callback
12
+ from fastgen.utils.basic_utils import get_batch_size_total
13
+ from fastgen.utils.distributed import synchronize, is_rank0
14
+ import fastgen.utils.logging_utils as logger
15
+
16
+ if TYPE_CHECKING:
17
+ from fastgen.methods import FastGenModel
18
+
19
+
20
+ class EMACallback(Callback):
21
+ def __init__(
22
+ self,
23
+ type: str = "constant",
24
+ # params for type=constant
25
+ beta: float = 0.9999,
26
+ # params for type=power
27
+ gamma: float = 16.97,
28
+ # params for type=halflife
29
+ ema_halflife_kimg: float = 500,
30
+ ema_rampup_ratio: Optional[float] = 0.05,
31
+ ema_name: str = "ema",
32
+ batch_size: int = 1, # overwritten by self.config if it exists
33
+ fsdp: bool = False, # overwritten by self.config if it exists
34
+ ):
35
+ self.type = type
36
+ self.beta = beta
37
+ self.gamma = gamma
38
+ self.ema_halflife_kimg = ema_halflife_kimg
39
+ self.ema_rampup_ratio = ema_rampup_ratio
40
+ self.ema_name = ema_name
41
+ self.batch_size = batch_size
42
+ self._is_fsdp = fsdp
43
+ self._enabled = True
44
+
45
+ def on_app_begin(self) -> None:
46
+ if hasattr(self, "config"):
47
+ # override using config
48
+ self._is_fsdp = self.config.trainer.fsdp
49
+ self.batch_size = get_batch_size_total(self.config)
50
+
51
+ def on_model_init_end(
52
+ self, model: FastGenModel | torch.nn.parallel.DistributedDataParallel, iteration: int = 0
53
+ ) -> None:
54
+ # Unwrap DDP if needed to access the original model's attributes
55
+ if hasattr(model, "module"):
56
+ model = model.module
57
+
58
+ # check ema initialization
59
+ ema = getattr(model, self.ema_name, None)
60
+ if ema is None:
61
+ self._enabled = False
62
+ logger.info(f"EMA {self.ema_name} is not enabled, skipping callback.")
63
+ return
64
+
65
+ assert ema.training is False, f"EMA {self.ema_name} should be in eval mode"
66
+ for name, p_net in ema.named_parameters():
67
+ assert not p_net.requires_grad, f"EMA parameter {name} should not require gradients"
68
+
69
+ def _total_iteration(self, model: FastGenModel, iteration: int) -> int:
70
+ if hasattr(model, "resume_iter"):
71
+ assert isinstance(model.resume_iter, int)
72
+ iteration = iteration + model.resume_iter
73
+ return iteration
74
+
75
+ def _power_function_beta(self, iteration):
76
+ beta = (1 - 1 / iteration) ** (self.gamma + 1)
77
+ return beta
78
+
79
+ def _get_cur_nimg(self, iteration):
80
+ cur_nimg = iteration * self.batch_size
81
+ return self.batch_size, cur_nimg
82
+
83
+ def _halflife_beta(self, iteration):
84
+ ema_halflife_nimg = self.ema_halflife_kimg * 1000
85
+ batch_size, cur_nimg = self._get_cur_nimg(iteration)
86
+ if self.ema_rampup_ratio is not None:
87
+ ema_halflife_nimg = min(ema_halflife_nimg, cur_nimg * self.ema_rampup_ratio)
88
+ ema_beta = 0.5 ** (batch_size / max(ema_halflife_nimg, 1e-8))
89
+ return ema_beta
90
+
91
+ def on_training_step_end(
92
+ self,
93
+ model: FastGenModel,
94
+ data_batch: dict[str, torch.Tensor],
95
+ output_batch: dict[str, torch.Tensor | Callable],
96
+ loss_dict: dict[str, torch.Tensor],
97
+ iteration: int = 0,
98
+ ) -> None:
99
+ del data_batch, output_batch, loss_dict
100
+
101
+ # Check if EMA is enabled
102
+ if not self._enabled:
103
+ return
104
+
105
+ if self.type == "constant":
106
+ beta = self.beta
107
+ elif self.type == "power":
108
+ beta = self._power_function_beta(self._total_iteration(model, iteration))
109
+ elif self.type == "halflife":
110
+ beta = self._halflife_beta(self._total_iteration(model, iteration))
111
+ else:
112
+ raise ValueError(f"Invalid {self.ema_name} type: {self.type}")
113
+
114
+ with torch.no_grad():
115
+ ema = getattr(model, self.ema_name)
116
+ ema_state_dict = ema.state_dict()
117
+
118
+ for name, p_net in model.net.named_parameters():
119
+ if self._is_fsdp and hasattr(p_net, "full_tensor"):
120
+ # Gather the full tensor from all ranks if using FSDP with DTensor
121
+ # When CPU offloading is enabled, we need to move to CUDA first because
122
+ # full_tensor() performs an all_gather which requires a CUDA backend
123
+ if p_net.device.type == "cpu":
124
+ # Move local shard to CUDA, gather, then the result stays on CUDA
125
+ # which is fine since we'll copy to EMA (which handles device placement)
126
+ full_tensor = p_net.to("cuda").full_tensor()
127
+ else:
128
+ full_tensor = p_net.full_tensor()
129
+ else:
130
+ full_tensor = p_net
131
+ # Strip checkpoint wrapper prefix if present (EMA doesn't have checkpointing)
132
+ ema_name = name.replace("_checkpoint_wrapped_module.", "")
133
+ # Cast to EMA dtype and device (typically float32 on CPU) for lerp_ compatibility
134
+ if ema_name in ema_state_dict:
135
+ ema_param = ema_state_dict[ema_name]
136
+ ema_param.lerp_(full_tensor.to(device=ema_param.device, dtype=ema_param.dtype), 1.0 - beta)
137
+ elif iteration == 1:
138
+ # only warn on first iteration if parameter is not found
139
+ logger.warning(f"EMA parameter {ema_name} not found in EMA state dict, skipping update.")
140
+
141
+ # FSDP2 doesn't shard buffers, so we can just copy them
142
+ for name, p_net in model.net.named_buffers():
143
+ if name in ema_state_dict:
144
+ ema_param = ema_state_dict[name]
145
+ ema_param.copy_(p_net.to(device=ema_param.device, dtype=ema_param.dtype))
146
+ elif iteration == 1:
147
+ # only warn on first iteration if buffer is not found
148
+ logger.warning(f"EMA buffer {name} not found in EMA state dict, skipping update.")
149
+
150
+ if hasattr(self, "config"):
151
+ # only wandb log when config exists
152
+ if iteration % self.config.trainer.logging_iter == 0 and is_rank0():
153
+ if wandb.run:
154
+ wandb.log({f"ema/{self.ema_name}_beta": beta}, step=iteration)
155
+ synchronize()
FastGen/fastgen/callbacks/forced_weight_norm.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from typing import TYPE_CHECKING
7
+
8
+
9
+ from fastgen.callbacks.callback import Callback
10
+ import fastgen.utils.logging_utils as logger
11
+
12
+ if TYPE_CHECKING:
13
+ from fastgen.methods import FastGenModel
14
+
15
+
16
+ class ForcedWeightNormCallback(Callback):
17
+ def on_training_accum_step_begin(
18
+ self,
19
+ model: FastGenModel,
20
+ *args,
21
+ **kwargs,
22
+ ) -> None:
23
+ if hasattr(model.net, "forced_weight_normalization"):
24
+ model.net.forced_weight_normalization()
25
+ else:
26
+ logger.warning(
27
+ "Enabled ForcedWeightNormCallback but model.net does not have the forced_weight_normalization method."
28
+ )
FastGen/fastgen/callbacks/gpu_mem_profiler.py ADDED
@@ -0,0 +1,134 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import torch
7
+ import os
8
+ from fastgen.utils import logging_utils as logger
9
+ from fastgen.callbacks.callback import Callback
10
+ import atexit
11
+ import pickle
12
+ from typing import Callable, Optional, TYPE_CHECKING
13
+ import base64
14
+ import json
15
+
16
+ if TYPE_CHECKING:
17
+ from fastgen.methods import FastGenModel
18
+
19
+
20
+ def create_dump(dump_path):
21
+ logger.critical(f"Creating {dump_path}")
22
+ if not dump_path.endswith("html"):
23
+ print(f"[{__file__}] create_dump produces an HTML file but was called with {dump_path=}")
24
+ torch.cuda.memory._dump_snapshot(dump_path + ".pickle")
25
+ with open(dump_path + ".pickle", "rb") as f:
26
+ data = pickle.load(f)
27
+ _memory_viz_template = r"""
28
+ <!DOCTYPE html>
29
+ <html>
30
+ <head>
31
+ </head>
32
+ <body>
33
+ <script type="module">
34
+ import {add_local_files} from "https://cdn.jsdelivr.net/gh/pytorch/pytorch@main/torch/utils/viz/MemoryViz.js"
35
+ const local_files = $SNAPSHOT
36
+ add_local_files(local_files, $VIZ_KIND)
37
+ </script>
38
+ </body>
39
+ """
40
+
41
+ # find which GPU was active
42
+ idx_device = -1
43
+ for i in range(8):
44
+ if data["device_traces"][i]:
45
+ idx_device = i
46
+ break
47
+
48
+ traces = data["device_traces"][idx_device] # create an aliasing variable for convenience
49
+ traces = [
50
+ d for d in traces if d["action"] == "alloc" or d["action"] == "free_completed"
51
+ ] # only the `alloc` and `free_completed` events matter for our visualization
52
+
53
+ for d in traces:
54
+ d["fastgen_frames"] = [
55
+ f for f in d["frames"] if "fastgen" in f["filename"]
56
+ ] # get the callstack frames from fastgen code (e.g. ignore frames in pytorch/other libraries)
57
+ if not d["fastgen_frames"]:
58
+ d["fastgen_frames"] = d["frames"]
59
+
60
+ # run through the trace and find allocations that were allocated but never freed
61
+ set_alloced_addrs: dict = {}
62
+ for d in traces:
63
+ if d["action"] == "alloc":
64
+ set_alloced_addrs[d["addr"]] = d
65
+ elif d["action"] == "free_completed":
66
+ if d["addr"] in set_alloced_addrs:
67
+ del set_alloced_addrs[d["addr"]]
68
+ else:
69
+ raise NotImplementedError(f"{d['action']}")
70
+
71
+ never_freed_traces = list(set_alloced_addrs.values())
72
+ KB = 1 << 10
73
+ never_freed_traces = [t for t in never_freed_traces if t["size"] > KB] # get rid of allocations below 1 KB
74
+
75
+ # now proceed through the trace (guarenteed to be all `alloc` events as we removed all free events).
76
+ # for each pair of alloc events, merge them iff they share a common fastgen ancestor.
77
+ # Merging events is useful as it both speeds up the visualization rendering and also makes it more understandable.
78
+ i = 0
79
+ while i < len(never_freed_traces) - 1:
80
+ curr_frames = never_freed_traces[i]["fastgen_frames"]
81
+ next_frames = never_freed_traces[i + 1]["fastgen_frames"]
82
+ if (
83
+ curr_frames and next_frames and curr_frames[0] == next_frames[0]
84
+ ): # TODO: probably should compare the full callstack
85
+ # same ancestor, delete next event and add its size to current event
86
+ never_freed_traces[i]["size"] += never_freed_traces[i + 1]["size"]
87
+ never_freed_traces.pop(i + 1)
88
+ else:
89
+ i += 1 # different ancestor, do not combine and move on
90
+
91
+ data["device_traces"][idx_device] = never_freed_traces # update the trace to only be the merged-alloc events
92
+ data["segments"] = [] # shrink the trace, unused in memory timeline
93
+ data["external_annotations"] = [] # shrink the trace, unused in memory timeline
94
+ buffer = pickle.dumps(data)
95
+ buffer += b"\x00" * (3 - len(buffer) % 3)
96
+ encoded_buffer = base64.b64encode(buffer).decode("utf-8")
97
+ json_format = json.dumps([{"name": "snapshot.pickle", "base64": encoded_buffer}])
98
+ html_src = _memory_viz_template.replace("$VIZ_KIND", repr("Active Memory Timeline")).replace(
99
+ "$SNAPSHOT", json_format
100
+ )
101
+ with open(dump_path, "w") as f:
102
+ f.write(html_src)
103
+
104
+
105
+ class MemTrackerCallback(Callback):
106
+ def __init__(self, save_every_n_iters: Optional[int] = None, deactivate_after_n_iters: int = 100):
107
+ def close_and_save():
108
+ create_dump(
109
+ f"{os.environ.get('FASTGEN_OUTPUT_ROOT', 'FASTGEN_OUTPUT')}/crash_rank{os.environ.get('RANK', '0')}.html"
110
+ )
111
+
112
+ self.deactivate_after_n_iters = deactivate_after_n_iters # Deactivate eventually to prevent leaking host memory
113
+ self.save_every_n_iters = save_every_n_iters
114
+ self.atexit_fn = close_and_save
115
+ atexit.register(self.atexit_fn)
116
+
117
+ def on_app_begin(self):
118
+ logger.info("[MemTrackerCallback] Tracking peak memory usage")
119
+ torch.cuda.memory._record_memory_history(stacks="python")
120
+
121
+ def on_training_step_end(
122
+ self,
123
+ model: FastGenModel,
124
+ data_batch: dict[str, torch.Tensor],
125
+ output_batch: dict[str, torch.Tensor | Callable],
126
+ loss_dict: dict[str, torch.Tensor],
127
+ iteration: int = 0,
128
+ ) -> None:
129
+ if iteration > self.deactivate_after_n_iters:
130
+ torch.cuda.memory._record_memory_history(enabled=None) # frees pytorch tracking datastructures
131
+ if self.save_every_n_iters is not None and (iteration % self.save_every_n_iters) == 0:
132
+ create_dump(
133
+ f"{os.environ.get('FASTGEN_OUTPUT_ROOT', 'FASTGEN_OUTPUT')}/step{iteration}_rank{os.environ.get('RANK', '0')}.html"
134
+ )
FastGen/fastgen/callbacks/gpu_stats.py ADDED
@@ -0,0 +1,92 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import os
7
+ from typing import TYPE_CHECKING, Callable, Any, Dict, List
8
+
9
+ import pandas as pd
10
+ import psutil
11
+ import torch
12
+
13
+ from fastgen.callbacks.callback import Callback
14
+ from fastgen.utils.distributed import world_size, is_rank0, synchronize
15
+ import fastgen.utils.logging_utils as logger
16
+
17
+ if TYPE_CHECKING:
18
+ from fastgen.methods import FastGenModel
19
+
20
+
21
+ def log_prof_data(data_list: List[Dict[str, Any]]):
22
+ # Create a table to log data with rank information
23
+ metrics = list(data_list[0].keys())
24
+
25
+ # Initialize dictionaries to store min and max values for each metric
26
+ min_values = {key: float("inf") for key in metrics}
27
+ max_values = {key: float("-inf") for key in metrics}
28
+ sum_values = {key: 0.0 for key in metrics}
29
+
30
+ count = 0
31
+
32
+ for _rank, prof_data in enumerate(data_list):
33
+ count += 1
34
+
35
+ # Update min, max, and sum values
36
+ for key in metrics:
37
+ min_values[key] = min(min_values[key], prof_data[key])
38
+ max_values[key] = max(max_values[key], prof_data[key])
39
+ sum_values[key] += prof_data[key]
40
+
41
+ # Calculate average values
42
+ avg_values = {key: sum_values[key] / count for key in metrics}
43
+ summary_df = pd.DataFrame({"Avg": avg_values, "Max": max_values, "Min": min_values})
44
+
45
+ logger.info(f"GPU stats:\n{summary_df.to_string()}")
46
+
47
+
48
+ class GPUStatsCallback(Callback):
49
+ def __init__(self, every_n: int = 100):
50
+ self.every_n = every_n
51
+
52
+ def on_train_begin(self, model: FastGenModel, iteration: int = 0):
53
+ torch.cuda.reset_peak_memory_stats()
54
+ if hasattr(self, "config"):
55
+ # overwritten by logging_iter if self.config exists
56
+ self.every_n = self.config.trainer.logging_iter
57
+ logger.info(f"every_n to measure gpus stats: {self.every_n}")
58
+
59
+ def on_training_step_end(
60
+ self,
61
+ model: FastGenModel,
62
+ data_batch: dict[str, torch.Tensor],
63
+ output_batch: dict[str, torch.Tensor | Callable],
64
+ loss_dict: dict[str, torch.Tensor],
65
+ iteration: int = 0,
66
+ ) -> None:
67
+ del data_batch, output_batch, loss_dict
68
+ if iteration % self.every_n == 0:
69
+ cur_process = psutil.Process(os.getpid())
70
+ cpu_memory_usage = sum(p.memory_info().rss for p in [cur_process] + cur_process.children(recursive=True))
71
+ cpu_mem_gb = cpu_memory_usage / (1024**3)
72
+
73
+ peak_gpu_mem_gb = torch.cuda.max_memory_allocated() / (1024**3)
74
+ peak_gpu_mem_reserved_gb = torch.cuda.max_memory_reserved() / (1024**3)
75
+ util = torch.cuda.utilization()
76
+
77
+ prof_data = {
78
+ "cpu_mem_gb": float(cpu_mem_gb),
79
+ "peak_gpu_mem_gb": float(peak_gpu_mem_gb),
80
+ "peak_gpu_mem_reserved_gb": float(peak_gpu_mem_reserved_gb),
81
+ "util": float(util),
82
+ }
83
+
84
+ synchronize()
85
+ data_list = [prof_data] * world_size()
86
+ # this is blocking by default
87
+ if world_size() > 1:
88
+ torch.distributed.all_gather_object(data_list, prof_data)
89
+
90
+ if is_rank0():
91
+ log_prof_data(data_list)
92
+ synchronize()
FastGen/fastgen/callbacks/grad_clip.py ADDED
@@ -0,0 +1,222 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ from contextlib import contextmanager
7
+ from typing import TYPE_CHECKING, Optional
8
+ import wandb
9
+
10
+ import torch
11
+ from torch.distributed.tensor import DTensor
12
+
13
+ from fastgen.callbacks.callback import Callback
14
+ from fastgen.utils.distributed import is_rank0, world_size
15
+ import fastgen.utils.logging_utils as logger
16
+
17
+ if TYPE_CHECKING:
18
+ from fastgen.methods import FastGenModel
19
+
20
+
21
+ @contextmanager
22
+ def cast_gradients_dtype(model, dtype=torch.float32, enabled=True):
23
+ if enabled:
24
+ try:
25
+ # Cast gradients to the desired dtype
26
+ for param in model.parameters():
27
+ if param.grad is not None and param.grad.dtype != dtype:
28
+ param.grad.data = param.grad.data.to(dtype)
29
+ yield
30
+ finally:
31
+ # Restore original gradient dtypes
32
+ for param in model.parameters():
33
+ if param.grad is not None and param.grad.dtype != param.dtype:
34
+ param.grad.data = param.grad.data.to(param.dtype)
35
+ else:
36
+ yield
37
+
38
+
39
+ def clip_grad_norm_fsdp(
40
+ parameters,
41
+ max_norm: float,
42
+ norm_type: float = 2.0,
43
+ device: Optional[torch.device] = None,
44
+ ) -> torch.Tensor:
45
+ """
46
+ Clip gradients for FSDP2 models with CPU offloading.
47
+
48
+ The standard torch.nn.utils.clip_grad_norm_ fails with FSDP2 CPU offloading because
49
+ DTensor operations (like division) trigger all_reduce on CPU, which has no backend.
50
+
51
+ This implementation:
52
+ 1. Extracts local tensors from DTensors
53
+ 2. Computes local norms on native device (CPU or GPU)
54
+ 3. All-reduces the scalar norm to get global norm
55
+ 4. Clips gradients in-place using the global norm
56
+
57
+ Args:
58
+ parameters: Iterable of parameters with gradients
59
+ max_norm: Maximum norm value
60
+ norm_type: Type of norm (default: L2)
61
+ device: Device for all-reduce tensor. If None, inferred from gradients or defaults to cuda.
62
+
63
+ Returns:
64
+ Total gradient norm (global across all ranks) as a regular tensor
65
+ """
66
+ if isinstance(parameters, torch.Tensor):
67
+ parameters = [parameters]
68
+ parameters = list(p for p in parameters if p.grad is not None)
69
+
70
+ if len(parameters) == 0:
71
+ return torch.tensor(0.0)
72
+
73
+ # Compute per-parameter norms on their native device (CPU or GPU)
74
+ # We compute norm^norm_type to allow proper aggregation across ranks
75
+ local_norm_sum = 0.0
76
+ inferred_device = None
77
+ for p in parameters:
78
+ if isinstance(p.grad, DTensor):
79
+ grad = p.grad._local_tensor
80
+ else:
81
+ grad = p.grad
82
+
83
+ # Infer CUDA device from gradients (use first CUDA device found)
84
+ if inferred_device is None and grad.device.type == "cuda":
85
+ inferred_device = grad.device
86
+
87
+ # Compute norm on the gradient's native device, accumulate as Python float
88
+ local_norm_sum += torch.norm(grad.detach().float(), norm_type).item() ** norm_type
89
+
90
+ # Use provided device, or inferred device, or fall back to current CUDA device
91
+ if device is None:
92
+ device = inferred_device if inferred_device is not None else torch.device("cuda")
93
+ local_norm_sum = torch.tensor(local_norm_sum, device=device)
94
+
95
+ # All-reduce to get global norm across all ranks
96
+ if world_size() > 1:
97
+ torch.distributed.all_reduce(local_norm_sum, op=torch.distributed.ReduceOp.SUM)
98
+
99
+ # Compute final global norm
100
+ total_norm = local_norm_sum ** (1.0 / norm_type)
101
+
102
+ # Compute clip coefficient (regular tensor division, no DTensor ops)
103
+ clip_coef = max_norm / (total_norm + 1e-6)
104
+ clip_coef_clamped = torch.clamp(clip_coef, max=1.0)
105
+
106
+ # Apply clipping to gradients
107
+ for p in parameters:
108
+ if isinstance(p.grad, DTensor):
109
+ # For DTensor, scale the local tensor directly
110
+ local_grad = p.grad._local_tensor
111
+ local_grad.mul_(clip_coef_clamped.to(local_grad.device))
112
+ else:
113
+ p.grad.detach().mul_(clip_coef_clamped.to(p.grad.device))
114
+
115
+ return total_norm
116
+
117
+
118
+ class GradClipCallback(Callback):
119
+ def __init__(
120
+ self,
121
+ grad_norm: float | None = 1.0,
122
+ model_key: str = "net",
123
+ posinf: float | None = None,
124
+ neginf: float | None = None,
125
+ precision_grad_clip: Optional[torch.dtype] = None,
126
+ ) -> None:
127
+ self.grad_norm = grad_norm
128
+ self.model_key = model_key
129
+ self.posinf = posinf
130
+ self.neginf = neginf
131
+ self.precision_grad_clip = precision_grad_clip
132
+
133
+ def nan_to_num(self, module: torch.nn.Module) -> tuple[int, torch.dtype | None]:
134
+ grad_dtype = None
135
+ non_finite_grads_count = 0
136
+
137
+ for name, param in module.named_parameters():
138
+ if param.grad is not None:
139
+ grad_dtype = param.grad.dtype
140
+
141
+ # Extract local tensor for DTensor (avoids triggering distributed ops on CPU)
142
+ if isinstance(param.grad, DTensor):
143
+ grad = param.grad._local_tensor
144
+ else:
145
+ grad = param.grad
146
+
147
+ non_finite_grads = grad.numel() - grad.isfinite().sum().item()
148
+ if non_finite_grads:
149
+ non_finite_grads_count += non_finite_grads
150
+ logger.debug(
151
+ f"Gradient of {name} (dtype {grad_dtype}) is not finite: "
152
+ f"Setting {grad.isnan().sum().item()} NaNs to 0 and {grad.isinf().sum().item()} Infs "
153
+ f"to {self.posinf} or {self.neginf}."
154
+ )
155
+ torch.nan_to_num(grad, nan=0.0, posinf=self.posinf, neginf=self.neginf, out=grad)
156
+
157
+ return non_finite_grads_count, grad_dtype
158
+
159
+ def on_optimizer_step_begin(self, model: FastGenModel, iteration: int = 0) -> None:
160
+ # unscale the optimizer related to the `model_key`
161
+ assert (
162
+ self.model_key in model.optimizer_dict.keys()
163
+ ), f"Keys in optimizer_dict: {list(model.optimizer_dict.keys())}."
164
+ optimizer = model.optimizer_dict[self.model_key]
165
+ # Only unscale if grad_scaler should be used (checks enabled + float32 grads)
166
+ if model.should_use_grad_scaler(optimizer):
167
+ model.grad_scaler.unscale_(optimizer)
168
+
169
+ # Save model device before selecting subnet (subnet may not have .device)
170
+ model_device = model.device
171
+
172
+ # select subnet if specified (by default, we only perform gradient clips on model.net)
173
+ subnets = self.model_key.split(".")
174
+ for subnet in subnets:
175
+ model = getattr(model, subnet)
176
+
177
+ # set nan to num for each parameter
178
+ non_finite_grads_count, grad_dtype = self.nan_to_num(model)
179
+ logger.debug(f"Gradient dtype of {self.model_key}: {grad_dtype}")
180
+ if non_finite_grads_count > 0:
181
+ logger.info(
182
+ f"Number of parameters with non-finite gradients (of dtype {grad_dtype}): {non_finite_grads_count}"
183
+ )
184
+ log_dict = {f"optimizer/non_finite_grads_count (model_key {self.model_key})": non_finite_grads_count}
185
+
186
+ if self.grad_norm is not None:
187
+ # Cast all gradients to precision_grad_clip for numerical stability during clipping
188
+ cast_grads = (
189
+ self.precision_grad_clip is not None
190
+ and grad_dtype is not None
191
+ and grad_dtype != self.precision_grad_clip
192
+ )
193
+
194
+ # log value at first iteration
195
+ if iteration == 1 and cast_grads:
196
+ logger.info(f"Casting gradients from {grad_dtype} to {self.precision_grad_clip} before clipping.")
197
+
198
+ # Check if CPU offloading is enabled by looking for DTensor grads on CPU
199
+ # CPU offloading = DTensor local tensors are on CPU
200
+ # No CPU offloading = DTensor local tensors are on GPU (or no DTensors)
201
+ use_fsdp_cpu_offload_clip = False
202
+ for p in model.parameters():
203
+ if p.grad is not None and isinstance(p.grad, DTensor):
204
+ if p.grad._local_tensor.device.type == "cpu":
205
+ use_fsdp_cpu_offload_clip = True
206
+ break
207
+
208
+ with cast_gradients_dtype(model, dtype=self.precision_grad_clip, enabled=cast_grads):
209
+ if use_fsdp_cpu_offload_clip:
210
+ # Use custom clipping for FSDP with CPU offloading
211
+ # Standard clip_grad_norm_ fails because DTensor ops trigger all_reduce on CPU
212
+ total_norm = clip_grad_norm_fsdp(model.parameters(), self.grad_norm, device=model_device)
213
+ else:
214
+ # Standard clipping for non-FSDP or FSDP without CPU offloading
215
+ total_norm = torch.nn.utils.clip_grad_norm_(model.parameters(), self.grad_norm, foreach=True)
216
+
217
+ log_dict[f"optimizer/grad_norm (model_key {self.model_key})"] = total_norm.item()
218
+
219
+ if hasattr(self, "config"):
220
+ # only wandb log when config exists
221
+ if iteration % self.config.trainer.logging_iter == 0 and is_rank0() and wandb.run:
222
+ wandb.log(log_dict, step=iteration)
FastGen/fastgen/callbacks/param_count.py ADDED
@@ -0,0 +1,116 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+ from typing import TYPE_CHECKING
6
+
7
+ from fastgen.callbacks.callback import Callback
8
+ from fastgen.utils.distributed import world_size
9
+ import fastgen.utils.logging_utils as logger
10
+ import torch
11
+ import wandb
12
+
13
+ try:
14
+ from torch.distributed.tensor import DTensor
15
+ except ImportError:
16
+ DTensor = None
17
+
18
+ if TYPE_CHECKING:
19
+ from fastgen.methods import FastGenModel
20
+
21
+
22
+ def _get_local_numel(param: torch.Tensor) -> int:
23
+ """Get the local (sharded) number of elements for a parameter.
24
+
25
+ For DTensor (FSDP2), returns the local shard size.
26
+ For regular tensors, returns the full size.
27
+ """
28
+ if DTensor is not None and isinstance(param, DTensor):
29
+ return param._local_tensor.numel()
30
+ return param.numel()
31
+
32
+
33
+ class ParamCountCallback(Callback):
34
+ def on_train_begin(self, model: FastGenModel, **kwargs) -> None:
35
+ # get modules
36
+ modules = {"model": model, **model.model_dict}
37
+
38
+ # iterate over modules
39
+ output = {}
40
+ for name, module in modules.items():
41
+ # Logical (full model) param counts
42
+ trainable_params = sum(p.numel() for p in module.parameters() if p.requires_grad)
43
+ total_params = sum(p.numel() for p in module.parameters())
44
+
45
+ # Local (sharded) param counts - what's actually in memory on this rank
46
+ local_trainable_params = sum(_get_local_numel(p) for p in module.parameters() if p.requires_grad)
47
+ local_total_params = sum(_get_local_numel(p) for p in module.parameters())
48
+
49
+ # check if parameter counts are different across ranks
50
+ if world_size() > 1:
51
+ trainable_params = self.gather_param_counts(trainable_params)
52
+ total_params = self.gather_param_counts(total_params)
53
+ local_trainable_params = self.gather_param_counts(local_trainable_params)
54
+ local_total_params = self.gather_param_counts(local_total_params)
55
+ if len(set(total_params)) == 1 and len(set(trainable_params)) == 1:
56
+ trainable_params = trainable_params[0]
57
+ total_params = total_params[0]
58
+ if len(set(local_total_params)) == 1 and len(set(local_trainable_params)) == 1:
59
+ local_trainable_params = local_trainable_params[0]
60
+ local_total_params = local_total_params[0]
61
+
62
+ # logging
63
+ module_name = module.__class__.__name__
64
+ output.update(
65
+ {
66
+ f"{name}/trainable_params": trainable_params,
67
+ f"{name}/total_params": total_params,
68
+ f"{name}/local_trainable_params": local_trainable_params,
69
+ f"{name}/local_total_params": local_total_params,
70
+ }
71
+ )
72
+ if isinstance(trainable_params, list):
73
+ logger.warning(f"Parameter counts differ across ranks for {module_name}.")
74
+ for rank, (p_train, p) in enumerate(zip(trainable_params, total_params)):
75
+ logger.info(
76
+ f"{name} ({module_name}) has {p_train * 1.e-6:.2f} M trainable and {p * 1.e-6:.2f} M total params on rank {rank}."
77
+ )
78
+ else:
79
+ logger.info(
80
+ f"{name} ({module_name}) has {trainable_params * 1.e-6:.2f} M trainable and {total_params * 1.e-6:.2f} M total params (logical)."
81
+ )
82
+
83
+ # Report local/sharded counts
84
+ if isinstance(local_trainable_params, list):
85
+ for rank, (p_train, p) in enumerate(zip(local_trainable_params, local_total_params)):
86
+ logger.info(
87
+ f"{name} ({module_name}) has {p_train * 1.e-6:.2f} M trainable and {p * 1.e-6:.2f} M total params LOCAL on rank {rank}."
88
+ )
89
+ else:
90
+ is_sharded = local_total_params < total_params if not isinstance(total_params, list) else True
91
+ if is_sharded:
92
+ logger.info(
93
+ f"{name} ({module_name}) has {local_trainable_params * 1.e-6:.2f} M trainable and {local_total_params * 1.e-6:.2f} M total params LOCAL per rank (sharding ratio: {world_size()}x)."
94
+ )
95
+ else:
96
+ logger.info(f"{name} ({module_name}) is NOT sharded (local == logical params).")
97
+
98
+ if wandb.run:
99
+ wandb.run.summary.update(output)
100
+
101
+ def gather_param_counts(self, param_count):
102
+ """
103
+ Gather parameter counts across all ranks.
104
+
105
+ Args:
106
+ param_count: Parameter count to gather.
107
+
108
+ Returns:
109
+ List of parameter counts across all ranks.
110
+ """
111
+ param_count = torch.tensor(
112
+ [param_count], dtype=torch.long, device="cuda" if torch.cuda.is_available() else "cpu"
113
+ )
114
+ param_count_list = [torch.zeros_like(param_count) for _ in range(world_size())]
115
+ torch.distributed.all_gather(param_count_list, param_count)
116
+ return [p.item() for p in param_count_list]
FastGen/fastgen/callbacks/train_profiler.py ADDED
@@ -0,0 +1,138 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+
6
+ import time
7
+ from typing import TYPE_CHECKING, Callable
8
+
9
+ import torch
10
+ import wandb
11
+
12
+ from fastgen.callbacks.callback import Callback
13
+ from fastgen.utils.distributed import is_rank0
14
+ import fastgen.utils.logging_utils as logger
15
+
16
+ if TYPE_CHECKING:
17
+ from fastgen.methods import FastGenModel
18
+
19
+
20
+ class TrainProfilerCallback(Callback):
21
+ """Callback for profiling training speed and detailed timing breakdowns.
22
+
23
+ Tracks:
24
+ - iter_time: seconds per iteration (wall clock time)
25
+ - data_load_time: time spent loading data
26
+ - avg_forward_time: average forward pass time across accumulation steps
27
+ - backward_time: time spent in backward pass
28
+ - optim_step_time: time spent in optimizer step
29
+ """
30
+
31
+ def __init__(self, every_n: int = 100, detailed: bool = True):
32
+ """Initialize the profiler callback.
33
+
34
+ Args:
35
+ every_n: Log metrics every N iterations
36
+ detailed: If True, log detailed timing breakdown. If False, only log iter_time.
37
+ """
38
+ # For iter_time tracking
39
+ self.last_log_time = None
40
+
41
+ # For detailed profiling
42
+ self.detailed = detailed
43
+ self.train_step_begin_time = None
44
+ self.accum_begin_times = None
45
+ self.backward_begin_times = None
46
+ self.optimizer_step_begin = None
47
+ self.step_end_time = None
48
+ self.every_n = every_n
49
+
50
+ def on_train_begin(self, model: FastGenModel, iteration: int = 0) -> None:
51
+ if hasattr(self, "config"):
52
+ # overwritten by logging_iter if self.config exists
53
+ self.every_n = self.config.trainer.logging_iter
54
+ logger.info(f"every_n to profile trainer: {self.every_n}")
55
+
56
+ def on_training_step_begin(
57
+ self,
58
+ model: FastGenModel,
59
+ iteration: int = 0,
60
+ ):
61
+ if self.detailed:
62
+ self.train_step_begin_time = time.perf_counter()
63
+ self.accum_begin_times = []
64
+ self.backward_begin_times = []
65
+
66
+ def on_training_accum_step_begin(
67
+ self, model: FastGenModel, data_batch: dict[str, torch.Tensor], iteration: int = 0, accum_iter: int = 0
68
+ ):
69
+ if self.detailed:
70
+ self.accum_begin_times.append(time.perf_counter())
71
+
72
+ def on_backward_begin(
73
+ self,
74
+ model: FastGenModel,
75
+ data_batch: dict[str, torch.Tensor],
76
+ output_batch: dict[str, torch.Tensor | Callable],
77
+ loss_dict: dict[str, torch.Tensor],
78
+ iteration: int = 0,
79
+ accum_iter: int = 0,
80
+ ):
81
+ if self.detailed:
82
+ self.backward_begin_times.append(time.perf_counter())
83
+
84
+ def on_optimizer_step_begin(self, model: FastGenModel, iteration: int = 0):
85
+ if self.detailed:
86
+ self.optimizer_step_begin = time.perf_counter()
87
+
88
+ def on_training_step_end(
89
+ self,
90
+ model: FastGenModel,
91
+ data_batch: dict[str, torch.Tensor],
92
+ output_batch: dict[str, torch.Tensor | Callable],
93
+ loss_dict: dict[str, torch.Tensor],
94
+ iteration: int = 0,
95
+ ) -> None:
96
+ del data_batch, output_batch, loss_dict
97
+
98
+ if self.detailed:
99
+ self.step_end_time = time.perf_counter()
100
+
101
+ if hasattr(self, "config"):
102
+ # only wandb log when config exists
103
+ if iteration % self.every_n == 0 and is_rank0():
104
+ metrics = {}
105
+
106
+ # Calculate iter_time (wall clock time per iteration)
107
+ cur_time = time.time()
108
+ if self.last_log_time is not None:
109
+ iter_time = (cur_time - self.last_log_time) / self.every_n
110
+ logger.info(f"{iteration} : avg iteration time {iter_time:.2f} seconds")
111
+ metrics["profiler/avg_iteration_time"] = iter_time
112
+ self.last_log_time = cur_time
113
+
114
+ # Calculate detailed timing breakdown
115
+ if self.detailed and self.accum_begin_times and self.backward_begin_times:
116
+ data_load_time = self.accum_begin_times[0] - self.train_step_begin_time
117
+ forward_time = sum(
118
+ [b - a for (b, a) in zip(self.backward_begin_times, self.accum_begin_times)]
119
+ ) / len(self.accum_begin_times)
120
+ backward_time = self.optimizer_step_begin - self.backward_begin_times[-1]
121
+ optim_step_time = self.step_end_time - self.optimizer_step_begin
122
+
123
+ logger.info(f"{iteration} : data loading time {data_load_time:.2f}")
124
+ logger.info(f"{iteration} : avg forward pass time {forward_time:.2f}")
125
+ logger.info(f"{iteration} : backward pass time {backward_time:.2f}")
126
+ logger.info(f"{iteration} : optimizer step time {optim_step_time:.2f}")
127
+
128
+ metrics.update(
129
+ {
130
+ "profiler/data_loading_time": data_load_time,
131
+ "profiler/avg_forward_pass_time": forward_time,
132
+ "profiler/backward_pass_time": backward_time,
133
+ "profiler/optimizer_step_time": optim_step_time,
134
+ }
135
+ )
136
+
137
+ if wandb.run and metrics:
138
+ wandb.log(metrics, step=iteration)
FastGen/fastgen/callbacks/wandb.py ADDED
@@ -0,0 +1,404 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from __future__ import annotations
5
+ import os
6
+ from dataclasses import dataclass, field
7
+ import time
8
+ from typing import Optional, Dict, Callable, TYPE_CHECKING
9
+ import gc
10
+
11
+
12
+ import torch
13
+ import torchvision
14
+ from torchvision.transforms import functional as tv_F
15
+
16
+ import wandb
17
+ import wandb.util
18
+
19
+ from fastgen.callbacks.callback import Callback
20
+ from fastgen.configs.config_utils import serialize_config
21
+ from fastgen.utils import basic_utils
22
+
23
+ from fastgen.utils.distributed import rank0_only, synchronize, world_size
24
+ from fastgen.utils import logging_utils as logger
25
+
26
+ if TYPE_CHECKING:
27
+ from fastgen.configs.config import BaseConfig
28
+ from fastgen.methods import FastGenModel
29
+
30
+
31
+ def to_wandb(
32
+ tensor: torch.Tensor,
33
+ rgb_range: float = 255.0,
34
+ normalized: bool = False,
35
+ max_plot_img: int = 16,
36
+ max_plot_vid: int = 2,
37
+ fps: int = 16,
38
+ channel_before_time: bool = True,
39
+ caption: str | None = None,
40
+ vid_format: str = "mp4",
41
+ ) -> wandb.Image | wandb.Video:
42
+ """
43
+ Convert a tensor to a wandb.Image or wandb.Video.
44
+
45
+ Args:
46
+ tensor (torch.Tensor): Input tensor of shape [B,C,H,W], [B,T,C,H,W], or [B,T,C,H,W,D].
47
+ rgb_range (float, optional): Output target RGB range (can almost definitely be kept as 255).
48
+ Defaults to 255.0.
49
+ normalized (bool, optional): Whether the tensor is normalized to [0,1]. Defaults to False which assumes [-1,1] range.
50
+ max_plot_img (int, optional): Max number of images to plot. Defaults to 16.
51
+ max_plot_vid (int, optional): Max number of videos to plot. Defaults to 2.
52
+ fps (int, optional): Frames per second. Defaults to 8.
53
+ channel_before_time (bool, optional): Whether the tensor is in the format [B,C,T,..]. Set False if the [B,T,C,..] format is used.
54
+ caption (str, optional): Caption for the image or video. Defaults to None.
55
+ vid_format (str, optional): Format of the video file. Defaults to "mp4".
56
+
57
+ Returns:
58
+ wandb.Image | wandb.Video: Format a tensor for logging to W&B.
59
+ """
60
+
61
+ if tensor.ndim == 5:
62
+ max_plot = max_plot_vid
63
+ if channel_before_time:
64
+ tensor = tensor.permute(0, 2, 1, 3, 4)
65
+ elif tensor.ndim == 4:
66
+ max_plot = max_plot_img
67
+ else:
68
+ raise ValueError(f"Tensor must be 4 or 5 dimensional, but got {tensor.ndim} dimensions")
69
+
70
+ # slice and adjust range
71
+ if normalized:
72
+ factor = rgb_range
73
+ offset = 0.0
74
+ else:
75
+ factor = rgb_range / 2.0
76
+ offset = rgb_range / 2.0
77
+ tensor = tensor[:max_plot].mul(factor).add(offset).clip_(0, rgb_range).to(torch.uint8)
78
+
79
+ # convert to wandb.Image or wandb.Video
80
+ assert tensor.shape[-3] == 3, "Make sure that the data is in ..., C, H, W format"
81
+ if tensor.ndim == 5:
82
+ return wandb.Video(tensor.cpu().numpy(), fps=fps, format=vid_format, caption=caption)
83
+ else:
84
+ image_grid = torchvision.utils.make_grid(tensor, nrow=4, pad_value=1)
85
+ image_grid = tv_F.to_pil_image(image_grid)
86
+ return wandb.Image(image_grid, caption=caption)
87
+
88
+
89
+ @rank0_only
90
+ def init_wandb(config: BaseConfig):
91
+ # wandb login
92
+ wandb_credential = config.log_config.wandb_credential
93
+ if os.path.isfile(wandb_credential):
94
+ os.environ["WANDB_API_KEY"] = open(wandb_credential, encoding="utf-8").read().strip("\n")
95
+ logger.info(f"Loading WANDB_API_KEY from {wandb_credential}")
96
+
97
+ wandb_config = config.log_config
98
+
99
+ # Resume with or generate a wandb id
100
+ logger.info(f"wandb_config.save_path: {wandb_config.save_path}")
101
+ os.makedirs(wandb_config.save_path, exist_ok=True)
102
+ wandb_id_path = f"{wandb_config.save_path}/wandb_id.txt"
103
+ if os.path.isfile(wandb_id_path):
104
+ wandb_id = open(wandb_id_path, encoding="utf-8").read().strip()
105
+ logger.info(f"Resuming with an existing wandb id: {wandb_id}")
106
+ else:
107
+ wandb_id = wandb.util.generate_id()
108
+ with open(wandb_id_path, "w", encoding="utf-8") as f:
109
+ f.write(f"{wandb_id}\n")
110
+ logger.info(f"Generating a wandb id: {wandb_id}")
111
+
112
+ # Get config as plain dict
113
+ config_resolved = serialize_config(config, return_type="dict")
114
+
115
+ # Initialize the wandb library.
116
+ wandb.init(
117
+ id=wandb_id,
118
+ project=wandb_config.project,
119
+ group=wandb_config.group,
120
+ name=wandb_config.name,
121
+ config=config_resolved,
122
+ dir=wandb_config.save_path,
123
+ resume="allow",
124
+ mode=wandb_config.wandb_mode,
125
+ )
126
+
127
+ # Save a copy of code to a wandb Artifact (this can be slow)
128
+ # Make code upload optional to avoid distributed training delays
129
+ upload_code = basic_utils.str2bool(os.getenv("WANDB_UPLOAD_CODE", "false"))
130
+ if upload_code:
131
+ logger.info("Uploading code to wandb (this may take a few minutes)...")
132
+ wandb.run.log_code(".")
133
+ logger.info("Code upload to wandb completed")
134
+ else:
135
+ logger.info("Wandb code upload disabled (set WANDB_UPLOAD_CODE=true to enable)")
136
+
137
+
138
+ @dataclass
139
+ class _LossDictRecord:
140
+ loss_dict: dict = field(default_factory=dict)
141
+ iter_count_dict: dict = field(default_factory=dict)
142
+
143
+ def add(self, loss_dict: Optional[Dict[str, torch.Tensor]]) -> None:
144
+ if loss_dict is not None:
145
+ for loss_name, loss_val in loss_dict.items():
146
+ self.loss_dict[loss_name] = self.loss_dict.get(loss_name, 0.0) + loss_val.float().item()
147
+ self.iter_count_dict[loss_name] = self.iter_count_dict.get(loss_name, 0) + 1
148
+
149
+ def reset(self) -> None:
150
+ self.loss_dict = {}
151
+ self.iter_count_dict = {}
152
+
153
+ def gather_dict(self, dictionary: Dict[str, float | int]) -> Dict[str, float | int]:
154
+ n_ranks = world_size()
155
+ if n_ranks > 1:
156
+ dict_list = [None for _ in range(n_ranks)]
157
+ torch.distributed.all_gather_object(dict_list, dictionary)
158
+ # from list of dicts to dict of summed values
159
+ dictionary = {}
160
+ for d in dict_list:
161
+ for key, value in d.items():
162
+ dictionary[key] = dictionary.get(key, 0.0) + value
163
+ return dictionary
164
+
165
+ def get_stat(self) -> Dict[str, float]:
166
+ # number of ranks that logged this loss
167
+ rank_dict = self.gather_dict({k: 1 for k in self.loss_dict.keys()})
168
+ # number of times this loss was computed
169
+ count_dict = self.gather_dict(self.iter_count_dict)
170
+ # sum of all losses
171
+ loss_dict = self.gather_dict(self.loss_dict)
172
+
173
+ avg_loss_dict = {}
174
+ for loss_name, loss_val in loss_dict.items():
175
+ count = count_dict.get(loss_name, 0)
176
+ ranks = rank_dict.get(loss_name, 1)
177
+ iter_count = count / ranks
178
+ avg_loss = (loss_val / count) * (ranks / world_size()) if count > 0 else 0.0
179
+ logger.info(f"avg_{loss_name}: {avg_loss:.4f}".ljust(30) + f"iter count: {iter_count}")
180
+ avg_loss_dict[loss_name] = avg_loss
181
+ self.reset()
182
+ return avg_loss_dict
183
+
184
+
185
+ class WandbCallback(Callback):
186
+ """
187
+ The callback gets precision for data from model
188
+ """
189
+
190
+ def __init__(
191
+ self,
192
+ *args,
193
+ validation_logging_step: int = 1,
194
+ sample_logging_iter: Optional[int] = None,
195
+ vid_format: str = "mp4",
196
+ **kwargs,
197
+ ):
198
+ super().__init__(*args, **kwargs)
199
+
200
+ self.validation_logging_step = validation_logging_step
201
+ self.sample_logging_iter = sample_logging_iter
202
+ self.val_sample_map = None
203
+ self.vid_format = vid_format
204
+ self.loss_dict_record = _LossDictRecord()
205
+ self.val_loss_dict_record = _LossDictRecord()
206
+
207
+ def on_app_begin(self) -> None:
208
+ assert hasattr(self, "config"), "Missing config in WandbCallback."
209
+ init_wandb(self.config)
210
+ self.offload_module_in_decoding = self.config.trainer.offload_module_in_decoding
211
+ # disable offloading if using FSDP
212
+ if self.config.trainer.fsdp:
213
+ self.offload_module_in_decoding = False
214
+ if self.sample_logging_iter is None:
215
+ self.sample_logging_iter = self.config.trainer.logging_iter
216
+ synchronize()
217
+
218
+ @rank0_only
219
+ def on_optimizer_step_begin(self, model: FastGenModel, iteration: int = 0) -> None:
220
+ assert hasattr(self, "config"), "Missing config in WandbCallback."
221
+ if iteration % self.config.trainer.logging_iter == 0:
222
+ for name, scheduler in model.scheduler_dict.items():
223
+ wandb.log({f"optimizer/lr_{name}": scheduler.get_last_lr()[0]}, step=iteration)
224
+
225
+ def get_sample_map(
226
+ self, model: FastGenModel, data_batch: dict[str, torch.Tensor], output_batch: dict[str, torch.Tensor | Callable]
227
+ ) -> dict[str, wandb.Image | wandb.Video]:
228
+ # Collect generated and real data and create copies to avoid modifying the original dicts
229
+ sample_map = {}
230
+ gen_rand = output_batch["gen_rand"]
231
+ if isinstance(gen_rand, Callable):
232
+ synchronize()
233
+ gen_rand = gen_rand()
234
+ synchronize()
235
+
236
+ # Avoid modifying the original dicts
237
+ data_batch = data_batch.copy()
238
+ output_batch = output_batch.copy()
239
+
240
+ # Decide whether we want to visualize multistep teacher generation
241
+ if self.config.trainer.visualize_teacher:
242
+ assert "input_rand" in output_batch, "We need to know the noise to visualize teacher generation"
243
+ teacher_output = model.sample(
244
+ model.teacher,
245
+ output_batch["input_rand"][0:1],
246
+ data_batch["condition"][0:1], # e.g. text condition encoded by the text encoder
247
+ data_batch["neg_condition"][0:1], # e.g. negative text condition encoded by the text encoder
248
+ )
249
+ output_batch["gen_teacher"] = teacher_output
250
+
251
+ # Decode to pixel if it's in latent space
252
+ if hasattr(model.net, "init_preprocessors"):
253
+ torch.cuda.empty_cache()
254
+ device_nets = model.device
255
+
256
+ has_vae = hasattr(model.net, "vae")
257
+ if not has_vae:
258
+ model.net.init_vae()
259
+ model.net.vae.to(device=device_nets, dtype=model.precision)
260
+
261
+ if self.offload_module_in_decoding:
262
+ # offload the unneeded models to CPU (enable it if hitting OOM here)
263
+ logger.info(
264
+ f"GPU Memory BEFORE moving nets to CPU: {torch.cuda.memory_allocated(device_nets) / 1024 ** 2:.2f} MB"
265
+ )
266
+ if hasattr(model, "fake_score"):
267
+ model.fake_score = model.fake_score.to("cpu")
268
+ if hasattr(model, "teacher"):
269
+ model.teacher = model.teacher.to("cpu")
270
+ logger.info(
271
+ f"GPU Memory AFTER moving nets to CPU: {torch.cuda.memory_allocated(device_nets) / 1024 ** 2:.2f} MB"
272
+ )
273
+ synchronize()
274
+
275
+ with basic_utils.inference_mode(precision_amp=model.precision_amp_enc, device_type=device_nets.type):
276
+ if "real" in data_batch:
277
+ # only generate one sample for video
278
+ limit = 1 if len(data_batch["real"].shape) == 5 else len(data_batch["real"])
279
+ data_batch["real"] = model.net.vae.decode(data_batch["real"][:limit])
280
+ if isinstance(gen_rand, dict):
281
+ for k in gen_rand:
282
+ limit = 1 if len(gen_rand[k].shape) == 5 else len(gen_rand[k])
283
+ gen_rand[k] = model.net.vae.decode(gen_rand[k][:limit])
284
+ else:
285
+ limit = 1 if len(gen_rand.shape) == 5 else len(gen_rand)
286
+ gen_rand = model.net.vae.decode(gen_rand[:limit])
287
+
288
+ if "gen_teacher" in output_batch:
289
+ output_batch["gen_teacher"] = model.net.vae.decode(output_batch["gen_teacher"][:limit])
290
+ if logger.LOG_LEVEL == "DEBUG" and "gen_rand_train" in output_batch:
291
+ output_batch["gen_rand_train"] = model.net.vae.decode(output_batch["gen_rand_train"][:limit])
292
+
293
+ if not has_vae:
294
+ del model.net.vae
295
+
296
+ if self.offload_module_in_decoding:
297
+ # move back fake_score to gpu
298
+ if hasattr(model, "fake_score"):
299
+ model.fake_score = model.fake_score.to(device_nets)
300
+ if hasattr(model, "teacher"):
301
+ model.teacher = model.teacher.to(device_nets)
302
+ logger.info(
303
+ f"GPU Memory AFTER moving nets back to GPU: {torch.cuda.memory_allocated(device_nets) / 1024 ** 2:.2f} MB"
304
+ )
305
+ synchronize()
306
+
307
+ if wandb.run:
308
+ if (
309
+ "condition_raw" in data_batch
310
+ and isinstance(data_batch["condition_raw"], (list, tuple))
311
+ and isinstance(data_batch["condition_raw"][0], str)
312
+ ):
313
+ caption = "\n".join(data_batch["condition_raw"][: len(gen_rand)])
314
+ else:
315
+ caption = None
316
+ if isinstance(gen_rand, dict):
317
+ for k in gen_rand:
318
+ sample_map[f"student/generation/{k}"] = to_wandb(
319
+ gen_rand[k], caption=caption, vid_format=self.vid_format
320
+ )
321
+ else:
322
+ sample_map["student/generation"] = to_wandb(gen_rand, caption=caption, vid_format=self.vid_format)
323
+ if "real" in data_batch:
324
+ sample_map["data/real"] = to_wandb(data_batch["real"], caption=caption, vid_format=self.vid_format)
325
+ if "gen_teacher" in output_batch:
326
+ sample_map["teacher/generation"] = to_wandb(
327
+ output_batch["gen_teacher"], caption=caption, vid_format=self.vid_format
328
+ )
329
+ if logger.LOG_LEVEL == "DEBUG" and "gen_rand_train" in output_batch:
330
+ sample_map["student/generation_train"] = to_wandb(
331
+ output_batch["gen_rand_train"], caption=caption, vid_format=self.vid_format
332
+ )
333
+
334
+ return sample_map
335
+
336
+ def log_sample_map(
337
+ self,
338
+ model: FastGenModel,
339
+ data_batch: dict[str, torch.Tensor],
340
+ output_batch: dict[str, torch.Tensor | Callable],
341
+ suffix: str = "",
342
+ iteration: int = 0,
343
+ group: str = "train",
344
+ ) -> None:
345
+ sample_map = self.get_sample_map(model, data_batch, output_batch)
346
+ sample_map = {f"{group}_media/{k}{suffix}": v for k, v in sample_map.items()}
347
+ if wandb.run:
348
+ wandb.log(sample_map, step=iteration)
349
+ synchronize()
350
+ gc.collect()
351
+ torch.cuda.empty_cache()
352
+
353
+ def log_stats(self, loss_dict_record: _LossDictRecord, iteration: int = 0, group: str = "train") -> None:
354
+ logger.info(f"logging {group} stats at iteration {iteration}" + "-" * 20)
355
+ # Collect distributed statistics
356
+ avg_loss_dict = loss_dict_record.get_stat()
357
+ stats = {f"{group}/{name}": val for name, val in avg_loss_dict.items()}
358
+ base_info = {"optimizer/iteration": iteration}
359
+
360
+ # log stats and base info
361
+ if wandb.run:
362
+ wandb.log(stats, step=iteration)
363
+ wandb.log(base_info, step=iteration)
364
+
365
+ def on_training_step_end(
366
+ self,
367
+ model: FastGenModel,
368
+ data_batch: dict[str, torch.Tensor],
369
+ output_batch: dict[str, torch.Tensor | Callable],
370
+ loss_dict: dict[str, torch.Tensor],
371
+ iteration: int = 0,
372
+ ) -> None:
373
+ self.loss_dict_record.add(loss_dict)
374
+ time_start = time.perf_counter()
375
+ logged = False
376
+ if iteration % self.config.trainer.logging_iter == 0 or iteration == 1:
377
+ self.log_stats(self.loss_dict_record, iteration=iteration, group="train")
378
+ logged = True
379
+ if iteration % self.sample_logging_iter == 0 or iteration == 1:
380
+ self.log_sample_map(model, data_batch, output_batch, iteration=iteration, group="train")
381
+ logged = True
382
+ if logged:
383
+ time_taken = time.perf_counter() - time_start
384
+ logger.info(f"WandB logging complete after {time_taken:.2f} seconds")
385
+
386
+ def on_validation_step_end(
387
+ self,
388
+ model: FastGenModel,
389
+ data_batch: dict[str, torch.Tensor],
390
+ output_batch: dict[str, torch.Tensor | Callable],
391
+ loss_dict: dict[str, torch.Tensor],
392
+ step: int = 0,
393
+ iteration: int = 0,
394
+ idx: int = 0,
395
+ ) -> None:
396
+ self.val_loss_dict_record.add(loss_dict)
397
+
398
+ if step % self.validation_logging_step == 0:
399
+ self.log_sample_map(
400
+ model, data_batch, output_batch, suffix=f"_{step}", iteration=iteration, group=f"val{idx}"
401
+ )
402
+
403
+ def on_validation_end(self, model: FastGenModel, iteration: int = 0, idx: int = 0) -> None:
404
+ self.log_stats(self.val_loss_dict_record, iteration=iteration, group=f"val{idx}")
FastGen/fastgen/configs/README.md ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Configuration System
2
+
3
+ FastGen uses a hierarchical Python-based configuration system built on [Hydra](https://hydra.cc/), [OmegaConf](https://omegaconf.readthedocs.io/), and [attrs](https://www.attrs.org/).
4
+
5
+ ## Directory Structure
6
+
7
+ ```
8
+ fastgen/configs/
9
+ ├── experiments/ # Experiment configs (cifar10, Wan, sdxl, etc.)
10
+ ├── methods/ # Method-specific configs (DMD2, CM, KD, SFT, etc.)
11
+ ├── callbacks.py # Callback configurations (EMA, WandB, GradClip, etc.)
12
+ ├── config_utils.py # Utilities (import, override, serialize configs)
13
+ ├── config.py # Base config classes (BaseConfig, BaseModelConfig, BaseTrainerConfig)
14
+ ├── data.py # Dataset/dataloader configurations
15
+ ├── discriminator.py # Discriminator configs (for GAN-based methods)
16
+ ├── net.py # Network architecture configurations
17
+ ├── opt.py # Optimizer and scheduler configurations
18
+ ```
19
+
20
+ ## Config Hierarchy
21
+
22
+ 1. **Base Configs** (`config.py`): Core classes defining model, trainer, data, and logging settings
23
+ 2. **Method Configs** (`methods/`): Extend base configs with method-specific parameters
24
+ 3. **Experiment Configs** (`experiments/`): Concrete configs for specific dataset/model combinations
25
+
26
+ ## LazyCall Pattern
27
+
28
+ Deferred instantiation using `LazyCall`:
29
+
30
+ ```python
31
+ from fastgen.utils import LazyCall as L
32
+ from fastgen.methods import DMD2Model
33
+
34
+ model_class = L(DMD2Model)(config=None) # Config dict with _target_
35
+ model = instantiate(model_class) # Instantiate later
36
+ ```
37
+
38
+ ## Command-Line Arguments
39
+
40
+ ```bash
41
+ python train.py --config=path/to/config.py [--log_level LEVEL] [--dryrun] - key=value
42
+ ```
43
+
44
+ | Argument | Description |
45
+ |----------|-------------|
46
+ | `--config` | Path to the config file |
47
+ | `--log_level` | Log level: DEBUG, INFO (default), WARNING, ERROR |
48
+ | `--dryrun` | Print resolved config and exit without training |
49
+ | `-` | Separator before config overrides (required) |
50
+
51
+ Examples:
52
+
53
+ ```bash
54
+ # Override training settings
55
+ python train.py --config=fastgen/configs/experiments/EDM/config_dmd2_test.py - \
56
+ trainer.max_iter=10000 \
57
+ model.gan_loss_weight_gen=0. \
58
+ log_config.name=my_experiment
59
+
60
+ # Debug config without training
61
+ python train.py --config=fastgen/configs/experiments/EDM/config_dmd2_test.py --dryrun
62
+
63
+ # Verbose logging
64
+ python train.py --config=fastgen/configs/experiments/EDM/config_dmd2_test.py --log_level DEBUG
65
+ ```
66
+
67
+ ## Key Config Classes
68
+
69
+ | Class | Purpose |
70
+ |-------|---------|
71
+ | `BaseConfig` | Top-level: model, trainer, dataloader, logging |
72
+ | `BaseModelConfig` | Network, optimizer, precision, EMA, guidance |
73
+ | `BaseTrainerConfig` | Checkpointing, callbacks, DDP/FSDP, iterations |
74
+ | `LogConfig` | Project, group, name, wandb settings |
75
+
76
+ ## Environment Variables
77
+
78
+ ### Core Variables
79
+
80
+ | Variable | Description | Default |
81
+ |----------|-------------|---------|
82
+ | `FASTGEN_OUTPUT_ROOT` | Root directory for checkpoints, logs, and outputs | `FASTGEN_OUTPUT` |
83
+ | `DATA_ROOT_DIR` | Root directory for datasets | `$FASTGEN_OUTPUT_ROOT/DATA` |
84
+ | `CKPT_ROOT_DIR` | Root directory for pretrained checkpoints | `$FASTGEN_OUTPUT_ROOT/MODEL` |
85
+ | `HF_HOME` | HuggingFace cache directory | `$FASTGEN_OUTPUT_ROOT/.cache` |
86
+ | `LOCAL_FILES_ONLY` | Use only local files, skip downloads from HuggingFace | `false` |
87
+ | `WANDB_API_KEY` | W&B API key for persistent logging ([get yours here](https://wandb.ai/settings)) | (none, W&B will prompt) |
88
+ | `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION`, `AWS_ENDPOINT_URL` | S3 credentials for data and checkpoint storage | (none) |
89
+
90
+ ### Loading Credentials from Files
91
+
92
+ As an alternative to setting environment variables directly, credentials can be loaded automatically from files in the `./credentials/` directory:
93
+
94
+ | File | Environment Variables Set |
95
+ |------|---------------------------|
96
+ | `./credentials/wandb_api.txt` | `WANDB_API_KEY` (plain text file containing your API key) |
97
+ | `./credentials/s3.json` | `AWS_ACCESS_KEY_ID`, `AWS_SECRET_ACCESS_KEY`, `AWS_REGION`, `AWS_ENDPOINT_URL` (JSON format below) |
98
+
99
+ **Format for `s3.json`:**
100
+
101
+ ```json
102
+ {
103
+ "aws_access_key_id": "<your_access_key>",
104
+ "aws_secret_access_key": "<your_secret_key>",
105
+ "region_name": "<region>",
106
+ "endpoint_url": "<s3_endpoint_url>"
107
+ }
108
+ ```
FastGen/fastgen/configs/__init__.py ADDED
@@ -0,0 +1,2 @@
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
FastGen/fastgen/configs/callbacks.py ADDED
@@ -0,0 +1,63 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from fastgen.utils import LazyCall as L
5
+
6
+ from fastgen.callbacks.ct_schedule import CTScheduleCallback
7
+ from fastgen.callbacks.grad_clip import GradClipCallback
8
+ from fastgen.callbacks.param_count import ParamCountCallback
9
+ from fastgen.callbacks.wandb import WandbCallback
10
+ from fastgen.callbacks.ema import EMACallback
11
+ from fastgen.callbacks.train_profiler import TrainProfilerCallback
12
+ from fastgen.callbacks.gpu_stats import GPUStatsCallback
13
+ from fastgen.callbacks.forced_weight_norm import ForcedWeightNormCallback
14
+ from fastgen.callbacks.gpu_mem_profiler import MemTrackerCallback
15
+
16
+
17
+ CTSchedule_CALLBACK = dict(
18
+ ct_schedule=L(CTScheduleCallback)(q=2.0, ratio_limit=0.999, kimg_per_stage=12500),
19
+ )
20
+
21
+ EMA_CALLBACK = dict(
22
+ ema=L(EMACallback)(type="constant", beta=0.9999, gamma=16.97, ema_halflife_kimg=500, ema_rampup_ratio=0.05),
23
+ )
24
+
25
+ EMA_CONST_CALLBACKS = dict(
26
+ ema_9999=L(EMACallback)(type="constant", beta=0.9999, ema_name="ema_9999"),
27
+ ema_99995=L(EMACallback)(type="constant", beta=0.99995, ema_name="ema_99995"),
28
+ ema_9996=L(EMACallback)(type="constant", beta=0.9996, ema_name="ema_9996"),
29
+ )
30
+
31
+ EMA_POWER_CALLBACKS = dict(
32
+ ema_1=L(EMACallback)(type="power", gamma=96.99, ema_name="ema_1"),
33
+ ema_5=L(EMACallback)(type="power", gamma=16.97, ema_name="ema_5"),
34
+ ema_10=L(EMACallback)(type="power", gamma=6.94, ema_name="ema_10"),
35
+ )
36
+
37
+ ForcedWeightNorm_CALLBACK = dict(
38
+ forced_weight_norm=L(ForcedWeightNormCallback)(),
39
+ )
40
+
41
+ GradClip_CALLBACK = dict(
42
+ grad_clip=L(GradClipCallback)(grad_norm=10.0, model_key="net"),
43
+ )
44
+
45
+ GPUStats_CALLBACK = dict(
46
+ gpu_stats=L(GPUStatsCallback)(every_n=100),
47
+ )
48
+
49
+ ParamCount_CALLBACK = dict(
50
+ param_count=L(ParamCountCallback)(),
51
+ )
52
+
53
+ TrainProfiler_CALLBACK = dict(
54
+ train_profiler=L(TrainProfilerCallback)(every_n=100),
55
+ )
56
+
57
+ WANDB_CALLBACK = dict(
58
+ wandb=L(WandbCallback)(sample_logging_iter=None),
59
+ )
60
+
61
+ MemTracker_CALLBACK = dict(
62
+ mem_tracker=L(MemTrackerCallback)(),
63
+ )
FastGen/fastgen/configs/config.py ADDED
@@ -0,0 +1,254 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import os
5
+ from typing import Any, List, Optional, Dict
6
+
7
+ import copy
8
+ import attrs
9
+ from omegaconf import DictConfig
10
+
11
+ from fastgen.utils import LazyCall as L
12
+ from fastgen.configs.callbacks import WANDB_CALLBACK
13
+ from fastgen.configs.data import CIFAR10_Loader_Config
14
+ from fastgen.configs.net import EDM_CIFAR10_Config as EDMConfig
15
+ from fastgen.configs.opt import BaseOptimizerConfig, BaseSchedulerConfig
16
+ from fastgen.methods import FastGenModel
17
+
18
+
19
+ @attrs.define(slots=False)
20
+ class CuDNNConfig:
21
+ # If set to True, cudnn will use deterministic cudnn functions for better reproducibility.
22
+ deterministic: bool = False
23
+ # If set to True, cudnn will benchmark several algorithms and pick the fastest one.
24
+ benchmark: bool = True
25
+
26
+
27
+ @attrs.define(slots=False)
28
+ class LogConfig:
29
+ # Project name
30
+ project: str = "fastgen"
31
+ # Experiment name
32
+ group: str = "cifar10"
33
+ # Run/job name
34
+ name: str = "debug"
35
+ # W&B mode, can be "online" or "disabled".
36
+ wandb_mode: str = "online"
37
+ # Wandb credential path
38
+ wandb_credential: str = "./credentials/wandb_api.txt"
39
+
40
+ # save path
41
+ @property
42
+ def save_path(self) -> str:
43
+ return os.path.join(
44
+ os.environ.get("FASTGEN_OUTPUT_ROOT", "FASTGEN_OUTPUT"), f"{self.project}/{self.group}/{self.name}"
45
+ )
46
+
47
+
48
+ @attrs.define(slots=False)
49
+ class EvalConfig:
50
+ # Number of samples to generate
51
+ num_samples: int = 50000
52
+ # Save a small batch of images
53
+ save_images: bool = False
54
+ # Minimum checkpoint to evaluate
55
+ min_ckpt: int = 0
56
+ # Maximum checkpoint to evaluate
57
+ max_ckpt: int = 100000000
58
+ # Directory to save samples
59
+ samples_dir: str = "samples"
60
+
61
+
62
+ @attrs.define(slots=False)
63
+ class BaseCheckpointerConfig:
64
+ save_dir: str = "checkpoints"
65
+ use_s3: bool = False
66
+ s3_container: str = "s3://checkpoints/fastgen"
67
+ s3_credential: str = "./credentials/s3.json"
68
+
69
+ # path to pretrained model (from previous stages),
70
+ # it's used by loading fsdp/ddp trained ckpt to an fsdp/ddp pipeline
71
+ pretrained_ckpt_path: str = ""
72
+ # submodule names of model and keys of a pretrained checkpoint of the form {"model": {"submodule_key": ...}, ...}
73
+ pretrained_ckpt_key_map: Dict[str, str] = {"net": "net"}
74
+
75
+
76
+ @attrs.define(slots=False)
77
+ class SampleTConfig:
78
+ """Config for sampling t from a time distribution."""
79
+
80
+ # time distribution (currently supporting: uniform, lognormal, polynomial, logitnormal, shift, and log_t)
81
+ time_dist_type: str = "uniform"
82
+ # mu in lognormal, logitnormal, and log_t distributions
83
+ train_p_mean: float = -1.1
84
+ # sigma in lognormal, logitnormal, and log_t distributions
85
+ train_p_std: float = 2.0
86
+ # shift value in shifted sampling (t_shifted = t * shift / (t * (shift - 1) + 1))
87
+ shift: float = 5.0
88
+ # lowest value in truncated range
89
+ min_t: float = 0.002
90
+ # highest value in truncated range
91
+ max_t: float = 80.0
92
+ # If provided, it is in the form [t_max, ..., 0] where len(t_list) needs to equal student_sample_steps + 1
93
+ t_list: Optional[List[float]] = None
94
+ # degree of freedom in log-transformed student-t distribution
95
+ log_t_df: float = 0.01
96
+
97
+
98
+ @attrs.define(slots=False)
99
+ class BaseModelConfig:
100
+ # Use factory functions to ensure each instance gets its own copy
101
+ net: dict = attrs.field(factory=lambda: copy.deepcopy(EDMConfig))
102
+ teacher: Optional[dict] = None # Usually not used, only used when teacher is different from net (i.e. Causvid)
103
+
104
+ # guidance scale for classifier-free guidance in teacher diffusion model. None means no guidance.
105
+ guidance_scale: Optional[float] = None
106
+
107
+ # enable skip layer guidance (currently only wan network has the skip_layers option in cfg)
108
+ skip_layers: List[int] | None = None
109
+
110
+ # optimizer and scheduler for the main net (i.e., one-step generator in DMD)
111
+ net_optimizer: dict = attrs.field(factory=lambda: copy.deepcopy(BaseOptimizerConfig))
112
+ net_scheduler: dict = attrs.field(factory=lambda: copy.deepcopy(BaseSchedulerConfig))
113
+
114
+ # sampling t from a given distribution
115
+ sample_t_cfg: SampleTConfig = attrs.field(factory=SampleTConfig)
116
+
117
+ # shape of the input to the model (defaults to CIFAR-10)
118
+ input_shape: List[int] = [3, 32, 32]
119
+ # device ("cuda" or "cpu")
120
+ device: str = "cuda"
121
+
122
+ # enable gradient scaler
123
+ grad_scaler_enabled: bool = False
124
+ grad_scaler_init_scale: float = 65536.0
125
+ grad_scaler_growth_interval: int = 2000
126
+
127
+ # path to the pretrained teacher model ckpt
128
+ pretrained_model_path: str = ""
129
+ # path to the pretrained student net ckpt (if different from the teacher)
130
+ pretrained_student_net_path: str = ""
131
+ # initialize student from the above checkpoints (can be turned off to only load weights to the teacher)
132
+ load_student_weights: bool = True
133
+
134
+ # enable preprocessors in the model
135
+ enable_preprocessors: bool = True
136
+
137
+ # EMA for the main net (requires EMACallback)
138
+ use_ema: Any = False
139
+
140
+ # multistep generation if larger than 1 (default: single-step generation)
141
+ student_sample_steps: int = 1
142
+ # sampling type in multistep generation ('sde', 'ode')
143
+ student_sample_type: str = "sde"
144
+
145
+ # Enable memory-efficient model loading with meta device:
146
+ # - Rank 0 loads pretrained weights normally
147
+ # - Other ranks use torch.device("meta") for ZERO memory allocation (just metadata)
148
+ # - FSDP materializes meta tensors and broadcasts weights from rank 0
149
+ # This dramatically speeds up initialization for large models (14B+):
150
+ # - Reduces RAM from N*model_size to 1*model_size
151
+ # - Eliminates disk I/O contention (N parallel reads -> 1 read)
152
+ # - Expected speedup: 30+ min -> <1 min for 14B models on 8 GPUs
153
+ fsdp_meta_init: bool = False
154
+
155
+ # whether to add the teacher model to the fsdp_dict
156
+ add_teacher_to_fsdp_dict: bool = True
157
+
158
+ # whether to find unused parameters in ddp
159
+ # - can be turned off for improved performance
160
+ # - however, it is required if the model has a discriminator or the net initializes unused modules (e.g., for logvar predictions)
161
+ ddp_find_unused_parameters: bool = True
162
+
163
+ # precision variables (choose from "float64", "float32", "bfloat16", or "float16")
164
+ # (precision of the time steps is handled in the noise scheduler, defaulting to float64 for numerical stability)
165
+
166
+ # precision for model/optimizer states and data - recommended to be float32 if precision_amp is not None
167
+ precision: str = "float32"
168
+ # AMP during training - if None or equal to precision, AMP is disabled during training.
169
+ precision_amp: str | None = None
170
+ # AMP during inference - if None or equal to precision, AMP is disabled during inference.
171
+ precision_amp_infer: str | None = None
172
+ # AMP during en-/decoding (e.g., for VAEs or text encoders) - if None or equal to precision, AMP is disabled during en-/decoding.
173
+ precision_amp_enc: str | None = None
174
+
175
+
176
+ @attrs.define(slots=False)
177
+ class BaseTrainerConfig:
178
+ cudnn: CuDNNConfig = attrs.field(factory=CuDNNConfig)
179
+ checkpointer: BaseCheckpointerConfig = attrs.field(factory=BaseCheckpointerConfig)
180
+
181
+ # Callbacks configs.
182
+ callbacks: dict = DictConfig(WANDB_CALLBACK)
183
+
184
+ # save checkpoint frequency
185
+ save_ckpt_iter: int = 5000
186
+ # test on validation set frequency
187
+ validation_iter: int = 1000
188
+ # logging frequency
189
+ logging_iter: int = 1000
190
+ # maximum training iteration
191
+ max_iter: int = 1000000
192
+ # whether to visualize multistep teacher generation
193
+ visualize_teacher: bool = False
194
+
195
+ # Set the random seed.
196
+ seed: int = 0
197
+ # Validation seed
198
+ val_seed: int | None = None
199
+ # Resume
200
+ resume: bool = True
201
+
202
+ # DDP Parallelism
203
+ ddp: bool = False
204
+ # FSDP Parallelism
205
+ fsdp: bool = False
206
+ # Enable TensorFloat32 (convolution and matmul)
207
+ tf32_enabled: bool = True
208
+
209
+ # Number of gradient accumulation rounds
210
+ grad_accum_rounds: int = 1
211
+
212
+ # Global batch size (if not None, overrides grad_accum_rounds to match the specified batch size)
213
+ batch_size_global: int | None = None
214
+
215
+ # offload other modules to cpu during latent decoding
216
+ offload_module_in_decoding: bool = False
217
+
218
+ # apply cpu offloading in fsdp
219
+ fsdp_cpu_offload: bool = False
220
+ # Fallback minimum number of parameters for FSDP wrapping
221
+ # (10M wraps large models into fairly small shards)
222
+ # The FastGenNetwork should provide a fully_shard method that can be used to shard the network.
223
+ # If we need to shard a different module, we fall back to an auto-sharding policy based on this value.
224
+ fsdp_min_num_params: int = 10_000_000
225
+ # Sharding group size for FSDP. If None, fully shard across all ranks.
226
+ # If set, creates a 2D mesh with (replicate, shard) dimensions.
227
+ fsdp_sharding_group_size: Optional[int] = None
228
+
229
+ # global variables
230
+ global_vars: Optional[dict] = None
231
+ global_vars_val: List[dict | None] = [None]
232
+
233
+ # augment config
234
+ augment_pipe: Optional[DictConfig] = None
235
+
236
+
237
+ @attrs.define(slots=False)
238
+ class BaseConfig:
239
+ # Log config.
240
+ log_config: LogConfig = attrs.field(factory=LogConfig)
241
+
242
+ # Trainer configs.
243
+ trainer: BaseTrainerConfig = attrs.field(factory=BaseTrainerConfig)
244
+
245
+ # Model configs.
246
+ model: BaseModelConfig = attrs.field(factory=BaseModelConfig)
247
+ model_class: DictConfig = L(FastGenModel)(config=None)
248
+
249
+ # Data configs.
250
+ dataloader_train: dict = CIFAR10_Loader_Config
251
+ dataloader_val: Any = None
252
+
253
+ # Eval configs.
254
+ eval: EvalConfig = attrs.field(factory=EvalConfig)
FastGen/fastgen/configs/config_utils.py ADDED
@@ -0,0 +1,317 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import os
5
+ from typing import Any, Optional, List, Dict
6
+ import inspect
7
+ from copy import deepcopy
8
+
9
+ import attrs
10
+ import yaml
11
+ from omegaconf import DictConfig, OmegaConf, ListConfig
12
+ from hydra import compose, initialize
13
+ from hydra.core.config_store import ConfigStore
14
+
15
+ import importlib
16
+ from dataclasses import fields as dataclass_fields
17
+ import attr
18
+ from dataclasses import is_dataclass
19
+ import fastgen.utils.logging_utils as logger
20
+
21
+
22
+ def import_config_from_python_file(config_file: str) -> Any:
23
+ """
24
+ Import a config from a python file.
25
+
26
+ Args:
27
+ config_file (str): The path to the python file.
28
+
29
+ Returns:
30
+ Any: The config object.
31
+ """
32
+
33
+ if not config_file.endswith(".py"):
34
+ raise ValueError("Config file must be a Python file with a .py extension. " f"Received: {config_file}")
35
+
36
+ if not os.path.isfile(config_file):
37
+ raise FileNotFoundError(f"FastGen config file ({config_file}) not found.")
38
+
39
+ # Convert to importable module format.
40
+ config_module = config_file.replace("/", ".").replace(".py", "")
41
+
42
+ # Import the module
43
+ try:
44
+ config = importlib.import_module(config_module)
45
+ except ImportError as e:
46
+ logger.error(f"Failed to import config from python file: {e}")
47
+ raise e
48
+
49
+ return config.create_config()
50
+
51
+
52
+ def config_from_dict(ref_instance: Any, kwargs: Any) -> Any:
53
+ """
54
+ Construct an instance of the same type as ref_instance using the provided dictionary or data or unstructured data
55
+
56
+ Args:
57
+ ref_instance: The reference instance to determine the type and fields when needed
58
+ kwargs: A dictionary of keyword arguments to use for constructing the new instance or primitive data or unstructured data
59
+
60
+ Returns:
61
+ Any: A new instance of the same type as ref_instance constructed using the provided kwargs or the primitive data or unstructured data
62
+
63
+ Raises:
64
+ AssertionError: If the fields do not match or if extra keys are found.
65
+ Exception: If there is an error constructing the new instance.
66
+ """
67
+ is_type = is_attrs_or_dataclass(ref_instance)
68
+ if not is_type:
69
+ return kwargs
70
+ else:
71
+ ref_fields = set(get_fields(ref_instance))
72
+ assert isinstance(kwargs, dict) or isinstance(kwargs, DictConfig), "kwargs must be a dictionary or a DictConfig"
73
+ keys = set(kwargs.keys())
74
+
75
+ # ref_fields must equal to or include all keys
76
+ extra_keys = keys - ref_fields
77
+ assert (
78
+ ref_fields == keys or keys.issubset(ref_fields)
79
+ ), f"Fields mismatch: {ref_fields} != {keys}. Extra keys found: {extra_keys} \n \t when constructing {type(ref_instance)} with {keys}"
80
+
81
+ resolved_kwargs: Dict[str, Any] = {}
82
+ for f in keys:
83
+ resolved_kwargs[f] = config_from_dict(getattr(ref_instance, f), kwargs[f])
84
+ try:
85
+ new_instance = type(ref_instance)(**resolved_kwargs)
86
+ except Exception as e:
87
+ logger.error(f"Error when constructing {type(ref_instance)} with {resolved_kwargs}")
88
+ logger.error(e)
89
+ raise e
90
+ return new_instance
91
+
92
+
93
+ def flatten_dict(nested_dict, parent_key="", sep=".", exclude_key="_target_"):
94
+ """
95
+ Flattens a nested dictionary by joining keys with a separator.
96
+
97
+ Args:
98
+ nested_dict (dict): The dictionary to flatten.
99
+ parent_key (str, optional): The base key to prepend to flattened keys.
100
+ Used in recursion. Defaults to ''.
101
+ sep (str, optional): The separator to use between keys. Defaults to '.'.
102
+
103
+ Returns:
104
+ dict: A new, flattened dictionary.
105
+ """
106
+ items = []
107
+ for key, value in nested_dict.items():
108
+ if key == exclude_key:
109
+ continue
110
+ if value is None:
111
+ continue
112
+
113
+ # Create the new key by joining parent_key and key
114
+ new_key = parent_key + sep + key if parent_key else key
115
+
116
+ # If the value is a non-empty dictionary, recurse
117
+ if isinstance(value, DictConfig):
118
+ value = OmegaConf.to_container(value)
119
+ if isinstance(value, dict) and value:
120
+ items.extend(flatten_dict(value, new_key, sep=sep, exclude_key=exclude_key).items())
121
+ # If the value is not a dictionary, it's a leaf node
122
+ else:
123
+ items.append((new_key, value))
124
+
125
+ return dict(items)
126
+
127
+
128
+ def override_config_with_opts(config: Any, opts: Optional[List[str]] = None) -> Any:
129
+ """
130
+ Override the config with the opts.
131
+
132
+ Args:
133
+ config (Any): The config object.
134
+ opts (Dict[str, Any]): Dict for the overrides.
135
+
136
+ Returns:
137
+ Any: The config object.
138
+ """
139
+
140
+ # Convert Config object to a DictConfig object for Hydra
141
+
142
+ if not isinstance(config, DictConfig):
143
+ config_dict = attrs.asdict(config)
144
+ config_dict = DictConfig(content=config_dict, flags={"allow_objects": True})
145
+ else:
146
+ config_dict = config
147
+
148
+ # Use Hydra to handle overrides
149
+ cs = ConfigStore.instance()
150
+ cs.store(name="config", node=config_dict)
151
+
152
+ if opts is None:
153
+ opts = []
154
+
155
+ if len(opts) > 0 and opts[0] != "-":
156
+ raise ValueError(f"opts must start with '-' to separate from other arguments. Got: {opts}")
157
+ opts = opts[1:]
158
+ with initialize(version_base=None):
159
+ try:
160
+ cfg = compose(config_name="config", overrides=opts)
161
+ except Exception as e:
162
+ raise ValueError(f"Failed to compose config with opts: {e}")
163
+
164
+ OmegaConf.resolve(cfg)
165
+
166
+ config = config_from_dict(config, cfg)
167
+
168
+ return config
169
+
170
+
171
+ def override_config_with_yaml(config: Any, yaml_path: str) -> Any:
172
+ """
173
+ Override the config with the yaml file. **_target_ field is excluded**.
174
+ """
175
+ with open(yaml_path, "r") as f:
176
+ yaml_dict = yaml.safe_load(f)
177
+ yaml_dict = flatten_dict(yaml_dict)
178
+ config_dict = flatten_dict(attrs.asdict(config))
179
+
180
+ # Loose overriding: all mismatched keys are ignored
181
+ override_dict = {k: v for k, v in yaml_dict.items() if k in config_dict}
182
+
183
+ opts = ["-"] + [f"{k}={v}" for k, v in override_dict.items()]
184
+ return override_config_with_opts(config, opts=opts)
185
+
186
+
187
+ def is_attrs_or_dataclass(obj) -> bool:
188
+ """
189
+ Check if the object is an instance of an attrs class or a dataclass.
190
+
191
+ Args:
192
+ obj: The object to check.
193
+
194
+ Returns:
195
+ bool: True if the object is an instance of an attrs class or a dataclass, False otherwise.
196
+ """
197
+ return is_dataclass(obj) or attr.has(type(obj))
198
+
199
+
200
+ def get_fields(obj):
201
+ """
202
+ Get the fields of an attrs class or a dataclass.
203
+
204
+ Args:
205
+ obj: The object to get fields from. Must be an instance of an attrs class or a dataclass.
206
+
207
+ Returns:
208
+ list: A list of field names.
209
+
210
+ Raises:
211
+ ValueError: If the object is neither an attrs class nor a dataclass.
212
+ """
213
+ if is_dataclass(obj):
214
+ return [field.name for field in dataclass_fields(obj)]
215
+ elif attr.has(type(obj)):
216
+ return [field.name for field in attr.fields(type(obj))]
217
+ else:
218
+ raise ValueError("The object is neither an attrs class nor a dataclass.")
219
+
220
+
221
+ def serialize_config(
222
+ config: Any,
223
+ return_type: str = "dict",
224
+ path: str | bytes | None = None,
225
+ filename: str = "config.yaml",
226
+ include_defaults: bool = False,
227
+ ) -> Dict[str, Any] | str:
228
+ """
229
+ Serialize a config (BaseConfig or DictConfig) to various formats.
230
+
231
+ Args:
232
+ config: The config to serialize (BaseConfig attrs object or DictConfig).
233
+ return_type: Output format - "dict" (plain dict), "yaml" (YAML string), or "file" (save to file).
234
+ path: Directory path to save the file. Required if return_type is "file".
235
+ filename: Name of the file to save. Only used if return_type is "file".
236
+ include_defaults: If True, add default parameter values from _target_ classes.
237
+
238
+ Returns:
239
+ Dict[str, Any] if return_type is "dict"
240
+ str (YAML) if return_type is "yaml" or "file"
241
+
242
+ Raises:
243
+ ValueError: If return_type is "file" but path is not provided.
244
+ """
245
+ if return_type == "file" and path is None:
246
+ raise ValueError("path must be provided when return_type is 'file'")
247
+
248
+ # Deep copy to avoid modifying original
249
+ config = deepcopy(config)
250
+
251
+ # Normalize to DictConfig with object support
252
+ if not isinstance(config, DictConfig):
253
+ config_dict = attrs.asdict(config)
254
+ config_omegaconf = DictConfig(content=config_dict, flags={"allow_objects": True})
255
+ else:
256
+ config_omegaconf = config
257
+
258
+ def is_serializable(item) -> bool:
259
+ try:
260
+ OmegaConf.to_yaml(item)
261
+ return True
262
+ except Exception:
263
+ return False
264
+
265
+ def get_default_params(cls_or_func):
266
+ if callable(cls_or_func):
267
+ signature = inspect.signature(cls_or_func)
268
+ else:
269
+ signature = inspect.signature(cls_or_func.__init__)
270
+ params = signature.parameters
271
+ return {name: param.default for name, param in params.items() if param.default is not inspect.Parameter.empty}
272
+
273
+ def process_config(conf):
274
+ if isinstance(conf, DictConfig):
275
+ for key, value in conf.items():
276
+ if isinstance(value, (DictConfig, ListConfig)):
277
+ # Optionally add default params from _target_ classes
278
+ if include_defaults:
279
+ try:
280
+ if "_target_" in value:
281
+ default_params = get_default_params(value["_target_"])
282
+ for default_key, default_v in default_params.items():
283
+ if default_key not in value:
284
+ value[default_key] = default_v
285
+ except Exception as e:
286
+ logger.error(f"Failed to add default argument values: {e}")
287
+ process_config(value)
288
+ else:
289
+ if not is_serializable(value) and value is not None:
290
+ conf[key] = str(value)
291
+ elif isinstance(conf, ListConfig):
292
+ for i, item in enumerate(conf):
293
+ if isinstance(item, (DictConfig, ListConfig)):
294
+ process_config(item)
295
+ else:
296
+ if not is_serializable(item) and item is not None:
297
+ conf[i] = str(item)
298
+ else:
299
+ raise NotImplementedError("Input config must be a DictConfig or ListConfig.")
300
+ return conf
301
+
302
+ config_omegaconf = process_config(config_omegaconf)
303
+ result_dict: Dict[str, Any] = OmegaConf.to_container(config_omegaconf, resolve=True) # type: ignore
304
+
305
+ if return_type == "dict":
306
+ return result_dict
307
+
308
+ # For yaml and file, convert to YAML string
309
+ yaml_str = yaml.dump(result_dict, default_flow_style=False, sort_keys=True)
310
+
311
+ if return_type == "file":
312
+ os.makedirs(path, exist_ok=True) # type: ignore
313
+ with open(f"{path}/{filename}", "w") as f:
314
+ f.write(yaml_str)
315
+ logger.info(f"Config is saved at {path}/{filename}")
316
+
317
+ return yaml_str
FastGen/fastgen/configs/data.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import os
5
+
6
+ from fastgen.datasets.class_cond_dataloader import ImageLoader
7
+ from fastgen.datasets.wds_dataloaders import (
8
+ WDSLoader,
9
+ ImageWDSLoader,
10
+ VideoWDSLoader,
11
+ )
12
+
13
+ from fastgen.utils import LazyCall as L
14
+
15
+ OUTPUT_ROOT = os.environ.get("FASTGEN_OUTPUT_ROOT", "FASTGEN_OUTPUT")
16
+ DATA_ROOT_DIR = os.getenv("DATA_ROOT_DIR", f"{OUTPUT_ROOT}/DATA")
17
+ S3_DATA_ROOT_DIR = os.getenv("DATA_ROOT_DIR", "s3://data")
18
+
19
+ # ################################################################################
20
+ # Generic Loaders (for config templates - override datatags for actual use)
21
+ # ################################################################################
22
+ # See fastgen/datasets/README.md for more details.
23
+
24
+ ImageLoaderConfig = L(ImageWDSLoader)(
25
+ datatags=["WDS:/path/to/images"],
26
+ batch_size=32,
27
+ key_map={"real": "jpg", "condition": "txt"},
28
+ presets_map={"neg_condition": "empty_string"},
29
+ input_res=512,
30
+ )
31
+
32
+ ImageLatentLoaderConfig = L(WDSLoader)(
33
+ datatags=["WDS:/path/to/image_latents"],
34
+ batch_size=32,
35
+ key_map={"real": "latent.pth", "condition": "txt_emb.pth"},
36
+ # Negative condition embedding loaded from a shared file (same for all samples)
37
+ files_map={"neg_condition": "/path/to/neg_prompt_emb.npy"},
38
+ )
39
+
40
+ VideoLoaderConfig = L(VideoWDSLoader)(
41
+ datatags=["WDS:/path/to/videos"],
42
+ batch_size=2,
43
+ key_map={"real": "mp4", "condition": "txt"},
44
+ presets_map={"neg_condition": "neg_prompt_wan"},
45
+ sequence_length=81,
46
+ img_size=(832, 480),
47
+ )
48
+
49
+ VideoLatentLoaderConfig = L(WDSLoader)(
50
+ datatags=["WDS:/path/to/video_latents"],
51
+ batch_size=2,
52
+ key_map={"real": "latent.pth", "condition": "txt_emb.pth"},
53
+ # Negative condition embedding loaded from a shared file (same for all samples)
54
+ files_map={"neg_condition": "/path/to/neg_prompt_emb.npy"},
55
+ # NOTE: For v2v tasks, add condition latent (e.g., depth) to key_map:
56
+ # key_map={"real": "latent.pth", "condition": "txt_emb.pth", "depth_latent": "depth_latent.pth"}
57
+ )
58
+
59
+ # ################################################################################
60
+ # Generic KD Loaders (for paired/path data)
61
+ # ################################################################################
62
+ # See fastgen/methods/knowledge_distillation/README.md for more details.
63
+
64
+ # For single-step KD: provides (real, noise, condition) pairs
65
+ # Data requirements: {"real": clean, "noise": noise, "condition": cond}
66
+ PairLoaderConfig = L(WDSLoader)(
67
+ datatags=["WDS:/path/to/pairs"],
68
+ batch_size=2,
69
+ key_map={"real": "latent.pth", "noise": "noise.pth", "condition": "txt_emb.pth"},
70
+ )
71
+
72
+ # For multi-step KD: provides (real, path, condition) with denoising trajectory
73
+ # Data requirements: {"real": clean, "path": [B, steps, C, ...], "condition": cond}
74
+ # path contains intermediate denoising steps (typically 4 steps)
75
+ PathLoaderConfig = L(WDSLoader)(
76
+ datatags=["WDS:/path/to/paths"],
77
+ batch_size=2,
78
+ key_map={"real": "latent.pth", "path": "path.pth", "condition": "txt_emb.pth"},
79
+ )
80
+
81
+ # ################################################################################
82
+ # Specific Datasets
83
+ # ################################################################################
84
+
85
+ CIFAR10_Loader_Config = L(ImageLoader)(
86
+ dataset_path=f"{DATA_ROOT_DIR}/cifar10/cifar10-32x32.zip",
87
+ s3_path=f"{S3_DATA_ROOT_DIR}/cifar10/cifar10-32x32.zip",
88
+ use_labels=True,
89
+ cache=True,
90
+ batch_size=128,
91
+ shuffle=True,
92
+ sampler_start_idx=None,
93
+ )
94
+
95
+ ImageNet64_Loader_Config = L(ImageLoader)(
96
+ dataset_path=f"{DATA_ROOT_DIR}/imagenet-64/imagenet-64x64.zip",
97
+ s3_path=f"{S3_DATA_ROOT_DIR}/imagenet-64/imagenet-64x64.zip",
98
+ use_labels=True,
99
+ cache=True,
100
+ batch_size=32,
101
+ shuffle=True,
102
+ sampler_start_idx=None,
103
+ )
104
+
105
+ ImageNet256_Loader_Config = L(ImageLoader)(
106
+ dataset_path=f"{DATA_ROOT_DIR}/imagenet-256/imagenet_256_sd.zip",
107
+ s3_path=f"{S3_DATA_ROOT_DIR}/imagenet-256/imagenet_256_sd.zip",
108
+ use_labels=True,
109
+ cache=True,
110
+ batch_size=32,
111
+ shuffle=True,
112
+ sampler_start_idx=None,
113
+ )
114
+
115
+ ImageNet64_EDMV2_Loader_Config = L(ImageLoader)(
116
+ dataset_path=f"{DATA_ROOT_DIR}/imagenet-64/imagenet-64x64-edmv2.zip",
117
+ s3_path=f"{S3_DATA_ROOT_DIR}/imagenet-64/imagenet-64x64-edmv2.zip",
118
+ use_labels=True,
119
+ cache=True,
120
+ batch_size=32,
121
+ shuffle=True,
122
+ sampler_start_idx=None,
123
+ )
FastGen/fastgen/configs/data_dummy.py ADDED
@@ -0,0 +1,67 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Dummy data loader for data-free distillation training (e.g., Self-Forcing without GAN)."""
5
+
6
+ import torch
7
+ from fastgen.utils import LazyCall as L
8
+
9
+
10
+ class DummyVideoLoader:
11
+ """A dummy dataloader that generates random tensors for data-free training.
12
+
13
+ Use this when running Self-Forcing or other distillation methods without real data.
14
+ Requires setting gan_loss_weight_gen=0 since there's no real data for discriminator.
15
+ """
16
+
17
+ def __init__(
18
+ self,
19
+ batch_size: int = 1,
20
+ input_shape: list = [16, 21, 60, 104], # [C, T, H, W]
21
+ text_seq_len: int = 512,
22
+ text_dim: int = 4096,
23
+ device: str = "cuda",
24
+ dtype: str = "bfloat16",
25
+ ):
26
+ self.batch_size = batch_size
27
+ self.input_shape = input_shape
28
+ self.text_seq_len = text_seq_len
29
+ self.text_dim = text_dim
30
+ self.device = device
31
+ self._dtype = getattr(torch, dtype) if isinstance(dtype, str) else dtype
32
+
33
+ def __iter__(self):
34
+ return self
35
+
36
+ def __next__(self):
37
+ return {
38
+ "real": torch.randn(
39
+ self.batch_size, *self.input_shape,
40
+ device=self.device, dtype=self._dtype
41
+ ),
42
+ "condition": torch.randn(
43
+ self.batch_size, self.text_seq_len, self.text_dim,
44
+ device=self.device, dtype=self._dtype
45
+ ),
46
+ "neg_condition": torch.zeros(
47
+ self.batch_size, self.text_seq_len, self.text_dim,
48
+ device=self.device, dtype=self._dtype
49
+ ),
50
+ }
51
+
52
+
53
+ # Config for WAN T2V 480p (832x480, 81 frames -> 21 latent frames)
54
+ DummyVideoLoaderConfig = L(DummyVideoLoader)(
55
+ batch_size=1,
56
+ input_shape=[16, 21, 60, 104], # [C, T, H, W] latent shape for 480p
57
+ text_seq_len=512,
58
+ text_dim=4096,
59
+ )
60
+
61
+ # Config for WAN T2V 720p (1280x720, 81 frames -> 21 latent frames)
62
+ DummyVideoLoader720pConfig = L(DummyVideoLoader)(
63
+ batch_size=1,
64
+ input_shape=[16, 21, 90, 160], # [C, T, H, W] latent shape for 720p
65
+ text_seq_len=512,
66
+ text_dim=4096,
67
+ )
FastGen/fastgen/configs/discriminator.py ADDED
@@ -0,0 +1,106 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from omegaconf import DictConfig
5
+
6
+ from fastgen.utils import LazyCall as L
7
+ from fastgen.networks.discriminators import (
8
+ Discriminator_EDM,
9
+ Discriminator_SD15,
10
+ Discriminator_SDXL,
11
+ Discriminator_ImageDiT,
12
+ Discriminator_VideoDiT,
13
+ )
14
+
15
+ Discriminator_EDM_CIFAR10_Config: DictConfig = L(Discriminator_EDM)(
16
+ feature_indices={0, 1, 2},
17
+ all_res=[32, 16, 8],
18
+ in_channels=256,
19
+ )
20
+
21
+ Discriminator_EDM_ImageNet64_Config: DictConfig = L(Discriminator_EDM)(
22
+ feature_indices=None,
23
+ all_res=[64, 32, 16, 8],
24
+ in_channels=768,
25
+ )
26
+
27
+ Discriminator_SD15_Res512_Config: DictConfig = L(Discriminator_SD15)(
28
+ feature_indices=None,
29
+ all_res=[32, 16, 8, 8, 8],
30
+ in_channels=1280,
31
+ )
32
+
33
+ Discriminator_SDXL_Res512_Config: DictConfig = L(Discriminator_SDXL)(
34
+ feature_indices=None,
35
+ all_res=[32, 16, 16, 16],
36
+ in_channels=1280,
37
+ )
38
+
39
+ Discriminator_SDXL_Res1024_Config: DictConfig = L(Discriminator_SDXL)(
40
+ feature_indices=None,
41
+ all_res=[64, 32, 32, 32],
42
+ in_channels=1280,
43
+ )
44
+
45
+ # Flux: hidden_dim=3072, 19 joint blocks + 38 single blocks = 57 total
46
+ Discriminator_Flux_Config: DictConfig = L(Discriminator_ImageDiT)(
47
+ feature_indices=None,
48
+ num_blocks=57, # 19 joint + 38 single blocks
49
+ inner_dim=3072, # Flux hidden dimension
50
+ )
51
+
52
+ # 2B patchify: spatial-2, temporal-1; inner_dim=1920; layer=30
53
+ Discriminator_CogVideoX2B_Config = L(Discriminator_VideoDiT)(
54
+ feature_indices=None,
55
+ num_blocks=30,
56
+ disc_type="dit_simple_conv3d",
57
+ inner_dim=1920 // 4,
58
+ )
59
+
60
+ # 5B patchify: spatial-2, temporal-1; inner_dim=3072; layer=42
61
+ Discriminator_CogVideoX5B_Config = L(Discriminator_VideoDiT)(
62
+ feature_indices=None,
63
+ num_blocks=42,
64
+ disc_type="dit_simple_conv3d",
65
+ inner_dim=3072 // 4,
66
+ )
67
+
68
+ # 1.3B patchify: spatial-2, temporal-1; inner_dim=1536; layer=30
69
+ Discriminator_Wan_1_3B_Config: DictConfig = L(Discriminator_VideoDiT)(
70
+ feature_indices=None,
71
+ num_blocks=30,
72
+ disc_type="dit_simple_conv3d",
73
+ inner_dim=1536 // 4,
74
+ )
75
+
76
+ # 14B patchify: spatial-2, temporal-1; inner_dim=5120; layer=40
77
+ Discriminator_Wan_14B_Config: DictConfig = L(Discriminator_VideoDiT)(
78
+ feature_indices=None,
79
+ num_blocks=40,
80
+ disc_type="dit_simple_conv3d",
81
+ inner_dim=5120 // 4,
82
+ )
83
+
84
+ # 5B patchify: spatial-2, temporal-1; inner_dim=3072; layer=30
85
+ Discriminator_Wan22_5B_Config: DictConfig = L(Discriminator_VideoDiT)(
86
+ feature_indices=None,
87
+ num_blocks=30,
88
+ disc_type="dit_simple_conv3d",
89
+ inner_dim=3072 // 4,
90
+ )
91
+
92
+ # Cosmos Predict2.5-2B: patchify spatial-2, temporal-1; inner_dim=2048; layer=28
93
+ Discriminator_CosmosPredict2_2B_Config: DictConfig = L(Discriminator_VideoDiT)(
94
+ feature_indices=None,
95
+ num_blocks=28,
96
+ disc_type="dit_simple_conv3d",
97
+ inner_dim=2048, # Must match model's inner_dim for Cosmos
98
+ )
99
+
100
+ # Cosmos Predict2.5-14B: patchify spatial-2, temporal-1; inner_dim=5120; layer=36
101
+ Discriminator_CosmosPredict2_14B_Config: DictConfig = L(Discriminator_VideoDiT)(
102
+ feature_indices=None,
103
+ num_blocks=36,
104
+ disc_type="dit_simple_conv3d",
105
+ inner_dim=5120, # Must match model's inner_dim for Cosmos
106
+ )
FastGen/fastgen/configs/experiments/CogVideoX/config_dmd2.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ from fastgen.configs.discriminator import Discriminator_CogVideoX2B_Config
6
+ import fastgen.configs.methods.config_dmd2 as config_dmd2_default
7
+ from fastgen.configs.data import VideoLatentLoaderConfig
8
+ from fastgen.configs.net import CogVideoXConfig
9
+
10
+ """ Configs for the DMD2 model on CogVideoX model. """
11
+
12
+
13
+ def create_config():
14
+ config = config_dmd2_default.create_config()
15
+
16
+ config.model.net_optimizer.lr = 1e-5
17
+ config.model.discriminator_optimizer.lr = 1e-5
18
+ config.model.fake_score_optimizer.lr = 1e-5
19
+
20
+ config.model.input_shape = [16, 13, 60, 90]
21
+ config.model.discriminator = Discriminator_CogVideoX2B_Config
22
+ config.model.discriminator.feature_indices = [15, 22, 29]
23
+ config.model.gan_loss_weight_gen = 0.03
24
+ config.model.net = CogVideoXConfig
25
+ config.model.guidance_scale = 6.0
26
+ config.model.enable_preprocessors = False
27
+
28
+ config.model.sample_t_cfg.time_dist_type = "uniform"
29
+ config.model.sample_t_cfg.min_t = 0.001
30
+ config.model.sample_t_cfg.max_t = 0.999
31
+
32
+ config.model.gan_use_same_t_noise = True
33
+ config.model.fake_score_pred_type = "x0"
34
+ config.model.student_sample_type = "ode"
35
+
36
+ # setting for 4-step training
37
+ config.model.student_sample_steps = 4
38
+ config.model.sample_t_cfg.t_list = [0.999, 0.937, 0.833, 0.624, 0.0]
39
+
40
+ config.dataloader_train = VideoLatentLoaderConfig
41
+ config.dataloader_train.batch_size = 2
42
+
43
+ config.trainer.max_iter = 10000
44
+ config.trainer.logging_iter = 100
45
+ config.trainer.save_ckpt_iter = 500
46
+
47
+ config.log_config.group = "CogVideoX_dmd2"
48
+ return config
FastGen/fastgen/configs/experiments/CogVideoX/config_kd.py ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+
5
+ import fastgen.configs.methods.config_kd as config_kd_default
6
+ from fastgen.configs.data import PairLoaderConfig
7
+ from fastgen.configs.net import CogVideoXConfig
8
+
9
+ """ Configs for the KD model on CogVideoX model. """
10
+
11
+
12
+ def create_config():
13
+ config = config_kd_default.create_config()
14
+
15
+ config.trainer.max_iter = 6000
16
+ config.trainer.logging_iter = 100
17
+ config.trainer.save_ckpt_iter = 500
18
+
19
+ config.model.net_optimizer.lr = 1e-4
20
+
21
+ config.model.input_shape = [16, 13, 60, 90]
22
+ config.model.net = CogVideoXConfig
23
+ config.model.enable_preprocessors = False
24
+ config.model.precision = "bfloat16"
25
+
26
+ config.dataloader_train = PairLoaderConfig
27
+ config.dataloader_train.batch_size = 2
28
+
29
+ config.log_config.group = "CogVideoX_kd"
30
+ return config
FastGen/fastgen/configs/experiments/CogVideoX/config_sft.py ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import fastgen.configs.methods.config_sft as config_sft_default
5
+ from fastgen.configs.data import VideoLatentLoaderConfig
6
+ from fastgen.configs.net import CogVideoXConfig
7
+
8
+ """ Configs for the SFT model on CogVideoX model. """
9
+
10
+
11
+ def create_config():
12
+ config = config_sft_default.create_config()
13
+
14
+ config.trainer.logging_iter = 500
15
+
16
+ config.model.net_optimizer.lr = 5e-5
17
+
18
+ config.model.guidance_scale = 6.0
19
+ config.model.student_sample_steps = 50
20
+
21
+ config.model.sample_t_cfg.time_dist_type = "uniform"
22
+ config.model.sample_t_cfg.min_t = 0.001
23
+ config.model.sample_t_cfg.max_t = 0.999
24
+
25
+ # CogVideoX latent shape: [C, T, H, W] = [16, 13, 60, 90]
26
+ # Corresponds to 49 frames at 480x720 resolution after VAE encoding
27
+ config.model.input_shape = [16, 13, 60, 90]
28
+ config.model.net = CogVideoXConfig
29
+ config.model.enable_preprocessors = False # Using precomputed latents
30
+
31
+ config.dataloader_train = VideoLatentLoaderConfig
32
+ config.dataloader_train.batch_size = 2
33
+
34
+ config.log_config.group = "CogVideoX_sft"
35
+ return config
FastGen/fastgen/configs/experiments/CogVideoX/config_sft_5b.py ADDED
@@ -0,0 +1,40 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ import fastgen.configs.methods.config_sft as config_sft_default
5
+ from fastgen.configs.data import VideoLatentLoaderConfig
6
+ from fastgen.configs.net import CogVideoX5BConfig
7
+
8
+ """ Configs for the SFT model on CogVideoX-5B model. """
9
+
10
+
11
+ def create_config():
12
+ config = config_sft_default.create_config()
13
+
14
+ config.trainer.logging_iter = 500
15
+
16
+ # Slightly lower LR for larger 5B model
17
+ config.model.net_optimizer.lr = 2e-5
18
+
19
+ config.model.guidance_scale = 6.0
20
+ config.model.student_sample_steps = 50
21
+
22
+ config.model.sample_t_cfg.time_dist_type = "uniform"
23
+ config.model.sample_t_cfg.min_t = 0.001
24
+ config.model.sample_t_cfg.max_t = 0.999
25
+
26
+ config.model.precision = "bfloat16"
27
+
28
+ # CogVideoX latent shape: [C, T, H, W] = [16, 13, 60, 90]
29
+ # Corresponds to 49 frames at 480x720 resolution after VAE encoding
30
+ # Same as 2B variant - they share the same VAE
31
+ config.model.input_shape = [16, 13, 60, 90]
32
+ config.model.net = CogVideoX5BConfig
33
+ config.model.enable_preprocessors = False # Using precomputed latents
34
+
35
+ config.dataloader_train = VideoLatentLoaderConfig
36
+ # Reduced batch size due to larger model
37
+ config.dataloader_train.batch_size = 1
38
+
39
+ config.log_config.group = "CogVideoX5B_sft"
40
+ return config
FastGen/fastgen/configs/experiments/CosmosPredict2/config_dmd2.py ADDED
@@ -0,0 +1,69 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """DMD2 config for Cosmos-Predict2.5-2B model."""
5
+
6
+ import fastgen.configs.methods.config_dmd2 as config_dmd2_default
7
+ from fastgen.configs.data import VideoLoaderConfig
8
+ from fastgen.configs.discriminator import Discriminator_CosmosPredict2_2B_Config
9
+ from fastgen.configs.net import CosmosPredict2_2B_Config, CosmosPredict2_2B_Aggressive_Config, CKPT_ROOT_DIR
10
+
11
+
12
+ def create_config():
13
+ config = config_dmd2_default.create_config()
14
+
15
+ config.trainer.max_iter = 10000
16
+ config.trainer.logging_iter = 100
17
+ config.trainer.save_ckpt_iter = 500
18
+
19
+ # Optimizer settings
20
+ config.model.net_optimizer.lr = 1e-5
21
+ config.model.discriminator_optimizer.lr = 1e-5
22
+ config.model.fake_score_optimizer.lr = 1e-5
23
+
24
+ config.model.precision = "bfloat16"
25
+
26
+ # Latent shape: [C, T_latent, H_latent, W_latent]
27
+ # Cosmos VAE uses 4x8x8 compression (time, height, width)
28
+ # Small resolution for testing: 320x176 video -> 40x22 latent, 93 frames -> 24 latent
29
+ # Full 720p: [16, 24, 88, 160] (1280x704 @ 93 frames)
30
+ # config.model.input_shape = [16, 24, 22, 40] # cthw - 256p (320x176 video)
31
+ config.model.input_shape = [16, 24, 60, 104] # cthw - 480p, 93 frames
32
+ # config.model.input_shape = [16, 24, 88, 160] # cthw - full 720p, 93 frames
33
+
34
+ # Network and discriminator
35
+ config.model.net = CosmosPredict2_2B_Config
36
+ config.model.discriminator = Discriminator_CosmosPredict2_2B_Config
37
+ config.model.discriminator.disc_type = "multiscale_down_mlp_large"
38
+ config.model.discriminator.feature_indices = [13, 20, 27]
39
+ # Teacher uses AGGRESSIVE SAC for memory savings
40
+ config.model.teacher = CosmosPredict2_2B_Aggressive_Config
41
+
42
+ # DMD2 settings
43
+ config.model.gan_loss_weight_gen = 0.03
44
+ config.model.gan_use_same_t_noise = True
45
+ config.model.fake_score_pred_type = "x0"
46
+ config.model.student_sample_type = "ode"
47
+ config.model.guidance_scale = 3.0
48
+ config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cosmos_predict2/Cosmos-Predict2.5-2B/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt"
49
+
50
+ # Timestep sampling
51
+ config.model.sample_t_cfg.time_dist_type = "shifted"
52
+ config.model.sample_t_cfg.min_t = 0.001
53
+ config.model.sample_t_cfg.max_t = 0.999
54
+
55
+ # setting for 4-step training
56
+ config.model.student_sample_steps = 4
57
+ config.model.sample_t_cfg.t_list = [0.999, 0.937, 0.833, 0.624, 0.0]
58
+
59
+ # Dataloader settings
60
+ config.dataloader_train = VideoLoaderConfig
61
+ config.dataloader_train.batch_size = 1
62
+
63
+ # Dataloader img_size = (W, H) = (latent_W * 8, latent_H * 8)
64
+ config.dataloader_train.img_size = (config.model.input_shape[-1] * 8, config.model.input_shape[-2] * 8)
65
+ config.dataloader_train.sequence_length = (config.model.input_shape[1] - 1) * 4 + 1
66
+
67
+ config.log_config.group = "cosmos_predict2_dmd2"
68
+
69
+ return config
FastGen/fastgen/configs/experiments/CosmosPredict2/config_dmd2_14b.py ADDED
@@ -0,0 +1,37 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """DMD2 config for Cosmos-Predict2.5-14B model."""
5
+
6
+ import fastgen.configs.experiments.CosmosPredict2.config_dmd2 as config_dmd2_base
7
+ from fastgen.configs.discriminator import Discriminator_CosmosPredict2_14B_Config
8
+ from fastgen.configs.net import CosmosPredict2_14B_Config, CosmosPredict2_14B_Aggressive_Config, CKPT_ROOT_DIR
9
+
10
+
11
+ def create_config():
12
+ config = config_dmd2_base.create_config()
13
+ config.trainer.fsdp_cpu_offload = True
14
+
15
+ # Latent shape: [C, T_latent, H_latent, W_latent]
16
+ # Cosmos VAE uses 4x8x8 compression (time, height, width)
17
+ # Small resolution for testing: 320x176 video -> 40x22 latent, 93 frames -> 24 latent
18
+ # Full 720p: [16, 24, 88, 160] (1280x704 @ 93 frames)
19
+ # config.model.input_shape = [16, 24, 22, 40] # cthw - 256p (320x176 video)
20
+ config.model.input_shape = [16, 24, 60, 104] # cthw - 480p, 93 frames
21
+ # config.model.input_shape = [16, 24, 88, 160] # cthw - full 720p, 93 frames
22
+
23
+ # Network and discriminator for 14B
24
+ config.model.net = CosmosPredict2_14B_Config
25
+ config.model.discriminator = Discriminator_CosmosPredict2_14B_Config
26
+ # Teacher uses AGGRESSIVE SAC for memory savings
27
+ config.model.teacher = CosmosPredict2_14B_Aggressive_Config
28
+
29
+ config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cosmos_predict2/Cosmos-Predict2.5-14B/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt"
30
+
31
+ # Dataloader img_size = (W, H) = (latent_W * 8, latent_H * 8)
32
+ config.dataloader_train.img_size = (config.model.input_shape[-1] * 8, config.model.input_shape[-2] * 8)
33
+ config.dataloader_train.sequence_length = (config.model.input_shape[1] - 1) * 4 + 1
34
+
35
+ config.log_config.group = "cosmos_predict2_14b_dmd2"
36
+
37
+ return config
FastGen/fastgen/configs/experiments/CosmosPredict2/config_dmd2_v2w.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """DMD2 config for Cosmos-Predict2.5-2B video2world model."""
5
+
6
+ import fastgen.configs.experiments.CosmosPredict2.config_dmd2 as config_dmd2_base
7
+
8
+
9
+ def create_config():
10
+ config = config_dmd2_base.create_config()
11
+
12
+ # Network for v2w
13
+ config.model.net.is_video2world = True
14
+ config.model.net.num_conditioning_frames = 1
15
+
16
+ config.log_config.group = "cosmos_predict2_dmd2_v2w"
17
+
18
+ return config
FastGen/fastgen/configs/experiments/CosmosPredict2/config_sft.py ADDED
@@ -0,0 +1,48 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Configs for SFT on Cosmos-Predict2.5-2B model."""
5
+
6
+ import fastgen.configs.methods.config_sft as config_sft_default
7
+ from fastgen.configs.data import VideoLoaderConfig
8
+ from fastgen.configs.net import CosmosPredict2_2B_Config, CKPT_ROOT_DIR
9
+
10
+
11
+ def create_config():
12
+ config = config_sft_default.create_config()
13
+ config.trainer.max_iter = 10000
14
+ config.trainer.logging_iter = 100
15
+ config.trainer.save_ckpt_iter = 500
16
+
17
+ config.model.net_optimizer.lr = 1e-5
18
+
19
+ config.model.sample_t_cfg.time_dist_type = "uniform"
20
+ config.model.sample_t_cfg.min_t = 0.001
21
+ config.model.sample_t_cfg.max_t = 0.999
22
+
23
+ config.model.precision = "bfloat16"
24
+
25
+ # Latent shape: [C, T_latent, H_latent, W_latent]
26
+ # Cosmos VAE uses 4x8x8 compression (time, height, width)
27
+ # Small resolution for testing: 320x176 video -> 40x22 latent, 21 frames -> 6 latent
28
+ # Full 720p: [16, 24, 88, 160] (1280x704 @ 93 frames)
29
+ # config.model.input_shape = [16, 24, 24, 40] # cthw - 256p
30
+ # config.model.input_shape = [16, 21, 60, 104] # cthw - 480p, 81 frames
31
+ config.model.input_shape = [16, 24, 60, 104] # cthw - 480p
32
+ # config.model.input_shape = [16, 24, 88, 160] # cthw - full 720p
33
+
34
+ config.model.net = CosmosPredict2_2B_Config
35
+ config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cosmos_predict2/Cosmos-Predict2.5-2B/base/post-trained/81edfebe-bd6a-4039-8c1d-737df1a790bf_ema_bf16.pt"
36
+ config.model.guidance_scale = 3.0
37
+ config.model.student_sample_steps = 35
38
+
39
+ config.dataloader_train = VideoLoaderConfig
40
+ config.dataloader_train.batch_size = 1
41
+
42
+ # Dataloader img_size = (W, H) = (latent_W * 8, latent_H * 8)
43
+ config.dataloader_train.img_size = (config.model.input_shape[-1] * 8, config.model.input_shape[-2] * 8)
44
+ config.dataloader_train.sequence_length = (config.model.input_shape[1] - 1) * 4 + 1
45
+
46
+ config.log_config.group = "cosmos_predict2_sft"
47
+
48
+ return config
FastGen/fastgen/configs/experiments/CosmosPredict2/config_sft_14b.py ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Configs for SFT on Cosmos-Predict2.5-14B model."""
5
+
6
+ import fastgen.configs.experiments.CosmosPredict2.config_sft as config_sft_base
7
+ from fastgen.configs.net import CosmosPredict2_14B_Config, CKPT_ROOT_DIR
8
+
9
+
10
+ def create_config():
11
+ config = config_sft_base.create_config()
12
+
13
+ # Network for 14B
14
+ config.model.net = CosmosPredict2_14B_Config
15
+ config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cosmos_predict2/Cosmos-Predict2.5-14B/base/post-trained/e21d2a49-4747-44c8-ba44-9f6f9243715f_ema_bf16.pt"
16
+
17
+ config.log_config.group = "cosmos_predict2_14b_sft"
18
+
19
+ return config
FastGen/fastgen/configs/experiments/CosmosPredict2/config_sft_v2w.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ """Configs for SFT on Cosmos-Predict2.5-2B video2world model."""
5
+
6
+ import fastgen.configs.experiments.CosmosPredict2.config_sft as config_sft_base
7
+
8
+
9
+ def create_config():
10
+ config = config_sft_base.create_config()
11
+
12
+ # Network for v2w
13
+ config.model.net.is_video2world = True
14
+ config.model.net.num_conditioning_frames = 1
15
+
16
+ config.log_config.group = "cosmos_predict2_sft_v2w"
17
+
18
+ return config
FastGen/fastgen/configs/experiments/DiT/config_mf_b.py ADDED
@@ -0,0 +1,70 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from omegaconf import DictConfig
5
+
6
+ import fastgen.configs.methods.config_mean_flow as config_mf_default
7
+ from fastgen.configs.data import ImageNet256_Loader_Config
8
+ from fastgen.configs.net import DiT_IN256_B_Config
9
+ from fastgen.configs.callbacks import EMA_CONST_CALLBACKS
10
+
11
+ """ Configs for the MeanFlow model, on DiT-XL and ImageNet-256 dataset. """
12
+
13
+
14
+ def create_config():
15
+ config = config_mf_default.create_config()
16
+
17
+ # model
18
+ config.model.input_shape = [4, 32, 32]
19
+ config.model.precision_amp = "bfloat16"
20
+ config.model.precision_amp_jvp = "float32"
21
+ config.model.cond_dropout_prob = 0.1
22
+ config.model.guidance_mixture_ratio = 0.5
23
+
24
+ config.model.sample_t_cfg.time_dist_type = "logitnormal"
25
+ config.model.sample_t_cfg.train_p_mean = -0.4
26
+ config.model.sample_t_cfg.train_p_std = 1.0
27
+ config.model.sample_t_cfg.min_t = 0.0
28
+ config.model.sample_t_cfg.max_t = 0.999
29
+ config.model.sample_t_cfg.r_sample_ratio = 0.25
30
+
31
+ config.model.loss_config.norm_method = "poly_1.0"
32
+ config.model.loss_config.norm_const = 1.0
33
+ config.model.loss_config.tangent_warmup_steps = 0
34
+ config.model.loss_config.loss_type = "l2"
35
+
36
+ config.model.net = DiT_IN256_B_Config
37
+ # remove the additional 1000 factor in JVP
38
+ config.model.net.scale_t = False
39
+ config.model.net.r_timestep = True
40
+ config.model.net.time_cond_type = "diff"
41
+
42
+ config.model.net_optimizer.optim_type = "adam"
43
+ config.model.net_optimizer.lr = 1e-4
44
+ config.model.net_optimizer.betas = (0.9, 0.95)
45
+ config.model.net_optimizer.weight_decay = 0.0
46
+
47
+ # ema
48
+ config.model.use_ema = ["ema_9999", "ema_99995", "ema_9996"]
49
+ config.trainer.callbacks = DictConfig(
50
+ {k: v for k, v in config.trainer.callbacks.items() if not k.startswith("ema")}
51
+ )
52
+ config.trainer.callbacks.update(EMA_CONST_CALLBACKS)
53
+
54
+ # Recommended setting for 2-step:
55
+ # config.model.sample_t_cfg.t_list = [0.999, 0.5, 0.0]
56
+ # config.model.student_sample_steps = 2
57
+
58
+ # dataloader
59
+ config.dataloader_train = ImageNet256_Loader_Config
60
+ config.dataloader_train.batch_size = 32
61
+
62
+ # trainer
63
+ config.trainer.batch_size_global = 1024
64
+ config.trainer.max_iter = 1200000
65
+ config.trainer.save_ckpt_iter = 50000
66
+ config.trainer.logging_iter = 10000
67
+
68
+ config.log_config.group = "imagenet256"
69
+
70
+ return config
FastGen/fastgen/configs/experiments/DiT/config_mf_xl.py ADDED
@@ -0,0 +1,26 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from fastgen.configs.net import DiT_IN256_XL_Config
5
+ import fastgen.configs.experiments.DiT.config_mf_b as config_mf_default
6
+
7
+
8
+ """ Configs for the MeanFlow model, on DiT-XL and ImageNet-256 dataset. """
9
+
10
+
11
+ def create_config():
12
+ config = config_mf_default.create_config()
13
+
14
+ config.model.guidance_t_end = 0.75
15
+ config.model.guidance_scale = 0.2
16
+ config.model.guidance_mixture_ratio = 0.92
17
+
18
+ config.model.net = DiT_IN256_XL_Config
19
+ # remove the additional 1000 factor in JVP
20
+ config.model.net.scale_t = False
21
+ config.model.net.r_timestep = True
22
+ config.model.net.time_cond_type = "diff"
23
+
24
+ config.dataloader_train.batch_size = 8
25
+
26
+ return config
FastGen/fastgen/configs/experiments/DiT/config_sft_dit_xl.py ADDED
@@ -0,0 +1,68 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from omegaconf import DictConfig
5
+
6
+ import fastgen.configs.methods.config_sft as config_sft_default
7
+ from fastgen.configs.data import ImageNet256_Loader_Config
8
+ from fastgen.configs.net import DiT_IN256_XL_Config, CKPT_ROOT_DIR
9
+ from fastgen.configs.callbacks import EMA_CONST_CALLBACKS
10
+
11
+ """Configs for SFT (Supervised Fine-Tuning) on DiT-XL and ImageNet-256 dataset."""
12
+
13
+
14
+ def create_config():
15
+ config = config_sft_default.create_config()
16
+
17
+ # model
18
+ # DiT latent shape: [C, H, W] = [4, 32, 32] for 256x256 images (256/8 = 32)
19
+ config.model.input_shape = [4, 32, 32]
20
+ config.model.precision_amp = "bfloat16"
21
+ config.model.cond_dropout_prob = 0.1 # 10% dropout for CFG training
22
+
23
+ # Timestep sampling config
24
+ config.model.sample_t_cfg.time_dist_type = "logitnormal"
25
+ config.model.sample_t_cfg.train_p_mean = -0.4
26
+ config.model.sample_t_cfg.train_p_std = 1.0
27
+ config.model.sample_t_cfg.min_t = 0.001
28
+ config.model.sample_t_cfg.max_t = 0.999
29
+
30
+ # Pretrained model path
31
+ # DiT-XL/2 ImageNet-256 checkpoint from https://github.com/facebookresearch/DiT
32
+ config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/imagenet-256/DiT-XL-2-256x256.pt"
33
+
34
+ # Network config
35
+ config.model.net = DiT_IN256_XL_Config
36
+ config.model.net.learn_sigma = True # Facebook DiT checkpoint was trained with learn_sigma=True
37
+ # Facebook DiT was trained with DDPM (epsilon prediction), not flow matching
38
+ config.model.net.net_pred_type = "eps"
39
+ config.model.net.schedule_type = "sd" # Uses same linear beta schedule as DiT (0.0001 to 0.02)
40
+
41
+ # Optimizer config (lower LR for fine-tuning; 1e-4 is for training from scratch)
42
+ config.model.net_optimizer.optim_type = "adamw"
43
+ config.model.net_optimizer.lr = 1e-5 # 10x lower for fine-tuning
44
+ config.model.net_optimizer.betas = (0.9, 0.95)
45
+ config.model.net_optimizer.weight_decay = 0.0
46
+
47
+ # EMA config
48
+ config.model.use_ema = ["ema_9999", "ema_99995"]
49
+ config.trainer.callbacks = DictConfig(
50
+ {k: v for k, v in config.trainer.callbacks.items() if not k.startswith("ema")}
51
+ )
52
+ config.trainer.callbacks.update(EMA_CONST_CALLBACKS)
53
+
54
+ # Dataloader
55
+ config.dataloader_train = ImageNet256_Loader_Config
56
+
57
+ # Sampling config for visualization
58
+ config.model.student_sample_steps = 50 # DiT typically uses 50-250 steps
59
+
60
+ # Trainer
61
+ config.trainer.batch_size_global = 256
62
+ config.trainer.max_iter = 400000
63
+ config.trainer.save_ckpt_iter = 10000
64
+ config.trainer.logging_iter = 1000
65
+
66
+ config.log_config.group = "dit_xl_imagenet256_sft"
67
+
68
+ return config
FastGen/fastgen/configs/experiments/DiT/config_sft_sit_xl.py ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from omegaconf import DictConfig
5
+
6
+ import fastgen.configs.methods.config_sft as config_sft_default
7
+ from fastgen.configs.data import ImageNet256_Loader_Config
8
+ from fastgen.configs.net import DiT_IN256_XL_Config, CKPT_ROOT_DIR
9
+ from fastgen.configs.callbacks import EMA_CONST_CALLBACKS
10
+
11
+ """Configs for SFT (Supervised Fine-Tuning) on SiT-XL and ImageNet-256 dataset.
12
+
13
+ SiT (Scalable Interpolant Transformers) uses the same architecture as DiT but
14
+ is trained with flow matching (rectified flow) instead of DDPM.
15
+
16
+ Reference: https://github.com/willisma/SiT
17
+ """
18
+
19
+
20
+ def create_config():
21
+ config = config_sft_default.create_config()
22
+
23
+ # model
24
+ # SiT latent shape: [C, H, W] = [4, 32, 32] for 256x256 images (256/8 = 32)
25
+ config.model.input_shape = [4, 32, 32]
26
+ config.model.precision_amp = "bfloat16"
27
+ config.model.cond_dropout_prob = 0.1 # 10% dropout for CFG training
28
+
29
+ # Timestep sampling config for flow matching
30
+ # SiT uses uniform time sampling during training
31
+ config.model.sample_t_cfg.time_dist_type = "uniform"
32
+ config.model.sample_t_cfg.min_t = 0.001
33
+ config.model.sample_t_cfg.max_t = 0.999
34
+
35
+ # Pretrained model path
36
+ # SiT-XL/2 ImageNet-256 checkpoint from https://github.com/willisma/SiT
37
+ config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/imagenet-256/SiT-XL-2-256x256.pt"
38
+
39
+ # Network config - SiT uses DiT architecture
40
+ config.model.net = DiT_IN256_XL_Config
41
+ config.model.net.learn_sigma = True # SiT checkpoint outputs 8 channels
42
+ # SiT was trained with flow matching (rectified flow)
43
+ config.model.net.net_pred_type = "flow" # Flow/velocity prediction
44
+ config.model.net.schedule_type = "rf" # Use RF schedule
45
+ config.model.net.use_sit_convention = True # SiT convention: t -> 1-t, v -> -v
46
+ config.model.net.scale_t = False # SiT uses continuous time t in [0, 1]
47
+
48
+ # Optimizer config (lower LR for fine-tuning; 1e-4 is for training from scratch)
49
+ config.model.net_optimizer.optim_type = "adamw"
50
+ config.model.net_optimizer.lr = 1e-5 # 10x lower for fine-tuning
51
+ config.model.net_optimizer.betas = (0.9, 0.95)
52
+ config.model.net_optimizer.weight_decay = 0.0
53
+
54
+ # EMA config
55
+ config.model.use_ema = ["ema_9999", "ema_99995"]
56
+ config.trainer.callbacks = DictConfig(
57
+ {k: v for k, v in config.trainer.callbacks.items() if not k.startswith("ema")}
58
+ )
59
+ config.trainer.callbacks.update(EMA_CONST_CALLBACKS)
60
+
61
+ # Dataloader
62
+ config.dataloader_train = ImageNet256_Loader_Config
63
+
64
+ # Sampling config for visualization
65
+ config.model.student_sample_steps = 50 # Standard steps
66
+ config.model.guidance_scale = 4.0 # Standard CFG
67
+
68
+ # Trainer
69
+ config.trainer.batch_size_global = 256
70
+ config.trainer.max_iter = 400000
71
+ config.trainer.save_ckpt_iter = 10000
72
+
73
+ config.log_config.group = "sit_xl_imagenet256_sft"
74
+
75
+ return config
FastGen/fastgen/configs/experiments/EDM/config_cm_cifar10.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from omegaconf import DictConfig
5
+ from fastgen.configs.methods.config_cm import create_config as cm_create_config
6
+ from fastgen.configs.callbacks import EMA_CONST_CALLBACKS
7
+ from fastgen.configs.net import CKPT_ROOT_DIR
8
+
9
+
10
+ def create_config():
11
+ config = cm_create_config()
12
+
13
+ # Recommended setting for 2-step:
14
+ # config.model.sample_t_cfg.t_list = [80.0, 0.821, 0.0]
15
+ # config.model.student_sample_steps = 2
16
+
17
+ config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cifar10/edm-cifar10-32x32-uncond-vp.pth"
18
+
19
+ config.model.use_ema = ["ema_9999", "ema_99995", "ema_9996"]
20
+ config.trainer.callbacks = DictConfig(
21
+ {k: v for k, v in config.trainer.callbacks.items() if not k.startswith("ema")}
22
+ )
23
+ config.trainer.callbacks.update(EMA_CONST_CALLBACKS)
24
+
25
+ config.trainer.max_iter = 350000
26
+ config.trainer.batch_size_global = 512
27
+
28
+ return config
FastGen/fastgen/configs/experiments/EDM/config_cm_cifar10_fast.py ADDED
@@ -0,0 +1,25 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from fastgen.configs.methods.config_cm import create_config as cm_create_config
5
+ from fastgen.configs.net import CKPT_ROOT_DIR
6
+
7
+
8
+ def create_config():
9
+ config = cm_create_config()
10
+
11
+ # Recommended setting for 2-step:
12
+ # config.model.sample_t_cfg.t_list = [80.0, 0.821, 0.0]
13
+ # config.model.student_sample_steps = 2
14
+
15
+ config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cifar10/edm-cifar10-32x32-uncond-vp.pth"
16
+
17
+ config.trainer.callbacks.ct_schedule.kimg_per_stage = 512000
18
+ config.dataloader_train.batch_size = 128
19
+
20
+ config.trainer.max_iter = 8000
21
+ config.trainer.callbacks.ct_schedule.q = 256.0
22
+ config.trainer.callbacks.ema.beta = 0.9993
23
+ config.trainer.save_ckpt_iter = 500
24
+ config.trainer.logging_iter = 100
25
+ return config
FastGen/fastgen/configs/experiments/EDM/config_cm_in64.py ADDED
@@ -0,0 +1,65 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from omegaconf import DictConfig
5
+
6
+ import fastgen.configs.methods.config_cm as config_cm_default
7
+ from fastgen.configs.data import ImageNet64_Loader_Config
8
+ from fastgen.configs.net import EDM_ImageNet64_Config, CKPT_ROOT_DIR
9
+ from fastgen.utils import LazyCall as L
10
+ from fastgen.utils.lr_scheduler import LambdaInverseSquareRootScheduler
11
+ from fastgen.configs.callbacks import EMA_POWER_CALLBACKS
12
+
13
+ """ Configs for the CM model, on EDM2 and ImageNet-64 dataset. """
14
+
15
+
16
+ def create_config():
17
+ config = config_cm_default.create_config()
18
+
19
+ # trainer
20
+ # recommended setting for ImageNet-64 is max_iter * batch_size // (4 * 1000)
21
+ config.trainer.callbacks.ct_schedule.kimg_per_stage = 3200
22
+ config.trainer.callbacks.ct_schedule.q = 4
23
+ config.trainer.callbacks.ct_schedule.ratio_limit = 0.9961
24
+
25
+ # ema
26
+ config.model.use_ema = ["ema_1", "ema_5", "ema_10"]
27
+ config.trainer.callbacks = DictConfig(
28
+ {k: v for k, v in config.trainer.callbacks.items() if not k.startswith("ema")}
29
+ )
30
+ config.trainer.callbacks.update(EMA_POWER_CALLBACKS)
31
+
32
+ # model
33
+ config.model.precision_amp = "float16"
34
+ config.model.grad_scaler_enabled = True
35
+ config.model.grad_scaler_init_scale = 16
36
+ config.model.grad_scaler_growth_interval = 20000
37
+ config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/imagenet-64/edm-imagenet-64x64-cond-adm.pth"
38
+ config.model.input_shape = [3, 64, 64]
39
+ config.model.sample_t_cfg.train_p_mean = -0.8
40
+ config.model.sample_t_cfg.train_p_std = 1.6
41
+ config.model.loss_config.huber_const = 0.06
42
+ config.model.loss_config.weighting_ct_loss = "c_out_sq"
43
+
44
+ # Recommended setting for 2-step:
45
+ # config.model.sample_t_cfg.t_list = [80.0, 1.526, 0.0]
46
+ # config.model.student_sample_steps = 2
47
+
48
+ config.model.net = EDM_ImageNet64_Config
49
+ config.model.net.dropout = 0.2
50
+ config.model.net_optimizer.optim_type = "adam"
51
+ config.model.net_optimizer.lr = 1e-3
52
+ config.model.net_optimizer.betas = (0.9, 0.99)
53
+ config.model.net_optimizer.weight_decay = 0.0
54
+ config.model.net_scheduler = L(LambdaInverseSquareRootScheduler)(
55
+ warm_up_steps=0,
56
+ decay_steps=2000,
57
+ )
58
+ # During inference, sigma_shift can improve 2-step results
59
+ # config.model.net.sigma_shift = 0.003
60
+
61
+ # dataloader
62
+ config.dataloader_train = ImageNet64_Loader_Config
63
+
64
+ config.log_config.group = "edm_imagenet64_cm"
65
+ return config
FastGen/fastgen/configs/experiments/EDM/config_dmd2_cifar10.py ADDED
@@ -0,0 +1,24 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2026 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+
4
+ from omegaconf import DictConfig
5
+ from fastgen.configs.methods.config_dmd2 import create_config as dmd2_create_config
6
+ from fastgen.configs.callbacks import EMA_CONST_CALLBACKS
7
+ from fastgen.configs.net import CKPT_ROOT_DIR
8
+
9
+
10
+ def create_config():
11
+ config = dmd2_create_config()
12
+
13
+ config.model.pretrained_model_path = f"{CKPT_ROOT_DIR}/cifar10/edm-cifar10-32x32-cond-vp.pth"
14
+
15
+ config.model.use_ema = ["ema_9999", "ema_99995", "ema_9996"]
16
+ config.trainer.callbacks = DictConfig(
17
+ {k: v for k, v in config.trainer.callbacks.items() if not k.startswith("ema")}
18
+ )
19
+ config.trainer.callbacks.update(EMA_CONST_CALLBACKS)
20
+
21
+ config.trainer.max_iter = 100000
22
+ config.trainer.batch_size_global = 2048
23
+
24
+ return config