Muhammad Taqi Raza commited on
Commit
af758d1
·
1 Parent(s): e3a5735

adding lyra files

Browse files
This view is limited to 50 files because it contains too many changes.   See raw diff
Files changed (50) hide show
  1. .gitattributes +1 -0
  2. ATTRIBUTIONS.md +0 -0
  3. CONTRIBUTING.md +51 -0
  4. INSTALL.md +90 -0
  5. LICENSE +201 -0
  6. README copy.md +163 -0
  7. configs/accelerate/accelerate_config.yaml +17 -0
  8. configs/accelerate/accelerate_config_single.yaml +17 -0
  9. configs/demo/lyra_dynamic.yaml +30 -0
  10. configs/demo/lyra_static.yaml +21 -0
  11. configs/inference/3dgs_res_176_320_views_17.yaml +4 -0
  12. configs/inference/3dgs_res_176_320_views_49.yaml +4 -0
  13. configs/inference/3dgs_res_352_640_views_49.yaml +4 -0
  14. configs/inference/3dgs_res_704_1280_views_121.yaml +4 -0
  15. configs/inference/3dgs_res_704_1280_views_121_multi_6.yaml +5 -0
  16. configs/inference/3dgs_res_704_1280_views_121_multi_6_dynamic.yaml +19 -0
  17. configs/inference/3dgs_res_704_1280_views_121_multi_6_dynamic_prune.yaml +19 -0
  18. configs/inference/3dgs_res_704_1280_views_121_multi_6_prune.yaml +5 -0
  19. configs/inference/3dgs_res_704_1280_views_49.yaml +4 -0
  20. configs/inference/default.yaml +46 -0
  21. configs/training/3dgs_res_176_320_views_17.yaml +9 -0
  22. configs/training/3dgs_res_176_320_views_49.yaml +9 -0
  23. configs/training/3dgs_res_352_640_views_49.yaml +10 -0
  24. configs/training/3dgs_res_704_1280_views_121.yaml +10 -0
  25. configs/training/3dgs_res_704_1280_views_121_multi_6.yaml +12 -0
  26. configs/training/3dgs_res_704_1280_views_121_multi_6_dynamic.yaml +16 -0
  27. configs/training/3dgs_res_704_1280_views_121_multi_6_dynamic_prune.yaml +19 -0
  28. configs/training/3dgs_res_704_1280_views_121_multi_6_prune.yaml +15 -0
  29. configs/training/3dgs_res_704_1280_views_49.yaml +10 -0
  30. configs/training/default.yaml +160 -0
  31. cosmos_predict1/__init__.py +14 -0
  32. cosmos_predict1/autoregressive/__init__.py +14 -0
  33. cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py +352 -0
  34. cosmos_predict1/autoregressive/configs/__init__.py +14 -0
  35. cosmos_predict1/autoregressive/configs/base/__init__.py +14 -0
  36. cosmos_predict1/autoregressive/configs/base/callbacks.py +33 -0
  37. cosmos_predict1/autoregressive/configs/base/dataloader.py +72 -0
  38. cosmos_predict1/autoregressive/configs/base/dataset.py +39 -0
  39. cosmos_predict1/autoregressive/configs/base/model.py +318 -0
  40. cosmos_predict1/autoregressive/configs/base/model_config.py +718 -0
  41. cosmos_predict1/autoregressive/configs/base/model_parallel.py +33 -0
  42. cosmos_predict1/autoregressive/configs/base/optim.py +86 -0
  43. cosmos_predict1/autoregressive/configs/base/tokenizer.py +139 -0
  44. cosmos_predict1/autoregressive/configs/config.py +111 -0
  45. cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py +0 -0
  46. cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py +163 -0
  47. cosmos_predict1/autoregressive/configs/inference/inference_config.py +102 -0
  48. cosmos_predict1/autoregressive/configs/registry.py +89 -0
  49. cosmos_predict1/autoregressive/datasets/dataset_utils.py +173 -0
  50. cosmos_predict1/autoregressive/datasets/video_dataset.py +190 -0
.gitattributes CHANGED
@@ -33,3 +33,4 @@ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
 
 
33
  *.zip filter=lfs diff=lfs merge=lfs -text
34
  *.zst filter=lfs diff=lfs merge=lfs -text
35
  *tfevents* filter=lfs diff=lfs merge=lfs -text
36
+ *.mp4 filter=lfs diff=lfs merge=lfs -text
ATTRIBUTIONS.md ADDED
The diff for this file is too large to render. See raw diff
 
CONTRIBUTING.md ADDED
@@ -0,0 +1,51 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # How to Contribute
2
+
3
+ We'd love to receive your patches and contributions. Please keep your PRs as draft until such time that you would like us to review them.
4
+
5
+ ## Code Reviews
6
+
7
+ All submissions, including submissions by project members, require review. We use GitHub pull requests for this purpose. Consult
8
+ [GitHub Help](https://help.github.com/articles/about-pull-requests/) for more information on using pull requests.
9
+
10
+ ## Signing Your Work
11
+
12
+ * We require that all contributors "sign-off" on their commits. This certifies that the contribution is your original work, or you have rights to submit it under the same license, or a compatible license.
13
+
14
+ * Any contribution which contains commits that are not Signed-Off will not be accepted.
15
+
16
+ * To sign off on a commit you simply use the `--signoff` (or `-s`) option when committing your changes:
17
+ ```bash
18
+ $ git commit -s -m "Add cool feature."
19
+ ```
20
+ This will append the following to your commit message:
21
+ ```
22
+ Signed-off-by: Your Name <your@email.com>
23
+ ```
24
+
25
+ * Full text of the DCO:
26
+
27
+ ```
28
+ Developer Certificate of Origin
29
+ Version 1.1
30
+
31
+ Copyright (C) 2004, 2006 The Linux Foundation and its contributors.
32
+ 1 Letterman Drive
33
+ Suite D4700
34
+ San Francisco, CA, 94129
35
+
36
+ Everyone is permitted to copy and distribute verbatim copies of this license document, but changing it is not allowed.
37
+ ```
38
+
39
+ ```
40
+ Developer's Certificate of Origin 1.1
41
+
42
+ By making a contribution to this project, I certify that:
43
+
44
+ (a) The contribution was created in whole or in part by me and I have the right to submit it under the open source license indicated in the file; or
45
+
46
+ (b) The contribution is based upon previous work that, to the best of my knowledge, is covered under an appropriate open source license and I have the right under that license to submit that work with modifications, whether created in whole or in part by me, under the same open source license (unless I am permitted to submit under a different license), as indicated in the file; or
47
+
48
+ (c) The contribution was provided directly to me by some other person who certified (a), (b) or (c) and I have not modified it.
49
+
50
+ (d) I understand and agree that this project and the contribution are public and that a record of the contribution (including all personal information I submit with it, including my sign-off) is maintained indefinitely and may be redistributed consistent with this project or the open source license(s) involved.
51
+ ```
INSTALL.md ADDED
@@ -0,0 +1,90 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ### Environment setup
2
+
3
+ Cosmos runs only on Linux systems. We have tested the installation with Ubuntu 24.04, 22.04, and 20.04.
4
+ Cosmos requires the Python version to be `3.10.x`. Please also make sure you have `conda` installed ([instructions](https://docs.conda.io/projects/conda/en/latest/user-guide/install/index.html)).
5
+
6
+ The below commands creates the `lyra` conda environment and installs the dependencies for inference:
7
+ ```bash
8
+ # Create the lyra conda environment.
9
+ conda env create --file lyra.yaml
10
+ # Activate the lyra conda environment.
11
+ conda activate lyra
12
+ # Install the dependencies.
13
+ pip install -r requirements_gen3c.txt
14
+ pip install -r requirements_lyra.txt
15
+ # Patch Transformer engine linking issues in conda environments.
16
+ ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/
17
+ ln -sf $CONDA_PREFIX/lib/python3.10/site-packages/nvidia/*/include/* $CONDA_PREFIX/include/python3.10
18
+ # Install Transformer engine.
19
+ pip install transformer-engine[pytorch]==1.12.0
20
+ # Install Apex for inference.
21
+ git clone https://github.com/NVIDIA/apex
22
+ CUDA_HOME=$CONDA_PREFIX pip install -v --disable-pip-version-check --no-cache-dir --no-build-isolation --config-settings "--build-option=--cpp_ext" --config-settings "--build-option=--cuda_ext" ./apex
23
+ # Install MoGe for inference.
24
+ pip install git+https://github.com/microsoft/MoGe.git
25
+ # Install Mamba for reconstruction model.
26
+ pip install --no-build-isolation "git+https://github.com/state-spaces/mamba@v2.2.4"
27
+ ```
28
+
29
+ You can test the environment setup for inference with
30
+ ```bash
31
+ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/test_environment.py
32
+ ```
33
+
34
+ ### Download Cosmos-Predict1 tokenizer
35
+
36
+ 1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token (if you haven't done so already). Set the access token to `Read` permission (default is `Fine-grained`).
37
+
38
+ 2. Log in to Hugging Face with the access token:
39
+ ```bash
40
+ huggingface-cli login
41
+ ```
42
+
43
+ 3. Download the Cosmos Tokenize model weights from [Hugging Face](https://huggingface.co/collections/nvidia/cosmos-predict1-67c9d1b97678dbf7669c89a7):
44
+ ```bash
45
+ python3 -m scripts.download_tokenizer_checkpoints --checkpoint_dir checkpoints/cosmos_predict1 --tokenizer_types CV8x8x8-720p
46
+ ```
47
+
48
+ The downloaded files should be in the following structure:
49
+ ```
50
+ checkpoints/
51
+ ├── Cosmos-Tokenize1-CV8x8x8-720p
52
+ ├── Cosmos-Tokenize1-DV8x16x16-720p
53
+ ├── Cosmos-Tokenize1-CI8x8-360p
54
+ ├── Cosmos-Tokenize1-CI16x16-360p
55
+ ├── Cosmos-Tokenize1-CV4x8x8-360p
56
+ ├── Cosmos-Tokenize1-DI8x8-360p
57
+ ├── Cosmos-Tokenize1-DI16x16-360p
58
+ └── Cosmos-Tokenize1-DV4x8x8-360p
59
+ ```
60
+
61
+ Under the checkpoint repository `checkpoints/<model-name>`, we provide the encoder, decoder, the full autoencoder in TorchScript (PyTorch JIT mode) and the native PyTorch checkpoints. For instance for `Cosmos-Tokenize1-CV8x8x8-720p` model:
62
+ ```bash
63
+ ├── checkpoints/
64
+ │ ├── Cosmos-Tokenize1-CV8x8x8-720p/
65
+ │ │ ├── encoder.jit
66
+ │ │ ├── decoder.jit
67
+ │ │ ├── autoencoder.jit
68
+ │ │ ├── model.pt
69
+ ```
70
+
71
+ ### Download GEN3C checkpoints
72
+
73
+ 1. Generate a [Hugging Face](https://huggingface.co/settings/tokens) access token (if you haven't done so already). Set the access token to `Read` permission (default is `Fine-grained`).
74
+
75
+ 2. Log in to Hugging Face with the access token:
76
+ ```bash
77
+ huggingface-cli login
78
+ ```
79
+
80
+ 3. Download the GEN3C model weights from [Hugging Face](https://huggingface.co/nvidia/GEN3C-Cosmos-7B):
81
+ ```bash
82
+ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_gen3c_checkpoints.py --checkpoint_dir checkpoints
83
+ ```
84
+
85
+ ### Download Lyra checkpoints
86
+
87
+ 1. Download the Lyra model weights from [Hugging Face](https://huggingface.co/nvidia/Lyra):
88
+ ```bash
89
+ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) python scripts/download_lyra_checkpoints.py --checkpoint_dir checkpoints
90
+ ```
LICENSE ADDED
@@ -0,0 +1,201 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ Apache License
2
+ Version 2.0, January 2004
3
+ http://www.apache.org/licenses/
4
+
5
+ TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
6
+
7
+ 1. Definitions.
8
+
9
+ "License" shall mean the terms and conditions for use, reproduction,
10
+ and distribution as defined by Sections 1 through 9 of this document.
11
+
12
+ "Licensor" shall mean the copyright owner or entity authorized by
13
+ the copyright owner that is granting the License.
14
+
15
+ "Legal Entity" shall mean the union of the acting entity and all
16
+ other entities that control, are controlled by, or are under common
17
+ control with that entity. For the purposes of this definition,
18
+ "control" means (i) the power, direct or indirect, to cause the
19
+ direction or management of such entity, whether by contract or
20
+ otherwise, or (ii) ownership of fifty percent (50%) or more of the
21
+ outstanding shares, or (iii) beneficial ownership of such entity.
22
+
23
+ "You" (or "Your") shall mean an individual or Legal Entity
24
+ exercising permissions granted by this License.
25
+
26
+ "Source" form shall mean the preferred form for making modifications,
27
+ including but not limited to software source code, documentation
28
+ source, and configuration files.
29
+
30
+ "Object" form shall mean any form resulting from mechanical
31
+ transformation or translation of a Source form, including but
32
+ not limited to compiled object code, generated documentation,
33
+ and conversions to other media types.
34
+
35
+ "Work" shall mean the work of authorship, whether in Source or
36
+ Object form, made available under the License, as indicated by a
37
+ copyright notice that is included in or attached to the work
38
+ (an example is provided in the Appendix below).
39
+
40
+ "Derivative Works" shall mean any work, whether in Source or Object
41
+ form, that is based on (or derived from) the Work and for which the
42
+ editorial revisions, annotations, elaborations, or other modifications
43
+ represent, as a whole, an original work of authorship. For the purposes
44
+ of this License, Derivative Works shall not include works that remain
45
+ separable from, or merely link (or bind by name) to the interfaces of,
46
+ the Work and Derivative Works thereof.
47
+
48
+ "Contribution" shall mean any work of authorship, including
49
+ the original version of the Work and any modifications or additions
50
+ to that Work or Derivative Works thereof, that is intentionally
51
+ submitted to Licensor for inclusion in the Work by the copyright owner
52
+ or by an individual or Legal Entity authorized to submit on behalf of
53
+ the copyright owner. For the purposes of this definition, "submitted"
54
+ means any form of electronic, verbal, or written communication sent
55
+ to the Licensor or its representatives, including but not limited to
56
+ communication on electronic mailing lists, source code control systems,
57
+ and issue tracking systems that are managed by, or on behalf of, the
58
+ Licensor for the purpose of discussing and improving the Work, but
59
+ excluding communication that is conspicuously marked or otherwise
60
+ designated in writing by the copyright owner as "Not a Contribution."
61
+
62
+ "Contributor" shall mean Licensor and any individual or Legal Entity
63
+ on behalf of whom a Contribution has been received by Licensor and
64
+ subsequently incorporated within the Work.
65
+
66
+ 2. Grant of Copyright License. Subject to the terms and conditions of
67
+ this License, each Contributor hereby grants to You a perpetual,
68
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
69
+ copyright license to reproduce, prepare Derivative Works of,
70
+ publicly display, publicly perform, sublicense, and distribute the
71
+ Work and such Derivative Works in Source or Object form.
72
+
73
+ 3. Grant of Patent License. Subject to the terms and conditions of
74
+ this License, each Contributor hereby grants to You a perpetual,
75
+ worldwide, non-exclusive, no-charge, royalty-free, irrevocable
76
+ (except as stated in this section) patent license to make, have made,
77
+ use, offer to sell, sell, import, and otherwise transfer the Work,
78
+ where such license applies only to those patent claims licensable
79
+ by such Contributor that are necessarily infringed by their
80
+ Contribution(s) alone or by combination of their Contribution(s)
81
+ with the Work to which such Contribution(s) was submitted. If You
82
+ institute patent litigation against any entity (including a
83
+ cross-claim or counterclaim in a lawsuit) alleging that the Work
84
+ or a Contribution incorporated within the Work constitutes direct
85
+ or contributory patent infringement, then any patent licenses
86
+ granted to You under this License for that Work shall terminate
87
+ as of the date such litigation is filed.
88
+
89
+ 4. Redistribution. You may reproduce and distribute copies of the
90
+ Work or Derivative Works thereof in any medium, with or without
91
+ modifications, and in Source or Object form, provided that You
92
+ meet the following conditions:
93
+
94
+ (a) You must give any other recipients of the Work or
95
+ Derivative Works a copy of this License; and
96
+
97
+ (b) You must cause any modified files to carry prominent notices
98
+ stating that You changed the files; and
99
+
100
+ (c) You must retain, in the Source form of any Derivative Works
101
+ that You distribute, all copyright, patent, trademark, and
102
+ attribution notices from the Source form of the Work,
103
+ excluding those notices that do not pertain to any part of
104
+ the Derivative Works; and
105
+
106
+ (d) If the Work includes a "NOTICE" text file as part of its
107
+ distribution, then any Derivative Works that You distribute must
108
+ include a readable copy of the attribution notices contained
109
+ within such NOTICE file, excluding those notices that do not
110
+ pertain to any part of the Derivative Works, in at least one
111
+ of the following places: within a NOTICE text file distributed
112
+ as part of the Derivative Works; within the Source form or
113
+ documentation, if provided along with the Derivative Works; or,
114
+ within a display generated by the Derivative Works, if and
115
+ wherever such third-party notices normally appear. The contents
116
+ of the NOTICE file are for informational purposes only and
117
+ do not modify the License. You may add Your own attribution
118
+ notices within Derivative Works that You distribute, alongside
119
+ or as an addendum to the NOTICE text from the Work, provided
120
+ that such additional attribution notices cannot be construed
121
+ as modifying the License.
122
+
123
+ You may add Your own copyright statement to Your modifications and
124
+ may provide additional or different license terms and conditions
125
+ for use, reproduction, or distribution of Your modifications, or
126
+ for any such Derivative Works as a whole, provided Your use,
127
+ reproduction, and distribution of the Work otherwise complies with
128
+ the conditions stated in this License.
129
+
130
+ 5. Submission of Contributions. Unless You explicitly state otherwise,
131
+ any Contribution intentionally submitted for inclusion in the Work
132
+ by You to the Licensor shall be under the terms and conditions of
133
+ this License, without any additional terms or conditions.
134
+ Notwithstanding the above, nothing herein shall supersede or modify
135
+ the terms of any separate license agreement you may have executed
136
+ with Licensor regarding such Contributions.
137
+
138
+ 6. Trademarks. This License does not grant permission to use the trade
139
+ names, trademarks, service marks, or product names of the Licensor,
140
+ except as required for reasonable and customary use in describing the
141
+ origin of the Work and reproducing the content of the NOTICE file.
142
+
143
+ 7. Disclaimer of Warranty. Unless required by applicable law or
144
+ agreed to in writing, Licensor provides the Work (and each
145
+ Contributor provides its Contributions) on an "AS IS" BASIS,
146
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
147
+ implied, including, without limitation, any warranties or conditions
148
+ of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
149
+ PARTICULAR PURPOSE. You are solely responsible for determining the
150
+ appropriateness of using or redistributing the Work and assume any
151
+ risks associated with Your exercise of permissions under this License.
152
+
153
+ 8. Limitation of Liability. In no event and under no legal theory,
154
+ whether in tort (including negligence), contract, or otherwise,
155
+ unless required by applicable law (such as deliberate and grossly
156
+ negligent acts) or agreed to in writing, shall any Contributor be
157
+ liable to You for damages, including any direct, indirect, special,
158
+ incidental, or consequential damages of any character arising as a
159
+ result of this License or out of the use or inability to use the
160
+ Work (including but not limited to damages for loss of goodwill,
161
+ work stoppage, computer failure or malfunction, or any and all
162
+ other commercial damages or losses), even if such Contributor
163
+ has been advised of the possibility of such damages.
164
+
165
+ 9. Accepting Warranty or Additional Liability. While redistributing
166
+ the Work or Derivative Works thereof, You may choose to offer,
167
+ and charge a fee for, acceptance of support, warranty, indemnity,
168
+ or other liability obligations and/or rights consistent with this
169
+ License. However, in accepting such obligations, You may act only
170
+ on Your own behalf and on Your sole responsibility, not on behalf
171
+ of any other Contributor, and only if You agree to indemnify,
172
+ defend, and hold each Contributor harmless for any liability
173
+ incurred by, or claims asserted against, such Contributor by reason
174
+ of your accepting any such warranty or additional liability.
175
+
176
+ END OF TERMS AND CONDITIONS
177
+
178
+ APPENDIX: How to apply the Apache License to your work.
179
+
180
+ To apply the Apache License to your work, attach the following
181
+ boilerplate notice, with the fields enclosed by brackets "[]"
182
+ replaced with your own identifying information. (Don't include
183
+ the brackets!) The text should be enclosed in the appropriate
184
+ comment syntax for the file format. We also recommend that a
185
+ file or class name and description of purpose be included on the
186
+ same "printed page" as the copyright notice for easier
187
+ identification within third-party archives.
188
+
189
+ Copyright [yyyy] [name of copyright owner]
190
+
191
+ Licensed under the Apache License, Version 2.0 (the "License");
192
+ you may not use this file except in compliance with the License.
193
+ You may obtain a copy of the License at
194
+
195
+ http://www.apache.org/licenses/LICENSE-2.0
196
+
197
+ Unless required by applicable law or agreed to in writing, software
198
+ distributed under the License is distributed on an "AS IS" BASIS,
199
+ WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
200
+ See the License for the specific language governing permissions and
201
+ limitations under the License.
README copy.md ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Lyra: Generative 3D Scene Reconstruction via Video Diffusion Model Self-Distillation
2
+
3
+ <p align="center">
4
+ <img src="https://github.com/user-attachments/assets/12d44362-8b7f-4952-9488-0e45cf759b57" alt="teaser"/>
5
+ </p>
6
+
7
+ **TL;DR: Feed-forward 3D and 4D scene generation from a single image/video trained with synthetic data generated by a camera-controlled video diffusion model.**
8
+
9
+ **Full Abstract**:
10
+ The ability to generate virtual environments is crucial for applications ranging from gaming to physical AI domains such as robotics, autonomous driving, and industrial AI. Current learning-based 3D reconstruction methods rely on the availability of captured real-world multi-view data, which is not always readily available. Recent advancements in video diffusion models have shown remarkable imagination capabilities, yet their 2D nature limits the applications to simulation where a robot needs to navigate and interact with the environment. In this paper, we propose a self-distillation framework that aims to distill the implicit 3D knowledge in the video diffusion models into an explicit 3D Gaussian Splatting (3DGS) representation, eliminating the need for multi-view training data. Specifically, we augment the typical RGB decoder with a 3DGS decoder, which is supervised by the output of the RGB decoder. In this approach, the 3DGS decoder can be purely trained with synthetic data generated by video diffusion models. At inference time, our model can synthesize 3D scenes from either a text prompt or a single image for real-time rendering. Our framework further extends to dynamic 3D scene generation from a monocular input video. Experimental results show that our framework achieves state-of-the-art performance in static and dynamic 3D scene generation.
11
+
12
+ **[Paper](https://arxiv.org/abs/2509.19296), [Project Page](https://research.nvidia.com/labs/toronto-ai/lyra/), [Dataset](https://huggingface.co/datasets/nvidia/PhysicalAI-SpatialIntelligence-Lyra-SDG)**
13
+
14
+ [Sherwin Bahmani](https://sherwinbahmani.github.io/),
15
+ [Tianchang Shen](https://www.cs.toronto.edu/~shenti11/),
16
+ [Jiawei Ren](https://jiawei-ren.github.io/),
17
+ [Jiahui Huang](https://huangjh-pub.github.io/),
18
+ [Yifeng Jiang](https://cs.stanford.edu/~yifengj/),
19
+ [Haithem Turki](https://haithemturki.com/),
20
+ [Andrea Tagliasacchi](https://theialab.ca/),
21
+ [David B. Lindell](https://davidlindell.com/),
22
+ [Zan Gojcic](https://zgojcic.github.io/),
23
+ [Sanja Fidler](https://www.cs.utoronto.ca/~fidler/),
24
+ [Huan Ling](https://www.cs.toronto.edu/~linghuan/),
25
+ [Jun Gao](https://www.cs.toronto.edu/~jungao/),
26
+ [Xuanchi Ren](https://xuanchiren.com/) <br>
27
+
28
+ ## Installation
29
+
30
+ Please follow the [INSTALL.md](INSTALL.md) to set up your conda environment and download pre-trained weights.
31
+
32
+ ## Demo
33
+ Lyra supports both images and videos as input. Below are examples of running Lyra on single images and videos.
34
+
35
+ First, you need to download the demo samples:
36
+
37
+ ```bash
38
+ # Download test samples from Hugging Face
39
+ huggingface-cli download nvidia/Lyra-Testing-Example --repo-type dataset --local-dir assets/demo
40
+ ```
41
+
42
+ ### Example 1: Single Image to 3D Gaussians Generation
43
+
44
+ 1) Generate multi-view video latents from the input image using scripts/bash/static_sdg.sh.
45
+
46
+ ```bash
47
+ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) torchrun --nproc_per_node=1 cosmos_predict1/diffusion/inference/gen3c_single_image_sdg.py \
48
+ --checkpoint_dir checkpoints \
49
+ --num_gpus 1 \
50
+ --input_image_path assets/demo/static/diffusion_input/images/00172.png \
51
+ --video_save_folder assets/demo/static/diffusion_output_generated \
52
+ --foreground_masking \
53
+ --multi_trajectory
54
+ ```
55
+
56
+ If you want to skip the diffusion part, we have pre-generated the latents in assets/demo/static/diffusion_output. By default we use pre-generated latents, change dataset_name in configs/demo/lyra_static.yaml from lyra_static_demo to lyra_static_demo_generated to use your own generated latents.
57
+
58
+ 2) Reconstruct multi-view video latents with the 3DGS decoder (change dataset_name in the .yaml to generated path if 1. was done)
59
+
60
+ ```bash
61
+ accelerate launch sample.py --config configs/demo/lyra_static.yaml
62
+ ```
63
+
64
+ ### Example 2: Single Video to Dynamic 3D Gaussians Generation
65
+
66
+ 1) Generate multi-view video latents from the input video and ViPE estimated depth using scripts/bash/dynamic_sdg.sh.
67
+
68
+ ```bash
69
+ CUDA_HOME=$CONDA_PREFIX PYTHONPATH=$(pwd) torchrun --nproc_per_node=1 cosmos_predict1/diffusion/inference/gen3c_dynamic_sdg.py \
70
+ --checkpoint_dir checkpoints \
71
+ --vipe_path assets/demo/dynamic/diffusion_input/rgb/6a71ee0422ff4222884f1b2a3cba6820.mp4 \
72
+ --video_save_folder assets/demo/dynamic/diffusion_output \
73
+ --disable_prompt_upsampler \
74
+ --num_gpus 1 \
75
+ --foreground_masking \
76
+ --multi_trajectory
77
+ ```
78
+
79
+ If you want to skip the diffusion part, we have pre-generated the latents in assets/demo/dynamic/diffusion_output. By default we use pre-generated latents, change dataset_name in configs/demo/lyra_dynamic.yaml from lyra_dynamic_demo to lyra_dynamic_demo_generated to use your own generated latents.
80
+ Add --flip_supervision if you want to also generate the motion reversed training data (not needed for inference).
81
+
82
+ 2) Reconstruct multi-view video latents with the 3DGS decoder (change dataset_name in the .yaml to generated path if 1. was done)
83
+
84
+ ```bash
85
+ accelerate launch sample.py --config configs/demo/lyra_dynamic.yaml
86
+ ```
87
+
88
+ #### Testing on your own videos using ViPE
89
+ Follow the installation instructions for [ViPE](https://github.com/nv-tlabs/vipe). Note: ViPE's environment is not compatible with Lyra. We recommend installing ViPE in a separate conda environment. The ViPE results are required for dynamic scene generation. Moreover, we use the depth from ViPE for depth supervision during 3DGS decoder training.
90
+
91
+ 1) Run ViPE to extract depth, intrinsics, and camera poses (make sure to use the --lyra flag to use the same depth estimator as us):
92
+ ```bash
93
+ vipe infer YOUR_VIDEO.mp4 -p lyra --output <vipe_results_dir>
94
+ ```
95
+
96
+ 2) Define the new data path in src/models/data/registry.py as dataset following the structure of our provided datasets
97
+
98
+ ### GPU Memory Requirements
99
+
100
+ We have tested Lyra only on H100 and A100 GPUs. For GPUs with limited memory, you can fully offload all models by appending the following flags to your SDG command:
101
+
102
+ ```bash
103
+ --offload_diffusion_transformer \
104
+ --offload_tokenizer \
105
+ --offload_text_encoder_model \
106
+ --offload_prompt_upsampler \
107
+ --offload_guardrail_models \
108
+ --disable_guardrail \
109
+ --disable_prompt_encoder
110
+ ```
111
+ Maximum observed memory during inference with full offloading: ~43GB. Note: Memory usage may vary depending on system specifications and is provided for reference only.
112
+
113
+ ## Training
114
+
115
+ We provide training scripts to train from scratch or fine-tune our models. First, you need to download our [training data](https://huggingface.co/datasets/nvidia/PhysicalAI-SpatialIntelligence-Lyra-SDG):
116
+
117
+ ```bash
118
+ # Download our training datasets from Hugging Face and untar them into a static/dynamic folder
119
+ huggingface-cli download nvidia/PhysicalAI-SpatialIntelligence-Lyra-SDG --repo-type dataset --local-dir lyra_dataset/tar
120
+ ```
121
+
122
+ Alternatively, use the demo script to generate training data. Here, the diffusion part is sufficient without running the 3DGS decoder, since we want to train that. Make sure to update the paths in src/models/data/registry.py for lyra_static / lyra_dynamic to wherever your data is stored. We provide our progressive training script:
123
+
124
+ ```bash
125
+ bash train.sh
126
+ ```
127
+
128
+ We provide visualization scripts during training to export renderings and 3D Gaussians for each stage:
129
+
130
+ ```bash
131
+ bash inference.sh
132
+ ```
133
+
134
+ ## Acknowledgement
135
+ Our model is based on [NVIDIA Cosmos](https://github.com/NVIDIA/Cosmos) and [GEN3C](https://github.com/nv-tlabs/GEN3C). We use input images generated by [Flux](https://github.com/black-forest-labs/flux).
136
+
137
+ We are also grateful to several other open-source repositories that we drew inspiration from or built upon during the development of our pipeline:
138
+ - [MoGe](https://github.com/microsoft/MoGe)
139
+ - [TrajectoryCrafter](https://github.com/TrajectoryCrafter/TrajectoryCrafter)
140
+ - [DimensionX](https://github.com/wenqsun/DimensionX)
141
+ - [Depth Anything V2](https://github.com/DepthAnything/Depth-Anything-V2)
142
+ - [Video Depth Anything](https://github.com/DepthAnything/Video-Depth-Anything)
143
+ - [Long-LRM](https://github.com/arthurhero/Long-LRM)
144
+
145
+ ## Citation
146
+ ```
147
+ @inproceedings{bahmani2025lyra,
148
+ title={Lyra: Generative 3D Scene Reconstruction via Video Diffusion Model Self-Distillation},
149
+ author={Bahmani, Sherwin and Shen, Tianchang and Ren, Jiawei and Huang, Jiahui and Jiang, Yifeng and
150
+ Turki, Haithem and Tagliasacchi, Andrea and Lindell, David B. and Gojcic, Zan and Fidler, Sanja and
151
+ Ling, Huan and Gao, Jun and Ren, Xuanchi},
152
+ booktitle={arXiv preprint arXiv:2509.19296},
153
+ year={2025}
154
+ }
155
+ ```
156
+
157
+ ## License and Contact
158
+
159
+ This project will download and install additional third-party open source software projects. Review the license terms of these open source projects before use.
160
+
161
+ Lyra source code is released under the [Apache 2 License](https://www.apache.org/licenses/LICENSE-2.0).
162
+
163
+ Lyra models are released under the [NVIDIA Open Model License](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). For a custom license, please visit our website and submit the form: [NVIDIA Research Licensing](https://www.nvidia.com/en-us/research/inquiries/).
configs/accelerate/accelerate_config.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: MULTI_GPU
4
+ enable_cpu_affinity: false
5
+ downcast_bf16: 'no'
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: 'no'
10
+ num_machines: 1
11
+ num_processes: 8
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
configs/accelerate/accelerate_config_single.yaml ADDED
@@ -0,0 +1,17 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ compute_environment: LOCAL_MACHINE
2
+ debug: false
3
+ distributed_type: NO
4
+ enable_cpu_affinity: false
5
+ downcast_bf16: 'no'
6
+ gpu_ids: all
7
+ machine_rank: 0
8
+ main_training_function: main
9
+ mixed_precision: 'no'
10
+ num_machines: 1
11
+ num_processes: 1
12
+ rdzv_backend: static
13
+ same_network: true
14
+ tpu_env: []
15
+ tpu_use_cluster: false
16
+ tpu_use_sudo: false
17
+ use_cpu: false
configs/demo/lyra_dynamic.yaml ADDED
@@ -0,0 +1,30 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Save all renderings etc. in this folder
2
+ out_dir_inference: outputs/demo/lyra_dynamic
3
+
4
+ # Define dataset name defined in src/models/data/registry.py
5
+ dataset_name: lyra_dynamic_demo # Use pre-generated latents
6
+ # dataset_name: lyra_dynamic_demo_generated # Generate own latents
7
+
8
+ # Order of camera trajectory indices
9
+ static_view_indices_fixed: ['5', '0', '1', '2', '3', '4']
10
+
11
+ # Only render each 4. frame
12
+ target_index_subsample: 4
13
+
14
+ # Inherit from these configs the model part etc.
15
+ config_path: [configs/training/default.yaml, configs/training/3dgs_res_704_1280_views_121_multi_6_dynamic_prune.yaml]
16
+
17
+ # For dynamic scenes, set target time
18
+ set_manual_time_idx: true
19
+
20
+ # Only create outputs for specified target times (between 0=min and 120=max)
21
+ target_index_manual: [0, 60, 120]
22
+
23
+ # Alternative: Loop over start and number of time indices with given stride
24
+ # target_index_manual: null
25
+ # target_index_manual_stride: 1
26
+ # target_index_manual_start_idx: 0
27
+ # target_index_manual_num_idx: 121
28
+
29
+ # Update path to where the static Lyra checkpoint is downloaded
30
+ ckpt_path: checkpoints/Lyra/lyra_dynamic.pt
configs/demo/lyra_static.yaml ADDED
@@ -0,0 +1,21 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Save all renderings etc. in this folder
2
+ out_dir_inference: outputs/demo/lyra_static
3
+
4
+ # Define dataset name defined in src/models/data/registry.py
5
+ dataset_name: lyra_static_demo # Use pre-generated latents
6
+ # dataset_name: lyra_static_demo_generated # Generate own latents
7
+
8
+ # Order of camera trajectory indices
9
+ static_view_indices_fixed: ['5', '0', '1', '2', '3', '4']
10
+
11
+ # Only render each 4. frame
12
+ target_index_subsample: 4
13
+
14
+ # For static scenes, do not set target time
15
+ set_manual_time_idx: true
16
+
17
+ # Inherit from these configs the model part etc.
18
+ config_path: [configs/training/default.yaml, configs/training/3dgs_res_704_1280_views_121_multi_6_prune.yaml]
19
+
20
+ # Update path to where the static Lyra checkpoint is downloaded
21
+ ckpt_path: checkpoints/Lyra/lyra_static.pt
configs/inference/3dgs_res_176_320_views_17.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ config_path: outputs/training/3dgs_res_176_320_views_17/config.yaml
2
+ out_dir_inference: outputs/inference/3dgs_res_176_320_views_17
3
+
4
+ static_view_indices_fixed: ['0']
configs/inference/3dgs_res_176_320_views_49.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ config_path: outputs/training/3dgs_res_176_320_views_49/config.yaml
2
+ out_dir_inference: outputs/inference/3dgs_res_176_320_views_49
3
+
4
+ static_view_indices_fixed: ['0']
configs/inference/3dgs_res_352_640_views_49.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ config_path: outputs/training/3dgs_res_352_640_views_49/config.yaml
2
+ out_dir_inference: outputs/inference/3dgs_res_352_640_views_49
3
+
4
+ static_view_indices_fixed: ['0']
configs/inference/3dgs_res_704_1280_views_121.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ config_path: outputs/training/3dgs_res_704_1280_views_121/config.yaml
2
+ out_dir_inference: outputs/inference/3dgs_res_704_1280_views_121
3
+
4
+ static_view_indices_fixed: ['0']
configs/inference/3dgs_res_704_1280_views_121_multi_6.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ config_path: outputs/training/3dgs_res_704_1280_views_121_multi_6/config.yaml
2
+ out_dir_inference: outputs/inference/3dgs_res_704_1280_views_121_multi_6
3
+
4
+ static_view_indices_fixed: ['5', '0', '1', '2', '3', '4']
5
+ target_index_subsample: 4
configs/inference/3dgs_res_704_1280_views_121_multi_6_dynamic.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config_path: outputs/training/3dgs_res_704_1280_views_121_multi_6_dynamic/config.yaml
2
+ out_dir_inference: outputs/inference/3dgs_res_704_1280_views_121_multi_6_dynamic
3
+
4
+ static_view_indices_fixed: ['5', '0', '1', '2', '3', '4']
5
+ target_index_subsample: 4
6
+
7
+ # For dynamic scenes, set target time
8
+ set_manual_time_idx: true
9
+
10
+ # Only create outputs for specified target times
11
+ target_index_manual: [0, 60, 120]
12
+
13
+ dataset_name: lyra_dynamic_demo
14
+
15
+ # Alternative: Loop over start and number of time indices with given stride
16
+ # target_index_manual: null
17
+ # target_index_manual_stride: 1
18
+ # target_index_manual_start_idx: 0
19
+ # target_index_manual_num_idx: 121
configs/inference/3dgs_res_704_1280_views_121_multi_6_dynamic_prune.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config_path: outputs/training/3dgs_res_704_1280_views_121_multi_6_dynamic_prune/config.yaml
2
+ out_dir_inference: outputs/inference/3dgs_res_704_1280_views_121_multi_6_dynamic_prune
3
+
4
+ static_view_indices_fixed: ['5', '0', '1', '2', '3', '4']
5
+ target_index_subsample: 4
6
+
7
+ # For dynamic scenes, set target time
8
+ set_manual_time_idx: true
9
+
10
+ # Only create outputs for specified target times
11
+ target_index_manual: [0, 60, 120]
12
+
13
+ dataset_name: lyra_dynamic_demo
14
+
15
+ # Alternative: Loop over start and number of time indices with given stride
16
+ # target_index_manual: null
17
+ # target_index_manual_stride: 1
18
+ # target_index_manual_start_idx: 0
19
+ # target_index_manual_num_idx: 121
configs/inference/3dgs_res_704_1280_views_121_multi_6_prune.yaml ADDED
@@ -0,0 +1,5 @@
 
 
 
 
 
 
1
+ config_path: outputs/training/3dgs_res_704_1280_views_121_multi_6_prune/config.yaml
2
+ out_dir_inference: outputs/inference/3dgs_res_704_1280_views_121_multi_6_prune
3
+
4
+ static_view_indices_fixed: ['5', '0', '1', '2', '3', '4']
5
+ target_index_subsample: 4
configs/inference/3dgs_res_704_1280_views_49.yaml ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ config_path: outputs/training/3dgs_res_704_1280_views_49/config.yaml
2
+ out_dir_inference: outputs/inference/3dgs_res_704_1280_views_49
3
+
4
+ static_view_indices_fixed: ['0']
configs/inference/default.yaml ADDED
@@ -0,0 +1,46 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ config_path: outputs/training/cosmos3dgs/config.yaml
2
+ out_dir_inference: outputs/inference
3
+ dataset_name: lyra_static_demo
4
+
5
+ # If not checkpoint name given, take the latest
6
+ ckpt_path: null
7
+ ckpt_name: null
8
+
9
+ # Set view indices manually
10
+ static_view_indices_fixed: null
11
+
12
+ # Only render a stride of the cameras from the 3DGS
13
+ target_index_subsample: 1
14
+
15
+ # Do evaluation
16
+ do_eval: false
17
+
18
+ # Don't read and write the depth
19
+ use_depth: false
20
+
21
+ # Overwrite number of test images
22
+ num_test_images: null
23
+
24
+ # Video output fps
25
+ out_fps: 24
26
+
27
+ # Assume static scenes in default
28
+ target_index_manual: null
29
+ target_index_manual_start_idx: null
30
+
31
+ ## Export file config
32
+ # Output a grid of results, if yes, how many scenes to visualize in one grid
33
+ save_grid: false
34
+ num_grid_samples: 4
35
+ # Output RGB decoder output next to the 3DGS rendering
36
+ save_gt_input: true
37
+ # Output a separate file of the RGB decoder output
38
+ save_video_input: false
39
+ # Output annotated gt depth
40
+ save_gt_depth: true
41
+ # Save a RGB-decoded version of the latents
42
+ save_rgb_decoding: false
43
+ # Output 3D gaussians as simple ply file
44
+ save_gaussians: false
45
+ # Output 3D gaussians using the original 3DGS export script
46
+ save_gaussians_orig: false
configs/training/3dgs_res_176_320_views_17.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: outputs/training/3dgs_res_176_320_views_17
2
+ img_size: [176, 320]
3
+ num_views: 34
4
+ num_input_views: 17
5
+ gs_view_chunk_size: 1
6
+ num_input_multi_views: 1
7
+ batch_size: 4
8
+ load_latents: False
9
+ max_train_steps: 10000
configs/training/3dgs_res_176_320_views_49.yaml ADDED
@@ -0,0 +1,9 @@
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: outputs/training/3dgs_res_176_320_views_49
2
+ img_size: [176, 320]
3
+ num_views: 98
4
+ num_input_views: 49
5
+ gs_view_chunk_size: 1
6
+ num_input_multi_views: 1
7
+ batch_size: 4
8
+ load_latents: False
9
+ max_train_steps: 12500
configs/training/3dgs_res_352_640_views_49.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: outputs/training/3dgs_res_352_640_views_49
2
+ img_size: [352, 640]
3
+ num_views: 98
4
+ num_input_views: 49
5
+ gs_view_chunk_size: 1
6
+ num_input_multi_views: 1
7
+ batch_size: 2
8
+ lpips_chunk_size: 8
9
+ load_latents: False
10
+ max_train_steps: 15000
configs/training/3dgs_res_704_1280_views_121.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: outputs/training/3dgs_res_704_1280_views_121
2
+ img_size: [704, 1280]
3
+ num_views: 130
4
+ num_input_views: 121
5
+ gs_view_chunk_size: 1
6
+ num_input_multi_views: 1
7
+ batch_size: 1
8
+ lpips_chunk_size: 1
9
+ checkpointing_steps: 200
10
+ max_train_steps: 75000
configs/training/3dgs_res_704_1280_views_121_multi_6.yaml ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: outputs/training/3dgs_res_704_1280_views_121_multi_6
2
+ img_size: [704, 1280]
3
+ num_views: 130
4
+ num_input_views: 121
5
+ gs_view_chunk_size: 1
6
+ num_input_multi_views: 6
7
+ batch_size: 1
8
+ static_view_indices_sampling: random_bucket
9
+ static_frame_sampling: exponential
10
+ lpips_chunk_size: 1
11
+ checkpointing_steps: 50
12
+ max_train_steps: 82000
configs/training/3dgs_res_704_1280_views_121_multi_6_dynamic.yaml ADDED
@@ -0,0 +1,16 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: outputs/training/3dgs_res_704_1280_views_121_multi_6_dynamic
2
+ data_mode: [['lyra_dynamic', 1]]
3
+ img_size: [704, 1280]
4
+ num_views: 133
5
+ num_input_views: 121
6
+ gs_view_chunk_size: 1
7
+ num_input_multi_views: 6
8
+ batch_size: 1
9
+ static_view_indices_sampling: random_bucket
10
+ static_frame_sampling: exponential
11
+ lpips_chunk_size: 1
12
+ use_time_embedding: true
13
+ # use flipped supervision
14
+ select_target_views_input_dynamic: false
15
+ checkpointing_steps: 50
16
+ max_train_steps: 90000
configs/training/3dgs_res_704_1280_views_121_multi_6_dynamic_prune.yaml ADDED
@@ -0,0 +1,19 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: outputs/training/3dgs_res_704_1280_views_121_multi_6_dynamic_prune
2
+ data_mode: [['lyra_dynamic', 1]]
3
+ img_size: [704, 1280]
4
+ num_views: 133
5
+ num_input_views: 121
6
+ gs_view_chunk_size: 1
7
+ num_input_multi_views: 6
8
+ batch_size: 1
9
+ static_view_indices_sampling: random_bucket
10
+ static_frame_sampling: exponential
11
+ lpips_chunk_size: 1
12
+ lambda_opacity: 0.1
13
+ gaussians_prune_ratio: 0.8
14
+ gaussians_random_ratio: 0.0
15
+ # use flipped supervision
16
+ select_target_views_input_dynamic: false
17
+ use_time_embedding: true
18
+ checkpointing_steps: 50
19
+ max_train_steps: 91000
configs/training/3dgs_res_704_1280_views_121_multi_6_prune.yaml ADDED
@@ -0,0 +1,15 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: outputs/training/3dgs_res_704_1280_views_121_multi_6_prune
2
+ img_size: [704, 1280]
3
+ num_views: 130
4
+ num_input_views: 121
5
+ gs_view_chunk_size: 1
6
+ num_input_multi_views: 6
7
+ batch_size: 1
8
+ static_view_indices_sampling: random_bucket
9
+ static_frame_sampling: exponential
10
+ lpips_chunk_size: 1
11
+ lambda_opacity: 0.1
12
+ gaussians_prune_ratio: 0.8
13
+ gaussians_random_ratio: 0.0
14
+ checkpointing_steps: 50
15
+ max_train_steps: 83000
configs/training/3dgs_res_704_1280_views_49.yaml ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ output_dir: outputs/training/3dgs_res_704_1280_views_49
2
+ img_size: [704, 1280]
3
+ num_views: 98
4
+ num_input_views: 49
5
+ gs_view_chunk_size: 1
6
+ num_input_multi_views: 1
7
+ batch_size: 1
8
+ lpips_chunk_size: 1
9
+ load_latents: False
10
+ max_train_steps: 17500
configs/training/default.yaml ADDED
@@ -0,0 +1,160 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ experiment_name: cosmos3dgs
2
+ output_dir: outputs/training/cosmos3dgs
3
+ data_mode: [['lyra_static', 1]]
4
+ img_size: [704, 1280]
5
+ resume_from_checkpoint: latest
6
+ resume_from_checkpoint_dir: null
7
+ seed: 42
8
+ checkpointing_steps: 100
9
+ permanent_checkpointing_steps: 2500
10
+ checkpoints_total_limit: 5
11
+ max_train_steps: 100000000
12
+ num_workers: 16
13
+ batch_size: 4
14
+ local_rank: -1
15
+
16
+ log_with: null
17
+ pretrained_model_name_or_path: null
18
+ save_multi_random_states: false
19
+ find_unused_parameters: false
20
+ use_ema: false
21
+ job_stop_steps: null
22
+ resume_pretrained_model_ckpt: null
23
+
24
+ # Main blocks
25
+ use_mamba: true
26
+ llrm_7m1t: true
27
+ llrm_7m1t_index: 8
28
+ llrm_7m1t_index_residual: 8
29
+ enc_depth: 16
30
+ enc_embed_dim: 512
31
+ enc_num_heads: 8
32
+ mlp_ratio: 4
33
+ patch_size: 2
34
+ patch_size_temporal: 1
35
+ num_block_channels_reduce: null
36
+ use_pos_embedding: false
37
+ gradient_checkpoint_transformer: true
38
+
39
+ # Tokenizer
40
+ vae_backbone: cosmos1
41
+ vae_path: ./checkpoints/cosmos_predict1/Cosmos-Tokenize1-CV8x8x8-720p
42
+
43
+ # Latent decoding
44
+ use_rgb_decoder: false
45
+ use_patch_embeddings_encoder: true
46
+ use_cosmos_decoder: false
47
+ transposed_conv_type: null # [None, 'factorized']
48
+ transposed_conv_hidden_channels: null
49
+ num_latent_c: 16
50
+ latent_time_compression: 8
51
+ latent_spat_compression: 8
52
+ patch_size_out_factor: [1, 8, 8]
53
+ gradient_checkpoint_conv: true
54
+
55
+ # Camera conditioning
56
+ use_plucker: true
57
+ relative_translation_scale: true
58
+ plucker_embedding_vae: true
59
+ compute_plucker_cuda: true
60
+ compute_plucker_dtype: bfloat16
61
+ plucker_embedding_vae_fuse_type: concat
62
+
63
+ # Frame sampling
64
+ num_views: 130
65
+ num_input_views: 121
66
+ gs_view_chunk_size: 1
67
+ num_input_multi_views: 1
68
+ fuse_multi_views: true
69
+ process_multi_views: true
70
+ static_view_indices_sampling: random
71
+ deferred_bp: true
72
+ static_frame_sampling: uniform
73
+ # sample a variable number of input multi views
74
+ sample_num_input_multi_views: True
75
+ static_view_indices_fixed: ['0']
76
+ # patch-based training
77
+ gs_render_patch_size: null
78
+
79
+ # subsample gaussians
80
+ sub_sample_gaussians_factor: null # e.g., [t, h, w] = [1, 2, 2], if null = no subsampling
81
+ sub_sample_gaussians: true
82
+ sub_sample_gaussians_type: null # [None, 'learned']
83
+ sub_sample_gaussians_type_tokens: 'global' # [None, 'global', 'local']
84
+ sub_sample_gaussians_temperature: 1.0
85
+
86
+ # freely moving gaussians
87
+ gaussians_predict_offset: false
88
+ use_gaussians_predict_offset: true
89
+ gaussians_predict_offset_range: [-1, 1]
90
+ gaussians_predict_offset_act: 'clamp'
91
+
92
+ # general rendering config
93
+ use_3dgut: true
94
+ znear: 0.1
95
+ zfar: 500
96
+ dnear: 0.1
97
+ dfar: 500
98
+ output_dims: 12
99
+ gaussian_scale_cap: 0.3
100
+ pre_sigmoid_distance_shift: -1.65
101
+
102
+ # Training setup
103
+ workspace: ./workspace
104
+ logging_dir: logs
105
+ resume: null
106
+ gradient_accumulation_steps: 1
107
+ gradient_clip: 1.0
108
+ mixed_precision: bf16
109
+ use_deepspeed: true
110
+ deepspeed_type: null
111
+ use_fsdp: false
112
+ learning_rate: 1e-4
113
+ scale_lr: false
114
+ lr_scheduler: constant_with_warmup
115
+ lr_warmup_steps: 100
116
+ lr_overwrite: false
117
+ use_8bit_adam: false
118
+ allow_tf32: true
119
+ adam_beta1: 0.9
120
+ adam_beta2: 0.999
121
+ adam_weight_decay: 1e-2
122
+ adam_epsilon: 1e-8
123
+ max_grad_norm: 1.0
124
+ autocast_cache_enabled: false
125
+ set_transformer_dtype: true
126
+ compile_frozen_modules: false
127
+ use_flex_attention: false
128
+ use_qk_norm: false
129
+ grad_norm_cap: 5000
130
+
131
+ # Additional losses
132
+ lambda_lpips: 0.5
133
+ lpips_img_size_min: 704
134
+ lpips_chunk_size: 32
135
+ lambda_ssim: 0.0
136
+ use_depth: true
137
+ lambda_depth: 0.05
138
+ lambda_opacity: 0.0
139
+ gaussians_prune_ratio: 0.0
140
+ gaussians_random_ratio: 0.0
141
+
142
+ # Dynamic
143
+ use_time_embedding: false
144
+ use_interp_target: false
145
+ static_time: false
146
+ time_embedding: true
147
+ time_embedding_dim: 3
148
+ time_embedding_vae: true
149
+ time_embedding_use_orig: true
150
+ timesteps_eps: 0.
151
+ select_target_views_input_dynamic: true
152
+
153
+ # Data
154
+ num_test_scenes: 16
155
+ subsample_data_train_val: true
156
+ mirror_static: true
157
+ mirror_dynamic: true
158
+ set_manual_time_idx: false
159
+ load_latents: true
160
+ subsample_target: null
cosmos_predict1/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_predict1/autoregressive/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_predict1/autoregressive/callbacks/video_sampling_teacher_forcing.py ADDED
@@ -0,0 +1,352 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import glob
17
+ import math
18
+ import os
19
+ from typing import Optional
20
+
21
+ import numpy as np
22
+ import torch
23
+ import torchvision
24
+ import torchvision.transforms.functional as torchvision_F
25
+ import wandb
26
+ from einops import rearrange
27
+ from megatron.core import parallel_state
28
+ from torch.distributed import get_process_group_ranks
29
+
30
+ from cosmos_predict1.autoregressive.utils.parallel import (
31
+ broadcast_data_batch_in_tp_cp_group,
32
+ gather_batch_from_cp_ranks,
33
+ get_batch_on_this_cp_rank,
34
+ )
35
+ from cosmos_predict1.callbacks.every_n import EveryN
36
+ from cosmos_predict1.utils import distributed, log, misc
37
+ from cosmos_predict1.utils.model import Model
38
+ from cosmos_predict1.utils.trainer import Trainer
39
+
40
+
41
+ def resize_image(image: torch.Tensor, resize_factor=0.5) -> torch.Tensor:
42
+ _, _, h, w = image.shape
43
+ new_h, new_w = int(resize_factor * h), int(resize_factor * w)
44
+ return torchvision_F.resize(image, (new_h, new_w))
45
+
46
+
47
+ class VideoSamplingTeacherForcing(EveryN):
48
+ def __init__(
49
+ self,
50
+ every_n: int,
51
+ step_size: int = 1,
52
+ video_latent_shape: list = [6, 24, 40],
53
+ num_frames_to_display: int = 4,
54
+ save_folder: Optional[str] = None,
55
+ num_file_to_log: int = 8,
56
+ ):
57
+ r"""
58
+ This callback enables us to perform teacher forcing inference on the training data.
59
+ By teacher forcing, we mean providing ground truth video tokens as inputs, and simply asking the model
60
+ to predict the next tokens. The predicted next tokens are then visualized. This does not perform
61
+ autoregressive sampling.
62
+ We also upload the downsampled video frames to wandb. Downsampling is needed for wandb to work fast.
63
+
64
+ Args:
65
+ every_n (int): Call this callback every_n steps
66
+ step_size (int): Number of steps taken for gradient accumulation. Global iteration number is
67
+ iteration // self.step_size
68
+ video_latent_shape (list): Shape of the video latent
69
+ num_frames_to_display (int): Number of frames to subsample for displaying in wandb
70
+ save_folder (str): Name of the local folder to save the video
71
+ num_file_to_log (int): Number of files to upload to wandb
72
+ """
73
+ super().__init__(every_n, step_size)
74
+ self.save_folder = save_folder if save_folder else self.__class__.__name__
75
+ self.video_latent_shape = video_latent_shape
76
+ self.num_frames_to_display = num_frames_to_display
77
+ self.num_file_to_log = num_file_to_log
78
+ self.rank = distributed.get_rank()
79
+
80
+ def on_train_start(self, model: Model, iteration: int = 0) -> None:
81
+ config_job = self.config.job
82
+ self.local_dir = f"{config_job.path_local}/{self.save_folder}"
83
+ if self.rank == 0:
84
+ os.makedirs(self.local_dir, exist_ok=True)
85
+ log.info(f"Video Teacher-Forcing Callback: local_dir: {self.local_dir}")
86
+
87
+ @torch.inference_mode()
88
+ def every_n_impl(
89
+ self,
90
+ trainer: Trainer,
91
+ model: Model,
92
+ data_batch: dict[str, torch.Tensor],
93
+ output_batch: dict[str, torch.Tensor],
94
+ loss: torch.Tensor,
95
+ iteration: int,
96
+ ) -> None:
97
+ # Tokenize the data
98
+
99
+ broadcast_data_batch_in_tp_cp_group(data_batch)
100
+
101
+ input_vid = data_batch[model.tokenizer.tokenizer_config.video_tokenizer.data_key]
102
+
103
+ dataset_name = data_batch.get("dataset_name", None)
104
+ if dataset_name is not None and dataset_name.startswith("image"):
105
+ # we disable the callback if the input video is an image batch
106
+ log.info(f"dataset_name is {dataset_name}, skip this callback")
107
+ return
108
+
109
+ # get the caption
110
+ captions = data_batch.get("caption", None)
111
+
112
+ # get the context embedding and mask
113
+ context = data_batch.get("context", None)
114
+ context_mask = data_batch.get("context_mask", None)
115
+ if context is not None:
116
+ context = misc.to(context, "cuda").detach().clone()
117
+ if context_mask is not None:
118
+ context_mask = misc.to(context_mask, "cuda").detach().clone()
119
+ # get the action
120
+ action = data_batch.get("action", None)
121
+ if action is not None:
122
+ action = misc.to(action, "cuda").detach().clone()
123
+
124
+ # Input tokens
125
+ tokens, _ = model.tokenizer.tokenize(data_batch)
126
+ tokens = misc.to(tokens, "cuda").detach().clone()
127
+ skip_save_file = False
128
+ if parallel_state.get_context_parallel_world_size() > 1:
129
+ cp_group = parallel_state.get_context_parallel_group()
130
+ if self.rank != min(get_process_group_ranks(cp_group)):
131
+ skip_save_file = True
132
+ tokens = get_batch_on_this_cp_rank(tokens)
133
+ if parallel_state.get_tensor_model_parallel_world_size() > 1:
134
+ # Turn on TP
135
+ tp_group = parallel_state.get_tensor_model_parallel_group()
136
+ if self.rank != min(get_process_group_ranks(tp_group)):
137
+ skip_save_file = True
138
+ tokens_encoded_in_train = output_batch["encode_tokens"].detach()
139
+ percent_token_diff = (tokens != tokens_encoded_in_train).float().mean()
140
+ percent_token_diff = distributed.dist_reduce_tensor(percent_token_diff)
141
+
142
+ input_tokens = tokens
143
+
144
+ num_tokens_to_generate = np.prod(self.video_latent_shape)
145
+
146
+ # Do a forward pass
147
+ logits = model.model.forward(
148
+ tokens,
149
+ input_pos=None,
150
+ context=context,
151
+ context_mask=context_mask,
152
+ action=action,
153
+ )
154
+ if parallel_state.get_context_parallel_world_size() > 1:
155
+ logits = gather_batch_from_cp_ranks(logits)
156
+ input_tokens = gather_batch_from_cp_ranks(input_tokens)
157
+
158
+ # Start position for video tokens in the vocabulary
159
+ video_token_start = self.config.model.tokenizer_config.video_tokenizer.tokenizer_offset
160
+ video_vocab_size = self.config.model.tokenizer_config.video_tokenizer.vocab_size
161
+
162
+ # Clipping logits only to video tokens. We remove the text vocab predictions.
163
+ # This will ensure that the video tokens only correspond to the video part of the vocabulary.
164
+ logits = logits[:, :, video_token_start : video_token_start + video_vocab_size]
165
+
166
+ # Sample with argmax token. This should be good for teacher forcing experiment.
167
+ logits = logits.contiguous()
168
+ generations = torch.argmax(logits, dim=-1)
169
+
170
+ # For each video in the batch, subsample frames for display
171
+ batch_size = input_tokens.shape[0]
172
+ out_frames = []
173
+ out_videos_gen = []
174
+ out_videos_rec = []
175
+ out_videos_gt = []
176
+ # log the accuracy of teacher-forcing
177
+ acc = []
178
+ loss_list = []
179
+
180
+ for sample_num in range(batch_size):
181
+ # Subsample the generations to the video part.
182
+ # This corresponds to the part from begin of video to end of video.
183
+ bov_token = model.tokenizer.video_special_tokens["<|begin_of_video|>"]
184
+ bov_index = input_tokens[sample_num] == bov_token
185
+ use_special_token = sum(bov_index) != 0
186
+ if use_special_token:
187
+ bov_index = bov_index.nonzero().item()
188
+ # generations: <bov> real_token1 real_token2, ... real_token7680; total 7680
189
+ # gen_video_tokens: real_token1 real_token2, ..., real_token7680; total 7680
190
+ # for vis: real_token1 real_token2, ..., real_token7680; total 7680
191
+ # for accuracy: real_token1 real_token2, ..., real_token7680; total 7680
192
+ gen_video_tokens = generations[sample_num][bov_index : bov_index + num_tokens_to_generate]
193
+ gen_video_tokens_vis = gen_video_tokens
194
+ gen_video_tokens_acc = gen_video_tokens
195
+ logits_loss = logits[sample_num][bov_index : bov_index + num_tokens_to_generate]
196
+ else:
197
+ # generations: real_token1 real_token2, ... real_token7680
198
+ # gen_video_tokens: real_token2 real_token3, ..., real_token7680; total 7679
199
+ # We need different tokens for vis and accuracy compute
200
+ # for acc: real_token2 real_token3, ..., real_token7680; total 7679
201
+ # for vis: pad_token (real_token2, ..., real_token7680); total 1 + 7679
202
+ gen_video_tokens = generations[sample_num][
203
+ : num_tokens_to_generate - 1
204
+ ] # remove the last token since there is no gt
205
+ # Since the first token is not predicted, we need to add the gt first token to make sure the shape is correct
206
+ gen_video_tokens_vis = torch.cat([input_tokens[sample_num][0:1], gen_video_tokens])
207
+ gen_video_tokens_acc = gen_video_tokens
208
+ logits_loss = logits[sample_num][: num_tokens_to_generate - 1]
209
+
210
+ # Rearrange the video to a spatial tensor
211
+ gen_video_tokens_vis_BTHW = rearrange(
212
+ gen_video_tokens_vis.unsqueeze(0),
213
+ "B (T H W) -> B T H W",
214
+ T=self.video_latent_shape[0],
215
+ H=self.video_latent_shape[1],
216
+ W=self.video_latent_shape[2],
217
+ )
218
+
219
+ # for real videos, we need to skip the bov and eov tokens for decoding
220
+ if use_special_token:
221
+ # input_tokens: <bov> real_token1 real_token2 ... <eov> <eov> ...
222
+ # real_video_tokens: real_token1 real_token2 ... real_token7680; total 7680
223
+ # for vis: real_token1 real_token2 ... real_token7680; total 7680
224
+ # for accuracy: real_token1 real_token2 ... real_token7680; total 7680; we include real_token1 since the output prediction also includes it, see gen_video_tokens_acc above
225
+ real_video_tokens = (
226
+ input_tokens[sample_num][bov_index + 1 : bov_index + num_tokens_to_generate + 1] - video_token_start
227
+ )
228
+ real_video_tokens_vis = real_video_tokens
229
+ real_video_tokens_acc = real_video_tokens
230
+ else:
231
+ # input_tokens: real_token1 real_token2,... real_token7680; total 7680
232
+ # real_video_tokens: real_token1 real_token2,... real_token7680; total 7680
233
+ # for acc: gt start from real_token2, real_token3; total 7679, remove the first token since it is not predicted
234
+ # for vis: gt start from real_token1, real_token2; total 7680
235
+ real_video_tokens = input_tokens[sample_num][:num_tokens_to_generate] - video_token_start
236
+ real_video_tokens_vis = real_video_tokens
237
+ real_video_tokens_acc = real_video_tokens[1:].flatten()
238
+
239
+ real_video_tokens_vis_BTHW = rearrange(
240
+ real_video_tokens_vis.unsqueeze(0),
241
+ "B (T H W) -> B T H W",
242
+ T=self.video_latent_shape[0],
243
+ H=self.video_latent_shape[1],
244
+ W=self.video_latent_shape[2],
245
+ )
246
+ # Calculate accuracy
247
+ correct_predictions = (gen_video_tokens_acc == real_video_tokens_acc).float()
248
+ labels = real_video_tokens_acc.clone()
249
+
250
+ if model.config.ignore_first_num_tokens > 0:
251
+ labels[: model.config.ignore_first_num_tokens] = model.tokenizer.ignore_index
252
+ select_index = labels != model.tokenizer.ignore_index
253
+ correct_predictions = correct_predictions[select_index]
254
+
255
+ loss = torch.nn.functional.cross_entropy(
256
+ logits_loss, labels, ignore_index=model.tokenizer.ignore_index, reduction="none"
257
+ )
258
+ acc.append(correct_predictions.mean() * 100.0)
259
+ loss_list.append(loss.mean())
260
+
261
+ # Decode the predicted latents
262
+ if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0:
263
+ vid_decoded = model.tokenizer.video_tokenizer.decode(gen_video_tokens_vis_BTHW.cuda())
264
+ else:
265
+ vid_decoded = model.tokenizer.video_tokenizer.decode_with_overlap(
266
+ gen_video_tokens_vis_BTHW.cuda(),
267
+ temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap,
268
+ )
269
+ # normalize decoded images from [-1, 1] to [0, 1], and clip value
270
+ vid_decoded = (vid_decoded * 0.5 + 0.5).clamp_(0, 1)
271
+ vid_decoded = vid_decoded[0]
272
+
273
+ # Decode the GT latents
274
+ if model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap == 0:
275
+ vid_rec = model.tokenizer.video_tokenizer.decode(real_video_tokens_vis_BTHW.cuda())
276
+ else:
277
+ vid_rec = model.tokenizer.video_tokenizer.decode_with_overlap(
278
+ real_video_tokens_vis_BTHW.cuda(),
279
+ temporal_overlap=model.tokenizer.tokenizer_config.video_tokenizer.temporal_overlap,
280
+ )
281
+ # normalize decoded image from [-1, 1] to [0, 1], and clip value
282
+ vid_rec = (vid_rec * 0.5 + 0.5).clamp_(0, 1)
283
+ vid_rec = vid_rec[0]
284
+
285
+ vid_input = input_vid[sample_num] # [-1, 1], input_vid shape: [B, C, L, H, W]
286
+ vid_input = (vid_input * 0.5 + 0.5).clamp_(0, 1).cuda() # Convert to [0, 1], [C, L, H, W]
287
+
288
+ # Subsample real and generated video frames
289
+ input_video_frames = vid_input.transpose(0, 1) # [L, C, H, W]
290
+ rec_video_frames = vid_rec.transpose(0, 1)
291
+ gen_video_frames = vid_decoded.transpose(0, 1)
292
+ out_videos_gen.append(gen_video_frames)
293
+ out_videos_rec.append(rec_video_frames)
294
+ out_videos_gt.append(input_video_frames)
295
+
296
+ stride = math.ceil(rec_video_frames.shape[0] / self.num_frames_to_display)
297
+
298
+ input_video_frames_subsampled = resize_image(input_video_frames[0::stride], resize_factor=0.5)
299
+ input_video_frames_subsampled = torchvision.utils.make_grid(
300
+ input_video_frames_subsampled, nrow=input_video_frames_subsampled.shape[0]
301
+ )
302
+
303
+ gt_video_frames_subsampled = resize_image(rec_video_frames[0::stride], resize_factor=0.5)
304
+ gt_video_frames_subsampled = torchvision.utils.make_grid(
305
+ gt_video_frames_subsampled, nrow=gt_video_frames_subsampled.shape[0]
306
+ )
307
+ gen_video_frames_subsampled = resize_image(gen_video_frames[0::stride], resize_factor=0.5)
308
+ gen_video_frames_subsampled = torchvision.utils.make_grid(
309
+ gen_video_frames_subsampled, nrow=gen_video_frames_subsampled.shape[0]
310
+ )
311
+
312
+ out_frames.append(input_video_frames_subsampled)
313
+ out_frames.append(gt_video_frames_subsampled)
314
+ out_frames.append(gen_video_frames_subsampled)
315
+
316
+ scaled_num_rank_to_log = (
317
+ self.num_file_to_log
318
+ * parallel_state.get_context_parallel_world_size()
319
+ * parallel_state.get_tensor_model_parallel_world_size()
320
+ )
321
+ if self.rank < scaled_num_rank_to_log and not skip_save_file:
322
+ local_path = f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_{self.rank:04d}.jpg"
323
+ out_image_grid = torchvision.utils.make_grid(out_frames, nrow=1, padding=0, normalize=False)
324
+ os.makedirs(os.path.dirname(local_path), exist_ok=True)
325
+ torchvision.utils.save_image(out_image_grid, local_path)
326
+
327
+ # Log to wandb
328
+ avg_acc = distributed.dist_reduce_tensor(torch.stack(acc).mean()).item()
329
+ avg_loss = distributed.dist_reduce_tensor(torch.stack(loss_list).mean()).item()
330
+ log_info = ""
331
+ if "acc" in output_batch:
332
+ log_info = f"train acc: {(output_batch['acc'].mean().item()):.6f}%"
333
+ if percent_token_diff is not None:
334
+ log_info += f"; percent_token_diff_train_val: {percent_token_diff.item() * 100:.6f}%"
335
+ log.info(
336
+ f"Eval iteration {iteration} teacher-forcing accuracy: {avg_acc:.6f}%, loss: {avg_loss:.4f}; {log_info}"
337
+ )
338
+ if self.rank == 0 and wandb.run:
339
+ local_files = glob.glob(f"{self.local_dir}/vid_teacher_forcing_iter_{iteration:09d}_*.jpg")
340
+ local_files = sorted(local_files)[: self.num_file_to_log]
341
+ if captions is None:
342
+ captions = ["vid_frames_teacher_forcing"] * len(local_files)
343
+ for local_path, caption in zip(local_files, captions):
344
+ wandb.log(
345
+ {"frames": [wandb.Image(local_path, caption=caption)]},
346
+ step=iteration,
347
+ )
348
+
349
+ wandb.log({"eval/teacher_forcing_acc": avg_acc}, step=iteration)
350
+ wandb.log({"eval/teacher_forcing_loss": avg_loss}, step=iteration)
351
+ if percent_token_diff is not None:
352
+ wandb.log({"eval/percent_token_diff_train_val": percent_token_diff.item() * 100}, step=iteration)
cosmos_predict1/autoregressive/configs/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_predict1/autoregressive/configs/base/__init__.py ADDED
@@ -0,0 +1,14 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
cosmos_predict1/autoregressive/configs/base/callbacks.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from cosmos_predict1.autoregressive.callbacks.video_sampling_teacher_forcing import VideoSamplingTeacherForcing
17
+ from cosmos_predict1.callbacks.grad_clip import GradClip
18
+ from cosmos_predict1.utils.callback import ProgressBarCallback
19
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
20
+
21
+ BASIC_CALLBACKS = dict(
22
+ progress_bar=L(ProgressBarCallback)(),
23
+ grad_clip=L(GradClip)(clip_norm=1.0, fsdp_enabled="${model.model_config.fsdp_enabled}", model_key="model"),
24
+ )
25
+
26
+ VIDEO_TEACHER_FORCING_CALLBACK = dict(
27
+ vid_sampling_tf=L(VideoSamplingTeacherForcing)(
28
+ every_n=500,
29
+ video_latent_shape="${model.model_config.video_latent_shape}",
30
+ num_frames_to_display=4,
31
+ save_folder="video_sampling_teacher_forcing",
32
+ )
33
+ )
cosmos_predict1/autoregressive/configs/base/dataloader.py ADDED
@@ -0,0 +1,72 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from megatron.core import parallel_state
17
+ from torch.utils.data import DataLoader, DistributedSampler
18
+
19
+ from cosmos_predict1.autoregressive.configs.base.dataset import VideoDatasetConfig
20
+ from cosmos_predict1.autoregressive.datasets.video_dataset import VideoDataset
21
+ from cosmos_predict1.utils import log
22
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
23
+
24
+ DATALOADER_OPTIONS = {}
25
+
26
+
27
+ def get_sampler(dataset):
28
+ return DistributedSampler(
29
+ dataset,
30
+ num_replicas=parallel_state.get_data_parallel_world_size(),
31
+ rank=parallel_state.get_data_parallel_rank(),
32
+ shuffle=True,
33
+ seed=0,
34
+ )
35
+
36
+
37
+ def dataloader_register(key):
38
+ log.info(f"registering dataloader {key}...")
39
+
40
+ def decorator(func):
41
+ DATALOADER_OPTIONS[key] = func
42
+ return func
43
+
44
+ return decorator
45
+
46
+
47
+ @dataloader_register("tealrobot_video")
48
+ def get_tealrobot_video(
49
+ batch_size: int = 1,
50
+ dataset_dir: str = "datasets/cosmos_nemo_assets/videos/",
51
+ sequence_interval: int = 1,
52
+ num_frames: int = 33,
53
+ video_size: list[int, int] = [640, 848],
54
+ start_frame_interval: int = 1,
55
+ ):
56
+ dataset = L(VideoDataset)(
57
+ config=VideoDatasetConfig(
58
+ dataset_dir=dataset_dir,
59
+ sequence_interval=sequence_interval,
60
+ num_frames=num_frames,
61
+ video_size=video_size,
62
+ start_frame_interval=start_frame_interval,
63
+ )
64
+ )
65
+ return L(DataLoader)(
66
+ dataset=dataset,
67
+ sampler=L(get_sampler)(dataset=dataset),
68
+ batch_size=batch_size,
69
+ drop_last=True,
70
+ pin_memory=True,
71
+ num_workers=8,
72
+ )
cosmos_predict1/autoregressive/configs/base/dataset.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Dataset config class."""
17
+
18
+ import attrs
19
+
20
+ from cosmos_predict1.utils.config import make_freezable
21
+
22
+
23
+ @make_freezable
24
+ @attrs.define(slots=False)
25
+ class VideoDatasetConfig:
26
+ """
27
+ Args:
28
+ dataset_dir (str): Base path to the dataset directory
29
+ sequence_interval (int): Interval between sampled frames in a sequence
30
+ num_frames (int): Number of frames to load per sequence
31
+ video_size (list): Target size [H,W] for video frames
32
+ start_frame_interval (int): Interval between starting frames of sequences
33
+ """
34
+
35
+ dataset_dir: str = "datasets/cosmos_nemo_assets/videos/"
36
+ sequence_interval: int = 1
37
+ num_frames: int = 33
38
+ video_size: list[int, int] = [640, 848]
39
+ start_frame_interval: int = 1
cosmos_predict1/autoregressive/configs/base/model.py ADDED
@@ -0,0 +1,318 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import attrs
19
+
20
+ from cosmos_predict1.autoregressive.configs.base.tokenizer import TokenizerConfig
21
+ from cosmos_predict1.utils import config
22
+
23
+ _ACTION_DIM = 8
24
+ from cosmos_predict1.utils.lazy_config import LazyDict
25
+
26
+
27
+ @attrs.define
28
+ class ModelConfig:
29
+ """
30
+ A class to hold model configuration arguments.
31
+
32
+ Args:
33
+ dim (int): The dimensionality of the input and output of each transformer block.
34
+ n_layers (int): Number of layers in the transformer.
35
+ n_heads (int): Number of attention heads.
36
+ n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to
37
+ `num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention.
38
+ head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads.
39
+ vocab_size (int): Vocabulary size.
40
+ ffn_hidden_size (int): Hidden size for feedforward network.
41
+ norm_eps (float): Epsilon value for normalization.
42
+ rope_theta (float): Theta value for rotary positional embeddings.
43
+ apply_abs_pos_emb (bool): Whether to apply absolute position embeddings.
44
+ max_batch_size (int): Maximum batch size for inference.
45
+ max_seq_len (int): Maximum sequence length for input text.
46
+ fuse_qkv (bool): Whether to fuse QKV in attention. Defaults to True.
47
+ causal_mask (bool): Whether to use causal mask. Defaults to True.
48
+ norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm".
49
+ precision (str): Data type for the model.
50
+ use_qk_normalization (bool): Whether to enable QK normalization.
51
+ tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1.
52
+ ckpt_dir (str): Checkpoint directory.
53
+ ckpt_path (str): Checkpoint path.
54
+ apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension).
55
+ yarn_scale (Optional[float]): Scale factor for YaRN.
56
+ yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code)
57
+ yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code)
58
+ original_seq_len (Optional[int]): Original sequence length.
59
+ vision_encoder (Optional[str]): Vision encoder name.
60
+ mm_projector (Optional[str]): Multi-modal projector name.
61
+ vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4.
62
+ rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "3D".
63
+ pytorch_rope_version (Optional[str]): Version of the PyTorch RoPE implementation. Choices: "v1", "v2".
64
+ original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension.
65
+ pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
66
+ vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3.
67
+ insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
68
+ insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
69
+ context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
70
+ num_video_frames (Optional[int]): Number of video frames.
71
+ video_height (Optional[int]): Raw video pixel height dimension.
72
+ video_width (Optional[int]): Raw video pixel width dimension.
73
+ video_latent_shape (Optional[list]): Video tokenizer output dimension, in (T,H,W).
74
+ """
75
+
76
+ dim: int = attrs.field(default=4096)
77
+ n_layers: int = attrs.field(default=32)
78
+ n_heads: int = attrs.field(default=32)
79
+ n_kv_heads: Optional[int] = attrs.field(default=8)
80
+ head_dim: Optional[int] = attrs.field(default=None)
81
+ vocab_size: int = attrs.field(default=128256)
82
+ ffn_hidden_size: int = attrs.field(default=14336)
83
+ norm_eps: float = attrs.field(default=1e-5)
84
+ rope_theta: float = attrs.field(default=500000)
85
+ apply_abs_pos_emb: bool = attrs.field(default=False)
86
+ max_batch_size: int = attrs.field(default=1)
87
+ max_seq_len: int = attrs.field(default=8192)
88
+ fuse_qkv: bool = attrs.field(default=False)
89
+ causal_mask: bool = attrs.field(default=True)
90
+ norm_type: str = attrs.field(default="rmsnorm")
91
+ precision: str = attrs.field(default="bfloat16")
92
+ use_qk_normalization: bool = False
93
+ tokenizer: Optional[TokenizerConfig] = None
94
+ tensor_model_parallel_size: int = attrs.field(default=1)
95
+ ckpt_dir: Optional[str] = attrs.field(default=None)
96
+ ckpt_path: Optional[str] = attrs.field(
97
+ default=None
98
+ ) # If not None, load the model from this path instead of ckpt_dir
99
+ apply_yarn: Optional[bool] = attrs.field(default=False)
100
+ yarn_scale: Optional[float] = attrs.field(default=None)
101
+ yarn_beta_fast: Optional[int] = attrs.field(default=None)
102
+ yarn_beta_slow: Optional[int] = attrs.field(default=None)
103
+ original_seq_len: Optional[int] = attrs.field(default=None)
104
+ vision_encoder: Optional[str] = attrs.field(default=None)
105
+ vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
106
+ mm_projector: Optional[str] = attrs.field(default=None)
107
+ rope_dim: Optional[str] = attrs.field(default="1D")
108
+ pytorch_rope_version: Optional[str] = attrs.field(default="v2")
109
+ original_latent_shape: Optional[list] = None
110
+ pad_to_multiple_of: Optional[int] = None
111
+ vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
112
+ insert_cross_attn: bool = False
113
+ insert_cross_attn_every_k_layers: int = 1
114
+ context_dim: Optional[int] = attrs.field(default=1024)
115
+ # For video training
116
+ num_video_frames: Optional[int] = None
117
+ # Raw video pixel dimension
118
+ video_height: Optional[int] = None
119
+ video_width: Optional[int] = None
120
+ # Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact
121
+ video_latent_shape: Optional[list] = None
122
+
123
+ def __getitem__(self, item):
124
+ return getattr(self, item)
125
+
126
+
127
+ @attrs.define
128
+ class TrainingModelConfig:
129
+ """
130
+ A class to hold model configuration arguments.
131
+
132
+ Args:
133
+ dim (int): The dimensionality of the input and output of each transformer block.
134
+ n_layers (int): Number of layers in the transformer.
135
+ n_heads (int): Number of attention heads.
136
+ n_kv_heads (Optional[int]): Number of key-value heads. If None, defaults to n_heads. Note: this is equivalent to
137
+ `num_gqa_groups` in TransformerEngine, where GQA means Grouped Query Attention.
138
+ head_dim (Optional[int]): Dimensionality of each head. If None, defaults to dim // n_heads.
139
+ vocab_size (int): Vocabulary size.
140
+ multiple_of (int): Ensures the hidden layer size is a multiple of this value for SwiGLU activation.
141
+ ffn_dim_multiplier (Optional[float]): Multiplier for feedforward network dimension.
142
+ ffn_hidden_size (Optional[int]): Hidden size for feedforward network. If None, use ffn_dim_multiplier to compute it.
143
+ norm_eps (float): Epsilon value for normalization.
144
+ rope_theta (float): Theta value for rotary positional embeddings.
145
+ apply_abs_pos_emb (bool): Whether to apply absolute position embeddings.
146
+ max_batch_size (int): Maximum batch size for inference (determines KV cache size).
147
+ max_seq_len (int): Maximum sequence length for input text (determines KV cache size).
148
+ fuse_qkv (bool): Whether to fuse QKV in attention. Flag for the pytorch backend.
149
+ causal_mask (bool): Whether to use causal mask. Defaults to True.
150
+ flash_attn (bool): Whether to use Flash attention.
151
+ norm_type (str): Type of normalization layer. Choices: "rmsnorm", "fused_rmsnorm", "layernorm", "np_layernorm".
152
+ backend (str): Backend for the model.
153
+ precision (str): Data type for the model.
154
+ ema (config.EMAConfig): Configuration for exponential moving average.
155
+ embedding_dropout(float): Dropout rate for the embedding layer.
156
+ attention_dropout(float): Dropout rate for attention.
157
+ hidden_dropout(float): Dropout after the attention and feed-forward layers (following TransformerEngine's
158
+ implementation in its TransformerLayer class).
159
+ use_qk_normalization (bool): Whether to enable QK normalization.
160
+ inference (bool): Whether the model is used for inference.
161
+ act_ckpt_enabled (bool): Whether to enable activation checkpointing.
162
+ fsdp_enabled (bool): Whether to enable FSDP.
163
+ fsdp (LazyDict): Configuration for FSDP.
164
+ ckpt_dir (str): Checkpoint directory.
165
+ ckpt_path (str): Checkpoint path.
166
+ cache_dir (str): Cache directory.
167
+ apply_yarn (Optional[bool]): Whether to apply YaRN (long-context extension).
168
+ yarn_scale (Optional[float]): Scale factor for YaRN.
169
+ yarn_beta_fast (Optional[int]): Beta fast variable for YaRN (i.e., low_freq_factor in Llama 3.1 RoPE scaling code)
170
+ yarn_beta_slow (Optional[int]): Beta slow variable for YaRN (i.e., high_freq_factor in Llama 3.1 RoPE scaling code)
171
+ original_seq_len (Optional[int]): Original sequence length.
172
+ depth_init (bool): If `True`, then each transformer block init uses its layer ID, and if `False`, each uses the
173
+ total number of transformer blocks. Defaults to `True` (following the TorchTitan implementation of Llama3).
174
+ context_parallel_size (int): Context parallel size. Defaults to 1.
175
+ tensor_model_parallel_size (int): Tensor model parallel size. Defaults to 1.
176
+ sequence_parallel (bool): Whether to use sequence parallelism. Defaults to False.
177
+ set_parallel_mode (bool): It is a boolean flag used by TransformerEngine to handle Tensor Parallelism.
178
+ Essentially, it is equivalent to `tensor_model_parallel_size > 1`. Defaults to `False`.
179
+ attention_tp (bool): Whether to use tensor parallelism for attention layers.
180
+ mm_projector (Optional[str]): Multimodal projector used for vision-language modeling. Defaults to None.
181
+ Choices: "identity", "linear", "mlp", "mlp_downsample".
182
+ video_latent_shape (Optional[list]): Shape of the video latent tensor. [T, H, W]
183
+ image_latent_shape (Optional[list]): Shape of the image latent tensor. [H, W]
184
+ num_video_frames (Optional[int]): Number of video frames.
185
+ rope_dim (Optional[str]): Dimensionality of the RoPE. Choices: "1D", "2D", "3D".
186
+ pytorch_rope_version (Optional[str]): Version of the RoPE for the `pytorch` backend. "v1" is the Llama implementation, and "v2" is HuggingFace/TransformerEngine implementation.
187
+ original_latent_shape (Optional[list]): Original shape of the latent tensor needed for rope extension.
188
+ pad_to_multiple_of (Optional[int]): Pad the position embedding to a multiple of this value.
189
+ peft_last_n_layers (Optional[int]): Number of last few layers to fine-tune in Parameter Efficient Fine-Tuning (PEFT). When this and peft_every_n_layers are both 0, it means all layers are fine-tuned (FFT).
190
+ peft_every_n_layers (Optional[int]): In Parameter Efficient Fine-Tuning (PEFT), every n layers are unfrozen and can be trained (in flamingo style). When this and peft_last_n_layers are both 0,
191
+ it means all layers are fine-tuned (FFT). For example, for a 40 layer model, n=8 means training layers 7, 15, 23, 31, 39, which includes the final layer.
192
+ It is advised to pick n such that the final layer is included.
193
+ freeze_vision_encoder (bool): Whether to freeze the vision encoder in vision-language model training. Defaults to False.
194
+ vision_encoder_in_channels (Optional[int]): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4-channel images with the last channel as the alpha channel, set this to 4.
195
+ insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
196
+ insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
197
+ context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
198
+ finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn).
199
+ finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn).
200
+ use_action_condition (bool): Whether to use the robot action condition.
201
+ action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp".
202
+ action_dim (Optional[int]): The dimensionality of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]).
203
+ action_embedding_dim (Optional[int]): The dimensionality of the robot action embedding.
204
+ group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal".
205
+ sync_1d_parameters (bool): Whether to synchronize layernorm parameters (1D) across tensor parallel ranks (default True).
206
+ Note: this is to ensure all TP-ranks have the same layernorm parameters.
207
+ z_loss_coeff (float): The coefficient for the z-loss.
208
+ insert_medusa_head (bool): Whether to insert the Medusa head.
209
+ ft_medusa_option (str): Options on which layers to finetune, choices like:
210
+ "fft": fully fine-tune both medusa heads and all LLM backbone;
211
+ "head": fine-tune medusa heads;
212
+ "head_out": fine-tune medusa heads, and the output layer;
213
+ "head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone.
214
+ medusa_num_heads (int): Number of heads in the Medusa head.
215
+ medusa_num_layers (int): Number of layers in the Medusa head.
216
+ medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1.
217
+ zero_init_cross_attn_proj (bool): Whether to initialize the cross-attn proj layer with zeros (default False).
218
+ concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False).
219
+ """
220
+
221
+ dim: int = attrs.field(default=4096)
222
+ n_layers: int = attrs.field(default=32)
223
+ n_heads: int = attrs.field(default=32)
224
+ n_kv_heads: Optional[int] = attrs.field(default=8)
225
+ head_dim: Optional[int] = attrs.field(default=None)
226
+ vocab_size: int = attrs.field(default=128256)
227
+ multiple_of: int = attrs.field(default=1024) # make SwiGLU hidden layer size multiple of large power of 2
228
+ ffn_dim_multiplier: Optional[float] = attrs.field(default=1.3)
229
+ ffn_hidden_size: Optional[int] = attrs.field(default=None)
230
+ norm_eps: float = attrs.field(default=1e-5)
231
+ rope_theta: float = attrs.field(default=500000)
232
+ apply_abs_pos_emb: bool = attrs.field(default=False)
233
+ max_batch_size: int = attrs.field(default=1)
234
+ max_seq_len: int = attrs.field(default=8192)
235
+ fuse_qkv: bool = attrs.field(default=False)
236
+ causal_mask: bool = attrs.field(default=True)
237
+ flash_attn: bool = attrs.field(default=True)
238
+ norm_type: str = attrs.field(default="rmsnorm")
239
+ backend: str = attrs.field(default="pytorch")
240
+ precision: str = attrs.field(default="bfloat16")
241
+ ema: config.EMAConfig = config.EMAConfig(enabled=False)
242
+ embedding_dropout: float = 0.0
243
+ attention_dropout: float = 0.0
244
+ hidden_dropout: float = 0.0
245
+ use_qk_normalization: bool = False
246
+ tokenizer: Optional[TokenizerConfig] = None
247
+ inference: bool = False
248
+ act_ckpt_enabled: bool = False
249
+ fsdp_enabled: bool = False
250
+ context_parallel_size: int = attrs.field(default=1)
251
+ tensor_model_parallel_size: int = attrs.field(default=1)
252
+ sequence_parallel: bool = attrs.field(default=False)
253
+ set_parallel_mode: bool = attrs.field(default=False)
254
+ fsdp: LazyDict = LazyDict(
255
+ dict(
256
+ policy="auto", # choices: ["size", "auto"]
257
+ min_num_params=1024, # Used as policy == "size"
258
+ sharding_strategy="hybrid", # Choices: ["full", "hybrid"]. "full" means sharding_group_size = world_size
259
+ sharding_group_size=8, # If None, defaults to min(world_size, 8). Recommends 8 for training on 8-GPU nodes.
260
+ )
261
+ )
262
+ ckpt_dir: Optional[str] = attrs.field(default="")
263
+ ckpt_path: Optional[str] = attrs.field(
264
+ default=None
265
+ ) # If not None, load the model from this path instead of ckpt_dir
266
+ cache_dir: Optional[str] = attrs.field(default="/project/cosmos/ar/cache")
267
+ apply_yarn: Optional[bool] = attrs.field(default=False)
268
+ yarn_scale: Optional[float] = attrs.field(default=None)
269
+ yarn_beta_fast: Optional[int] = attrs.field(default=None)
270
+ yarn_beta_slow: Optional[int] = attrs.field(default=None)
271
+ original_seq_len: Optional[int] = attrs.field(default=None)
272
+ depth_init: bool = attrs.field(default=True)
273
+ ignore_first_num_tokens: int = 0
274
+ z_loss_coeff: float = 1e-4
275
+ attention_tp: bool = False
276
+ vision_encoder: Optional[str] = attrs.field(default=None)
277
+ mm_projector: Optional[str] = attrs.field(default=None)
278
+ rope_dim: Optional[str] = attrs.field(default="1D")
279
+ pytorch_rope_version: Optional[str] = attrs.field(default="v2")
280
+ original_latent_shape: Optional[list] = None
281
+ pad_to_multiple_of: Optional[int] = None
282
+ peft_last_n_layers: Optional[int] = attrs.field(default=0)
283
+ peft_every_n_layers: Optional[int] = attrs.field(default=0)
284
+ freeze_vision_encoder: bool = False
285
+ vision_encoder_in_channels: Optional[int] = attrs.field(default=3)
286
+ insert_cross_attn: bool = False
287
+ insert_cross_attn_every_k_layers: int = 1
288
+ context_dim: Optional[int] = attrs.field(default=1024)
289
+ finetune_layers_with_cross_attn: bool = False
290
+ finetune_layers_without_cross_attn: bool = False
291
+ use_action_condition: bool = False
292
+ action_embedding_mode: Optional[str] = attrs.field(default="mlp")
293
+ action_dim: Optional[int] = attrs.field(default=_ACTION_DIM)
294
+ action_embedding_dim: Optional[int] = attrs.field(default=1024)
295
+ group_causal_mask_mode: Optional[str] = attrs.field(default=None)
296
+ sync_1d_parameters: bool = True
297
+ # hyper-parameters for the medusa head configs
298
+ insert_medusa_head: bool = False
299
+ ft_medusa_option: str = "fft"
300
+ medusa_num_heads: int = 7
301
+ medusa_num_layers: int = 1
302
+ medusa_concat_heads: bool = True
303
+ # For video training
304
+ num_video_frames: Optional[int] = None
305
+ # Raw video pixel dimension
306
+ video_height: Optional[int] = None
307
+ video_width: Optional[int] = None
308
+ # Video tokenizer output dimension, in (T,H,W), it's computed by num_video_frames/temporal_compress_factor, video_height/spatial_compression_fact, video_width/spatial_compression_fact
309
+ video_latent_shape: Optional[list] = None
310
+ # For image training
311
+ image_latent_shape: Optional[list] = None
312
+ # For robot training (action)
313
+ zero_init_cross_attn_proj: bool = False
314
+ # For robot training (action)
315
+ concat_action_to_context: bool = False
316
+
317
+ def __getitem__(self, item):
318
+ return getattr(self, item)
cosmos_predict1/autoregressive/configs/base/model_config.py ADDED
@@ -0,0 +1,718 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import copy
17
+ from typing import Callable, List, Optional
18
+
19
+ import torch
20
+ from megatron.core import ModelParallelConfig
21
+
22
+ from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TrainingModelConfig
23
+ from cosmos_predict1.autoregressive.configs.base.tokenizer import (
24
+ TextTokenizerConfig,
25
+ TokenizerConfig,
26
+ VideoTokenizerConfig,
27
+ create_discrete_video_fsq_tokenizer_state_dict_config,
28
+ )
29
+ from cosmos_predict1.autoregressive.tokenizer.image_text_tokenizer import ImageTextTokenizer
30
+ from cosmos_predict1.autoregressive.tokenizer.text_tokenizer import TextTokenizer
31
+ from cosmos_predict1.autoregressive.training.model import AutoRegressiveTrainingModel
32
+ from cosmos_predict1.utils import log
33
+ from cosmos_predict1.utils.config import EMAConfig
34
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
35
+
36
+ # Common architecture specifications
37
+ BASE_CONFIG = {"n_kv_heads": 8, "norm_type": "rmsnorm", "norm_eps": 1e-5, "ffn_hidden_size": 14336}
38
+ COSMOS_ARCHITECTURES = {
39
+ "1b": {
40
+ "n_layers": 16,
41
+ "dim": 2048,
42
+ "n_heads": 32,
43
+ },
44
+ "4b": {
45
+ "n_layers": 16,
46
+ "dim": 4096,
47
+ "n_heads": 32,
48
+ },
49
+ "12b": {
50
+ "n_layers": 40,
51
+ "dim": 5120,
52
+ "n_heads": 32,
53
+ "head_dim": 128,
54
+ },
55
+ }
56
+
57
+ COSMOS_YARN_CONFIG = {
58
+ "original_latent_shape": [3, 40, 64],
59
+ "apply_yarn": True,
60
+ "yarn_beta_fast": 4,
61
+ "yarn_beta_slow": 1,
62
+ "yarn_scale": 2,
63
+ }
64
+
65
+ # Llama3 architecture specifications for different model sizes
66
+ LLAMA3_ARCHITECTURES = {
67
+ "8b": {
68
+ "n_layers": 32,
69
+ "dim": 4096,
70
+ "n_heads": 32,
71
+ "ffn_hidden_size": 14336,
72
+ },
73
+ }
74
+ # Llama3.1 uses YaRN for long context support (context of 128k tokens)
75
+ LLAMA_YARN_CONFIG = {
76
+ "apply_yarn": True,
77
+ "yarn_scale": 8,
78
+ "yarn_beta_fast": 4,
79
+ "yarn_beta_slow": 1,
80
+ }
81
+
82
+ # Mistral architecture specifications for different model sizes
83
+ MISTRAL_ARCHITECTURES = {
84
+ "12b": {
85
+ "n_layers": 40,
86
+ "dim": 5120,
87
+ "n_heads": 32,
88
+ "ffn_hidden_size": 14336,
89
+ "head_dim": 128,
90
+ },
91
+ }
92
+
93
+ PIXTRAL_VISION_ARCHITECTURES = {
94
+ "12b": {"vision_encoder": "pixtral-12b-vit", "mm_projector": "mlp"},
95
+ }
96
+
97
+
98
+ def get_model_arch_specs(model_size: str, model_family: str = "mistral", pretrained: bool = False) -> dict:
99
+ """
100
+ Get the model architecture specifications for the given model size, model family and pretrained status.
101
+
102
+ Args:
103
+ model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", etc.
104
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral"
105
+ pretrained (bool): Whether to load pretrained weights.
106
+
107
+ Returns:
108
+ dict: A dictionary containing the model architecture specifications.
109
+ """
110
+ arch_specs = copy.deepcopy(BASE_CONFIG)
111
+ model_size = model_size.lower()
112
+ if model_family.startswith("cosmos"):
113
+ arch_specs.update(COSMOS_ARCHITECTURES[model_size])
114
+ elif model_family.startswith("llama"):
115
+ arch_specs.update(LLAMA3_ARCHITECTURES[model_size])
116
+ elif model_family in ["mistral", "pixtral"]:
117
+ arch_specs.update(MISTRAL_ARCHITECTURES[model_size])
118
+ if model_family == "pixtral":
119
+ arch_specs.update(PIXTRAL_VISION_ARCHITECTURES[model_size])
120
+ else:
121
+ raise ValueError(f"Model family {model_family} is not supported.")
122
+
123
+ if pretrained:
124
+ if model_family == "cosmos":
125
+ if model_size == "12b":
126
+ arch_specs.update(COSMOS_YARN_CONFIG)
127
+ log.debug(f"Using YaRN for RoPE extension with config: {COSMOS_YARN_CONFIG}")
128
+ else:
129
+ pass
130
+ elif model_family in ["llama", "llama3"]:
131
+ pretrained_specs = {
132
+ "rope_theta": 500000,
133
+ "max_seq_len": 8192,
134
+ "vocab_size": 128256,
135
+ }
136
+ arch_specs.update(pretrained_specs)
137
+ elif model_family == "llama3.1":
138
+ pretrained_specs = {
139
+ "rope_theta": 500000,
140
+ "max_seq_len": 131072,
141
+ "original_seq_len": 8192,
142
+ "vocab_size": 128256,
143
+ **LLAMA_YARN_CONFIG,
144
+ }
145
+ arch_specs.update(pretrained_specs)
146
+ elif model_family == "mistral":
147
+ assert model_size == "12b", "We only support Mistral-Nemo-12B model."
148
+ pretrained_specs = {
149
+ "rope_theta": 1000000,
150
+ "max_seq_len": 128000,
151
+ "vocab_size": 131072,
152
+ }
153
+ arch_specs.update(pretrained_specs)
154
+ elif model_family == "pixtral":
155
+ assert model_size == "12b", "We only support Pixtral 12B model."
156
+ pretrained_specs = {"rope_theta": 1000000000, "max_seq_len": 128000, "vocab_size": 131072}
157
+ arch_specs.update(pretrained_specs)
158
+ else:
159
+ raise ValueError(f"Model family {model_family} doesn't have a pretrained config.")
160
+
161
+ return arch_specs
162
+
163
+
164
+ def create_text_model_config(
165
+ model_ckpt_path: str,
166
+ tokenizer_path: str,
167
+ tensor_model_parallel_size: int = 1,
168
+ model_family: str = "mistral",
169
+ model_size: str = "12b",
170
+ is_instruct_model: bool = True,
171
+ max_seq_len: int = None,
172
+ max_batch_size: int = 1,
173
+ rope_dim: str = "1D",
174
+ add_special_tokens: bool = True,
175
+ pytorch_rope_version: str = None,
176
+ ) -> dict:
177
+ """Create a text model for training or inference.
178
+ Args:
179
+ model_ckpt_path (str): Path to the model checkpoint.
180
+ tokenizer_path (str): Path to the tokenizer folder.
181
+ tensor_model_parallel_size (int): Number of tensor model parallel groups.
182
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
183
+ model_size (str): Model size. Choices: "1b", "3b", "4b", "7b", "8b", "72b", etc.
184
+ is_instruct_model (bool): Whether the model is an instruct model.
185
+ inference (bool): Whether to create the model for inference.
186
+ max_seq_len (int): Maximum sequence length.
187
+ max_batch_size (int): Maximum batch size.
188
+ rope_dim (str): RoPE dimension. Choices: "1D", "3D".
189
+ add_special_tokens (bool): Whether to add special tokens.
190
+ Returns:
191
+ dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
192
+ """
193
+ # Model size specific parameters
194
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
195
+ if max_seq_len is not None:
196
+ # Override the max_seq_len if provided
197
+ model_arch_specs["max_seq_len"] = max_seq_len
198
+ if pytorch_rope_version is not None:
199
+ model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
200
+ model_config = ModelConfig(
201
+ max_batch_size=max_batch_size,
202
+ precision="bfloat16",
203
+ ckpt_path=model_ckpt_path,
204
+ use_qk_normalization=False,
205
+ tensor_model_parallel_size=tensor_model_parallel_size,
206
+ rope_dim=rope_dim,
207
+ **model_arch_specs,
208
+ )
209
+
210
+ tokenizer_config = TokenizerConfig(
211
+ text_tokenizer=TextTokenizerConfig(
212
+ config=L(TextTokenizer)(
213
+ model_family=model_family,
214
+ is_instruct_model=is_instruct_model,
215
+ local_path=tokenizer_path,
216
+ ),
217
+ data_key="text",
218
+ tokenizer_offset=model_config.vocab_size,
219
+ tokenize_here=False,
220
+ vocab_size=model_config.vocab_size,
221
+ ),
222
+ seq_len=model_config.max_seq_len,
223
+ training_type="text_only",
224
+ add_special_tokens=add_special_tokens,
225
+ )
226
+ return model_config, tokenizer_config
227
+
228
+
229
+ def create_vision_language_model_config(
230
+ model_ckpt_path: str,
231
+ tokenizer_ckpt_path: str,
232
+ tensor_model_parallel_size: int = 1,
233
+ model_family: str = "pixtral",
234
+ model_size: str = "12b",
235
+ is_instruct_model: bool = True,
236
+ max_batch_size: int = 1,
237
+ rope_dim: str = "1D",
238
+ add_special_tokens: bool = True,
239
+ max_seq_len: int = None,
240
+ vision_encoder_in_channels: int = 3,
241
+ fuse_qkv: bool = False,
242
+ pytorch_rope_version: str = None,
243
+ ) -> dict:
244
+ """Create a vision-language model for training or inference.
245
+ Args:
246
+ model_ckpt_path (str): Path to the model checkpoint.
247
+ tokenizer_ckpt_path (str): Path to the tokenizer checkpoint.
248
+ tensor_model_parallel_size (int): Number of tensor model parallel groups.
249
+ model_family (str): Model family. Choices: "pixtral".
250
+ model_size (str): Model size. Choices: "12b".
251
+ is_instruct_model (bool): Whether the model is an instruct model.
252
+ rope_dim (str): RoPE dimension. Choices: "1D".
253
+ add_special_tokens (bool): Whether to add special tokens.
254
+ max_seq_len (int): Maximum sequence length.
255
+ vision_encoder_in_channels (int): Number of channels in the input image for the vision encoder. Default is 3, you can specify to int larger than 3. E.g. if you have 4 channel images where last channel is binary mask, set this to 4.
256
+ fuse_qkv (bool): Whether to fuse the QKV linear layers.
257
+ Returns:
258
+ dict: A dictionary containing the model configuration, which can be used to instantiate the model object.
259
+ """
260
+ # Model size specific parameters
261
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
262
+ if max_seq_len is not None:
263
+ # Override the max_seq_len if provided
264
+ model_arch_specs["max_seq_len"] = max_seq_len
265
+ if pytorch_rope_version is not None:
266
+ model_arch_specs["pytorch_rope_version"] = pytorch_rope_version
267
+
268
+ model_config = ModelConfig(
269
+ max_batch_size=max_batch_size,
270
+ precision="bfloat16",
271
+ ckpt_path=model_ckpt_path,
272
+ use_qk_normalization=False,
273
+ tensor_model_parallel_size=tensor_model_parallel_size,
274
+ rope_dim=rope_dim,
275
+ vision_encoder_in_channels=vision_encoder_in_channels,
276
+ fuse_qkv=fuse_qkv,
277
+ **model_arch_specs,
278
+ )
279
+ # Vision-language tokenizer
280
+ tokenizer_config = TokenizerConfig(
281
+ text_tokenizer=TextTokenizerConfig(
282
+ config=L(ImageTextTokenizer)(
283
+ model_family=model_family,
284
+ is_instruct_model=is_instruct_model,
285
+ image_processor_path=tokenizer_ckpt_path,
286
+ tokenizer_path=tokenizer_ckpt_path,
287
+ ),
288
+ data_key="image_text_interleaved",
289
+ tokenizer_offset=model_config.vocab_size,
290
+ tokenize_here=False,
291
+ vocab_size=model_config.vocab_size,
292
+ ),
293
+ seq_len=model_config.max_seq_len,
294
+ training_type="image_text_interleaved",
295
+ add_special_tokens=add_special_tokens,
296
+ )
297
+ return model_config, tokenizer_config
298
+
299
+
300
+ def create_video2world_model_config(
301
+ model_ckpt_path: str,
302
+ tokenizer_ckpt_path: str,
303
+ tensor_model_parallel_size: int = 1,
304
+ model_family: str = "cosmos",
305
+ model_size: str = "4b",
306
+ pixel_chunk_duration: int = 9,
307
+ num_video_frames: int = 36,
308
+ compression_ratio: List[int] = [8, 16, 16],
309
+ original_seq_len: int = 8192,
310
+ num_condition_latents_t: int = 1,
311
+ num_tokens_to_ignore: int = -1,
312
+ batch_size: int = 2,
313
+ video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config,
314
+ rope_dim: str = "3D",
315
+ add_special_tokens: bool = True,
316
+ video_height: int = 384,
317
+ video_width: int = 640,
318
+ use_qk_normalization: bool = True,
319
+ insert_cross_attn: bool = False,
320
+ insert_cross_attn_every_k_layers: int = 1,
321
+ context_dim: int = 1024,
322
+ training_type: str = "video_to_video",
323
+ pad_to_multiple_of: Optional[int] = 64,
324
+ vocab_size: int = 64000,
325
+ apply_abs_pos_emb: bool = False,
326
+ ) -> dict:
327
+ """Create a video-to-world model config.
328
+ Args:
329
+ tensor_model_parallel_size (int): Number of tensor model parallel groups.
330
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
331
+ model_size (str): Model size. Choices: "1b", "8b", "3b".
332
+ pixel_chunk_duration (int): Number of frames in each chunk.
333
+ num_video_frames (int): Number of video frames.
334
+ compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8].
335
+ original_seq_len (int): Original sequence length.
336
+ apply_yarn (bool): Whether to apply YaRN for long context scaling.
337
+ yarn_beta_fast (Optional[int]): Fast beta for YaRN.
338
+ yarn_beta_slow (Optional[int]): Slow beta for YaRN.
339
+ yarn_scale (Optional[int]): Scale factor for ctx extension.
340
+ use_qk_normalization (bool): Whether to use Query-Key normalization.
341
+ training_type (str): Type of training task.
342
+ batch_size (int): Batch size.
343
+ video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config
344
+ video_tokenizer_version (str): Version of the video tokenizer.
345
+ num_condition_latents_t (int): Number of conditioning latent channels
346
+ num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence
347
+ video_height (int): Height of the video frame. Defaults to 384.
348
+ video_width (int): Width of the video frame. Defaults to 640.
349
+ rope_dim (str): RoPE dimension. Choices: "1D", "3D".
350
+ add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE.
351
+ pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
352
+ vocab_size (int): Vocabulary size.
353
+ apply_abs_pos_emb (bool): Whether to apply absolute positional embeddings.
354
+ Returns:
355
+ dict: A dictionary containing the model configuration representing the model object, can be instantiated.
356
+ """
357
+ assert (
358
+ pixel_chunk_duration % compression_ratio[0] == 1
359
+ ), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})"
360
+ latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1
361
+ latent_height = video_height // compression_ratio[1]
362
+ latent_width = video_width // compression_ratio[2]
363
+ # Do some math to compute the video latent shape and sequence length
364
+ assert (
365
+ num_video_frames % pixel_chunk_duration == 0
366
+ ), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}"
367
+ video_latent_shape = [
368
+ num_video_frames // pixel_chunk_duration * latent_chunk_duration,
369
+ latent_height,
370
+ latent_width,
371
+ ]
372
+ # product of video_latent_shape
373
+ num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2]
374
+ if add_special_tokens:
375
+ seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3
376
+ seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64
377
+ # for text to video, we need to add <bov> token to indicate the start of the video
378
+ elif training_type == "text_to_video":
379
+ seq_len = num_token_video_latent + 1
380
+ else:
381
+ seq_len = num_token_video_latent
382
+
383
+ if seq_len % pad_to_multiple_of != 0:
384
+ # Round up to the nearest multiple of pad_to_multiple_of
385
+ seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
386
+
387
+ # Model size specific parameters
388
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=True)
389
+
390
+ # Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss
391
+ # If num_tokens_to_ignore is specified, use it.
392
+ # Else compute it from num_condition_latents_t
393
+ if num_tokens_to_ignore < 0:
394
+ num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t
395
+ if not add_special_tokens and num_condition_latents_t > 0:
396
+ # If there are no special tokens (bov), do a -1 so that you can compute the loss
397
+ # from the first token of the next chunk
398
+ num_tokens_to_ignore -= 1
399
+
400
+ model_config = ModelConfig(
401
+ video_height=video_height,
402
+ video_width=video_width,
403
+ max_seq_len=seq_len,
404
+ max_batch_size=batch_size,
405
+ precision="bfloat16",
406
+ ckpt_path=model_ckpt_path,
407
+ use_qk_normalization=use_qk_normalization,
408
+ vocab_size=64000,
409
+ original_seq_len=original_seq_len,
410
+ tensor_model_parallel_size=tensor_model_parallel_size,
411
+ video_latent_shape=video_latent_shape,
412
+ num_video_frames=num_video_frames,
413
+ rope_dim=rope_dim,
414
+ pad_to_multiple_of=pad_to_multiple_of,
415
+ insert_cross_attn=insert_cross_attn,
416
+ insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers,
417
+ context_dim=context_dim,
418
+ apply_abs_pos_emb=apply_abs_pos_emb,
419
+ **model_arch_specs,
420
+ )
421
+
422
+ video_tokenizer_config = video_tokenizer_config_creator(
423
+ tokenizer_ckpt_path, pixel_chunk_duration, compression_ratio
424
+ )
425
+ tokenizer_config = TokenizerConfig(
426
+ text_tokenizer=None,
427
+ video_tokenizer=VideoTokenizerConfig(
428
+ config=video_tokenizer_config,
429
+ data_key="video",
430
+ tokenizer_offset=0, # Since there is no text embeddings in the model. Note this only apply when the model is trained from scratch. If we use text pretrained model, the offset will be vocab_size of text token.
431
+ tokenize_here=True,
432
+ max_seq_len=num_token_video_latent,
433
+ vocab_size=vocab_size,
434
+ ),
435
+ seq_len=seq_len,
436
+ training_type=training_type,
437
+ add_special_tokens=add_special_tokens,
438
+ pad_to_multiple_of=pad_to_multiple_of,
439
+ )
440
+ return model_config, tokenizer_config
441
+
442
+
443
+ def create_video2world_model(
444
+ tensor_model_parallel_size: int = 1,
445
+ context_parallel_size: int = 1,
446
+ shard_checkpoint: bool = False,
447
+ model_family: str = "cosmos",
448
+ model_size: str = "1b",
449
+ backend: str = "pytorch",
450
+ pixel_chunk_duration: int = 9,
451
+ num_video_frames: int = 36,
452
+ compression_ratio: List[int] = [8, 16, 16],
453
+ original_seq_len: int = 8192,
454
+ apply_yarn: bool = False,
455
+ yarn_beta_fast: Optional[int] = None,
456
+ yarn_beta_slow: Optional[int] = None,
457
+ yarn_scale: Optional[int] = None,
458
+ num_condition_latents_t: int = 1,
459
+ num_tokens_to_ignore: int = -1,
460
+ batch_size: int = 1,
461
+ fsdp_enabled: bool = False,
462
+ act_ckpt_enabled: bool = False,
463
+ video_tokenizer_config_creator: Callable = create_discrete_video_fsq_tokenizer_state_dict_config,
464
+ rope_dim: str = "3D",
465
+ add_special_tokens: bool = False,
466
+ video_height: int = 384,
467
+ video_width: int = 640,
468
+ original_latent_shape: Optional[List[int]] = None,
469
+ use_qk_normalization: bool = True,
470
+ sequence_parallel: bool = False,
471
+ insert_cross_attn: bool = False,
472
+ insert_cross_attn_every_k_layers: int = 1,
473
+ context_dim: int = 1024,
474
+ finetune_layers_with_cross_attn: bool = False,
475
+ finetune_layers_without_cross_attn: bool = False,
476
+ use_action_condition: bool = False,
477
+ action_embedding_mode: Optional[str] = "mlp",
478
+ action_dim: int = 8, # ACTION_DIM,
479
+ action_embedding_dim: int = 1024,
480
+ group_causal_mask_mode: Optional[str] = None,
481
+ training_type: str = "video_to_video",
482
+ pad_to_multiple_of: Optional[int] = 1,
483
+ z_loss_coeff: float = 1e-4,
484
+ temporal_overlap: int = 0,
485
+ embedding_dropout: float = 0.0,
486
+ insert_medusa_head: bool = False,
487
+ ft_medusa_option: str = "fft",
488
+ medusa_num_heads: int = 7,
489
+ medusa_num_layers: int = 1,
490
+ medusa_concat_heads: bool = True,
491
+ fuse_qkv: bool = False,
492
+ zero_init_cross_attn_proj: bool = False,
493
+ concat_action_to_context: bool = False,
494
+ tokenizer_ckpt_path: str = "checkpoints/Cosmos-1.0-Tokenizer-DV8x16x16/ema.jit",
495
+ ) -> dict:
496
+ """Create a video-to-video model for training.
497
+ Args:
498
+ tensor_model_parallel_size (int): Number of tensor model parallel groups.
499
+ context_parallel_size (int): Number of context parallel groups.
500
+ model_family (str): Model family. Choices: "llama", "llama3", "llama3.1", "mistral".
501
+ model_size (str): Model size. Choices: "1b", "8b", "3b".
502
+ backend (str): Backend for the model. Choices: "pytorch", "transformer_engine".
503
+ pixel_chunk_duration (int): Number of frames in each chunk.
504
+ num_video_frames (int): Number of video frames.
505
+ compression_ratio (List[int]): Compression ratio for the video frames. Choices: [8, 16, 16] or [4, 8, 8].
506
+ original_seq_len (int): Original sequence length.
507
+ apply_yarn (bool): Whether to apply YaRN for long context scaling.
508
+ yarn_beta_fast (Optional[int]): Fast beta for YaRN.
509
+ yarn_beta_slow (Optional[int]): Slow beta for YaRN.
510
+ yarn_scale (Optional[int]): Scale factor for ctx extension.
511
+ fsdp_enabled (bool): Whether Fully Sharded Data Parallel (FSDP) is enabled.
512
+ act_ckpt_enabled (bool): Whether activation checkpointing is enabled.
513
+ use_qk_normalization (bool): Whether to use Query-Key normalization.
514
+ training_type (str): Type of training task.
515
+ batch_size (int): Batch size.
516
+ video_tokenizer_config_creator (Callable): Method that takes "pixel_chunk_duration: int" and "version: str" as arguments and returns video tokenizer config
517
+ video_tokenizer_version (str): Version of the video tokenizer.
518
+ num_condition_latents_t (int): Number of conditioning latent channels
519
+ num_tokens_to_ignore (int) = Number of tokens to ignore. This takes the precedence
520
+ video_height (int): Height of the video frame. Defaults to 384.
521
+ video_width (int): Width of the video frame. Defaults to 640.
522
+ rope_dim (str): RoPE dimension. Choices: "1D", "2D", "3D".
523
+ add_special_tokens (bool): Whether to add special tokens, use False for 2D/3D RoPE.
524
+ original_latent_shape (list): Original latent shape before RoPE scaling.
525
+ sequence_parallel (bool): Whether to enable sequence parallelism.
526
+ insert_cross_attn (bool): Whether to insert the cross-attention layers after each multi-head self-attention (MSA) layer.
527
+ insert_cross_attn_every_k_layers (int): Insert cross-attention layers every k TransformerLayers.
528
+ context_dim (Optional[int]): The dimensionality of cross-attention embedding, e.g., T5 embed feature dim.
529
+ finetune_layers_with_cross_attn (bool): Whether to finetune Transformer layers w/ CA (cross-attn).
530
+ finetune_layers_without_cross_attn (bool): Whether to finetune Transformer layers w/o CA (cross-attn).
531
+ use_action_condition (bool): Whether to use action condition.
532
+ action_embedding_mode (Optional[str]): The mode of the robot action embedding. Choices: "matrix", "mlp".
533
+ action_dim (int): Dimension of the raw robot action tensor (e.g., 7 for DROID, [Δx, Δy, Δz, rx, ry, rz, gripper_open]).
534
+ action_embedding_dim (int): Dimension of the action embedding.
535
+ group_causal_mask_mode (Optional[str]): The mode of the group causal mask. Choices: "causal", "group_diagonal".
536
+ pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
537
+ z_loss_coeff (float): Coefficient for the z loss.
538
+ temporal_overlap (int): Temporal overlap in the latent space.
539
+ embedding_dropout (float): Dropout rate for the embeddings.
540
+ insert_medusa_head (bool): Whether to insert the Medusa head.
541
+ ft_medusa_option (str): Options on which layers to finetune, choices like:
542
+ "fft": fully fine-tune both medusa heads and all LLM backbone;
543
+ "head": fine-tune medusa heads;
544
+ "head_out": fine-tune medusa heads, and the output layer;
545
+ "head_out_last_k_layer": fine-tune medusa heads, the output layer, and the last k layer(s) of the LLM backbone.
546
+ medusa_num_heads (int): Number of heads in the Medusa head.
547
+ medusa_num_layers (int): Number of layers in the Medusa head.
548
+ medusa_concat_heads (bool): Whether to concatenate multiple medusa heads into fused matrix, only applicable when medusa_num_layers = 1.
549
+ fuse_qkv (bool): Whether to fuse the QKV linear layers.
550
+ zero_init_cross_attn_proj (bool): Whether to zero-initialize the cross-attention projection weights (default False).
551
+ concat_action_to_context (bool): Whether to concatenate the action embedding to the context (default False).
552
+ Returns:
553
+ dict: A dictionary containing the model configuration representing the model object, can be instantiated.
554
+ """
555
+ assert (
556
+ pixel_chunk_duration % compression_ratio[0] == 1
557
+ ), f"pixel_chunk_duration({pixel_chunk_duration}) should be k*n + 1 (k={compression_ratio[0]})"
558
+ latent_chunk_duration = (pixel_chunk_duration - 1) // compression_ratio[0] + 1
559
+ latent_height = video_height // compression_ratio[1]
560
+ latent_width = video_width // compression_ratio[2]
561
+ # Compute the video latent shape and sequence length
562
+ if temporal_overlap == 0:
563
+ assert (
564
+ num_video_frames % pixel_chunk_duration == 0
565
+ ), f"num_video_frames {num_video_frames} should be divisible by pixel_chunk_duration {pixel_chunk_duration}"
566
+ video_latent_shape = [
567
+ num_video_frames // pixel_chunk_duration * latent_chunk_duration,
568
+ latent_height,
569
+ latent_width,
570
+ ]
571
+
572
+ else:
573
+ # Calculate temporal overlap in the latent space
574
+ temporal_overlap_latent = temporal_overlap // compression_ratio[0]
575
+
576
+ # Calculate the effective number of latent chunks for the video
577
+ latent_chunks = (num_video_frames - temporal_overlap) // (pixel_chunk_duration - temporal_overlap)
578
+
579
+ # Compute the total duration of the latent chunks, accounting for overlap
580
+ effective_latent_duration = (
581
+ latent_chunk_duration - temporal_overlap_latent
582
+ ) * latent_chunks + temporal_overlap_latent
583
+
584
+ # Define the shape of the video in the latent space
585
+ video_latent_shape = [
586
+ effective_latent_duration, # Temporal dimension
587
+ latent_height, # Height in the latent space
588
+ latent_width, # Width in the latent space
589
+ ]
590
+
591
+ # product of video_latent_shape
592
+ num_token_video_latent = video_latent_shape[0] * video_latent_shape[1] * video_latent_shape[2]
593
+ if add_special_tokens:
594
+ seq_len = num_token_video_latent + 3 # Sequence length per batch, max_seq_len + 3
595
+ seq_len = (seq_len + 63) // 64 * 64 # Round up to multiple of 64
596
+ # for text to video, we need to add <bov> token to indicate the start of the video
597
+ elif training_type == "text_to_video":
598
+ seq_len = num_token_video_latent + 1
599
+ else:
600
+ seq_len = num_token_video_latent
601
+
602
+ if seq_len % pad_to_multiple_of != 0:
603
+ # Round up to the nearest multiple of pad_to_multiple_of
604
+ seq_len = ((seq_len + pad_to_multiple_of - 1) // pad_to_multiple_of) * pad_to_multiple_of
605
+
606
+ # Model size specific parameters
607
+ model_arch_specs = get_model_arch_specs(model_family=model_family, model_size=model_size, pretrained=False)
608
+
609
+ inference = False # False for training, True for inference
610
+ # set_parallel_mode = True
611
+ set_parallel_mode = tensor_model_parallel_size > 1
612
+ attention_tp = True
613
+
614
+ if context_parallel_size > 1:
615
+ assert backend == "transformer_engine", "Context parallelism is only supported in transformer engine."
616
+
617
+ if tensor_model_parallel_size > 1:
618
+ assert set_parallel_mode, "Tensor model parallelism is only supported in parallel mode."
619
+
620
+ # Whether skip the loss for first chunk or not, note the first token is already skipped when computing the loss
621
+ # If num_tokens_to_ignore is specified, use it.
622
+ # Else compute it from num_condition_latents_t
623
+ if num_tokens_to_ignore < 0:
624
+ num_tokens_to_ignore = latent_height * latent_width * num_condition_latents_t
625
+ if not add_special_tokens and num_condition_latents_t > 0:
626
+ # If there are no special tokens (bov), do a -1 so that you can compute the loss
627
+ # from the first token of the next chunk
628
+ num_tokens_to_ignore -= 1
629
+
630
+ model_config = TrainingModelConfig(
631
+ video_height=video_height,
632
+ video_width=video_width,
633
+ max_seq_len=seq_len,
634
+ max_batch_size=batch_size,
635
+ inference=inference,
636
+ backend=backend,
637
+ precision="bfloat16",
638
+ ema=EMAConfig(enabled=False),
639
+ act_ckpt_enabled=act_ckpt_enabled,
640
+ fsdp_enabled=fsdp_enabled,
641
+ cache_dir=None,
642
+ ckpt_path="checkpoints/Cosmos-Predict1-4B/model.pt",
643
+ use_qk_normalization=use_qk_normalization,
644
+ vocab_size=64000,
645
+ ignore_first_num_tokens=num_tokens_to_ignore,
646
+ apply_yarn=apply_yarn,
647
+ yarn_beta_fast=yarn_beta_fast,
648
+ yarn_beta_slow=yarn_beta_slow,
649
+ original_seq_len=original_seq_len,
650
+ yarn_scale=yarn_scale,
651
+ context_parallel_size=context_parallel_size,
652
+ tensor_model_parallel_size=tensor_model_parallel_size,
653
+ set_parallel_mode=set_parallel_mode,
654
+ attention_tp=attention_tp,
655
+ video_latent_shape=video_latent_shape,
656
+ num_video_frames=num_video_frames,
657
+ rope_dim=rope_dim,
658
+ original_latent_shape=original_latent_shape,
659
+ pad_to_multiple_of=pad_to_multiple_of,
660
+ sequence_parallel=sequence_parallel,
661
+ insert_cross_attn=insert_cross_attn,
662
+ insert_cross_attn_every_k_layers=insert_cross_attn_every_k_layers,
663
+ context_dim=context_dim,
664
+ finetune_layers_with_cross_attn=finetune_layers_with_cross_attn,
665
+ finetune_layers_without_cross_attn=finetune_layers_without_cross_attn,
666
+ use_action_condition=use_action_condition,
667
+ action_embedding_mode=action_embedding_mode,
668
+ action_dim=action_dim,
669
+ action_embedding_dim=action_embedding_dim,
670
+ group_causal_mask_mode=group_causal_mask_mode,
671
+ z_loss_coeff=z_loss_coeff,
672
+ embedding_dropout=embedding_dropout,
673
+ insert_medusa_head=insert_medusa_head,
674
+ ft_medusa_option=ft_medusa_option,
675
+ medusa_num_heads=medusa_num_heads,
676
+ medusa_num_layers=medusa_num_layers,
677
+ medusa_concat_heads=medusa_concat_heads,
678
+ fuse_qkv=fuse_qkv,
679
+ zero_init_cross_attn_proj=zero_init_cross_attn_proj,
680
+ concat_action_to_context=concat_action_to_context,
681
+ **model_arch_specs,
682
+ )
683
+
684
+ tokenizer_config = TokenizerConfig(
685
+ text_tokenizer=None,
686
+ video_tokenizer=VideoTokenizerConfig(
687
+ config=video_tokenizer_config_creator(
688
+ ckpt_path=tokenizer_ckpt_path, pixel_chunk_duration=pixel_chunk_duration
689
+ ),
690
+ data_key="video",
691
+ tokenizer_offset=0,
692
+ vocab_size=64000,
693
+ tokenize_here=True,
694
+ max_seq_len=num_token_video_latent,
695
+ temporal_overlap=temporal_overlap,
696
+ ),
697
+ seq_len="${model.model_config.max_seq_len}",
698
+ training_type=training_type,
699
+ add_special_tokens=add_special_tokens,
700
+ pad_to_multiple_of=pad_to_multiple_of,
701
+ )
702
+
703
+ model_parallel = ModelParallelConfig(
704
+ bf16=True,
705
+ params_dtype=getattr(torch, "bfloat16"),
706
+ )
707
+ model_parallel.tensor_model_parallel_size = "${model.model_config.tensor_model_parallel_size}"
708
+ model_parallel.context_parallel_size = "${model.model_config.context_parallel_size}"
709
+ model_parallel.sequence_parallel = "${model.model_config.sequence_parallel}"
710
+ return L(AutoRegressiveTrainingModel.build)(
711
+ seed=0,
712
+ train_from_scratch=True,
713
+ model_config=model_config,
714
+ fsdp_checkpointer=None,
715
+ tokenizer_config=tokenizer_config,
716
+ model_parallel=model_parallel,
717
+ shard_checkpoint=shard_checkpoint,
718
+ )
cosmos_predict1/autoregressive/configs/base/model_parallel.py ADDED
@@ -0,0 +1,33 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from megatron.core import ModelParallelConfig
18
+
19
+ from cosmos_predict1.utils.lazy_config import LazyDict
20
+
21
+
22
+ def create_model_parallel_config():
23
+ model_parallel = ModelParallelConfig(bf16=True, params_dtype=getattr(torch, "bfloat16"))
24
+ model_parallel.tensor_model_parallel_size = "${model.model_parallel.tensor_model_parallel_size}"
25
+ model_parallel.context_parallel_size = "${model.model_parallel.context_parallel_size}"
26
+ model_parallel.sequence_parallel = "${model.model_parallel.sequence_parallel}"
27
+ MODEL_PARALLELS = LazyDict(
28
+ dict(
29
+ model_parallel_bf16=model_parallel,
30
+ ),
31
+ flags={"allow_objects": True},
32
+ )
33
+ return MODEL_PARALLELS["model_parallel_bf16"]
cosmos_predict1/autoregressive/configs/base/optim.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+
18
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
19
+
20
+
21
+ class LambdaLinearWarmupScheduler:
22
+ """
23
+ A learning rate scheduler that implements linear warm-up and cool-down.
24
+
25
+ This scheduler provides three phases:
26
+ 1. Warm-up: Learning rate linearly increases from 0 to 1.
27
+ 2. Constant: Learning rate remains at 1.
28
+ 3. Cool-down: Learning rate linearly decreases from 1 to 0.
29
+
30
+ Args:
31
+ warmup_steps (int): Number of steps for the warm-up phase.
32
+ warmup_offset (int): Starts warmup from this offset.
33
+ max_iter (int, optional): Total number of iterations. Required if cooldown_steps is provided.
34
+ cooldown_steps (int, optional): Number of steps for the cool-down phase.
35
+
36
+ Raises:
37
+ ValueError: If cooldown_steps is provided without max_iter, or if an invalid step is given.
38
+ """
39
+
40
+ def __init__(self, warmup_steps: int, warmup_offset: int = 0, max_iter: int = None, cooldown_steps: int = None):
41
+ self.warmup_steps = warmup_steps
42
+ self.warmup_offset = warmup_offset
43
+ self.max_iter = max_iter
44
+ self.cooldown_steps = cooldown_steps
45
+
46
+ if cooldown_steps is not None:
47
+ if max_iter is None:
48
+ raise ValueError("max_iter must be specified when cooldown_steps is provided")
49
+ self.cooldown_start = max_iter - cooldown_steps
50
+ else:
51
+ self.cooldown_start = None
52
+
53
+ def __call__(self, step):
54
+ # Warm-up phase
55
+ if step < self.warmup_offset:
56
+ return 0
57
+
58
+ if step < self.warmup_steps + self.warmup_offset:
59
+ return float(step - self.warmup_offset) / float(max(1, self.warmup_steps))
60
+
61
+ # Constant phase (no cool-down)
62
+ elif self.cooldown_steps is None:
63
+ return 1.0
64
+
65
+ # Constant phase (before cool-down starts)
66
+ elif step < self.cooldown_start:
67
+ return 1.0
68
+
69
+ # Cool-down phase
70
+ elif self.cooldown_start <= step < self.max_iter:
71
+ cooldown_progress = (step - self.cooldown_start) / self.cooldown_steps
72
+ return 1.0 - cooldown_progress
73
+
74
+ # After max_iter
75
+ elif step >= self.max_iter:
76
+ return 0.0
77
+
78
+ # Unexpected case
79
+ else:
80
+ raise ValueError(f"Invalid step {step}")
81
+
82
+
83
+ LambdaLinearLR = L(torch.optim.lr_scheduler.LambdaLR)(
84
+ optimizer=None,
85
+ lr_lambda=L(LambdaLinearWarmupScheduler)(warmup_steps=5000),
86
+ )
cosmos_predict1/autoregressive/configs/base/tokenizer.py ADDED
@@ -0,0 +1,139 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Optional
17
+
18
+ import attrs
19
+
20
+ from cosmos_predict1.autoregressive.tokenizer.discrete_video import DiscreteVideoFSQStateDictTokenizer
21
+ from cosmos_predict1.autoregressive.tokenizer.networks import CausalDiscreteVideoTokenizer
22
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
23
+ from cosmos_predict1.utils.lazy_config import LazyDict
24
+
25
+
26
+ def create_discrete_video_fsq_tokenizer_state_dict_config(
27
+ ckpt_path, pixel_chunk_duration=33, compression_ratio=[8, 16, 16]
28
+ ) -> LazyDict:
29
+ CausalDiscreteFactorizedVideoTokenizerConfig: LazyDict = L(CausalDiscreteVideoTokenizer)(
30
+ # The new causal discrete tokenizer, that is at least 2x more efficient in memory and runtime.
31
+ # - It relies on fully 3D discrete wavelet transform
32
+ # - Uses a layer norm instead of a group norm
33
+ # - Factorizes full convolutions into spatial and temporal convolutions
34
+ # - Factorizes full attention into spatial and temporal attention
35
+ # - Strictly causal, with flexible temporal length at inference.
36
+ attn_resolutions=[32],
37
+ channels=128,
38
+ channels_mult=[2, 4, 4],
39
+ dropout=0.0,
40
+ in_channels=3,
41
+ num_res_blocks=2,
42
+ out_channels=3,
43
+ resolution=1024,
44
+ patch_size=4,
45
+ patch_method="haar",
46
+ z_channels=16,
47
+ z_factor=1,
48
+ num_groups=1,
49
+ legacy_mode=False,
50
+ spatial_compression=16,
51
+ temporal_compression=8,
52
+ embedding_dim=6,
53
+ levels=[8, 8, 8, 5, 5, 5],
54
+ name="CausalDiscreteFactorizedVideoTokenizer",
55
+ )
56
+
57
+ return L(DiscreteVideoFSQStateDictTokenizer)(
58
+ enc_fp=ckpt_path.replace("ema.jit", "encoder.jit"),
59
+ dec_fp=ckpt_path.replace("ema.jit", "decoder.jit"),
60
+ tokenizer_module=CausalDiscreteFactorizedVideoTokenizerConfig,
61
+ name="discrete_video_fsq",
62
+ latent_ch=6,
63
+ is_bf16=True,
64
+ pixel_chunk_duration=pixel_chunk_duration,
65
+ latent_chunk_duration=1 + (pixel_chunk_duration - 1) // compression_ratio[0],
66
+ max_enc_batch_size=8,
67
+ max_dec_batch_size=4,
68
+ levels=[8, 8, 8, 5, 5, 5],
69
+ compression_ratio=compression_ratio,
70
+ )
71
+
72
+
73
+ @attrs.define(slots=False)
74
+ class TextTokenizerConfig:
75
+ """
76
+ Text tokenizer config
77
+
78
+ Args:
79
+ config: Config file to define the text tokenizer class.
80
+ data_key (str): The input key from data_dict that will be passed to the text tokenizer.
81
+ tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
82
+ tokenizer_offset (int): Offset that is added to the tokens.
83
+ vocab_size (int): Vocabulary size of the tokenizer.
84
+ """
85
+
86
+ config: LazyDict
87
+ data_key: str = ""
88
+ tokenize_here: bool = False
89
+ tokenizer_offset: int = 0
90
+ vocab_size: int = 0
91
+
92
+
93
+ @attrs.define(slots=False)
94
+ class VideoTokenizerConfig:
95
+ """
96
+ Video tokenizer config
97
+
98
+ Args:
99
+ config: Config file to define the video tokenizer class.
100
+ data_key (str): The input key from data_dict that will be passed to the video tokenizer.
101
+ tokenize_here (bool): Whether to use the tokenizer to perform online tokenization.
102
+ tokenizer_offset (int): Offset that is added to the tokens. In case of joint text-video tokenizers, we
103
+ add an offset to make sure that video tokens and text tokens don't overlap.
104
+ vocab_size (int): Vocabulary size of the tokenizer.
105
+ max_seq_len (int): Maximum token length for an input video.
106
+ temporal_overlap (int): Overlap between consecutive video chunks.
107
+ """
108
+
109
+ config: LazyDict
110
+ data_key: str = ""
111
+ tokenize_here: bool = True
112
+ tokenizer_offset: int = 0
113
+ vocab_size: int = 0
114
+ max_seq_len: int = -1
115
+ temporal_overlap: int = 0
116
+
117
+
118
+ @attrs.define(slots=False)
119
+ class TokenizerConfig:
120
+ """
121
+ Joint tokenizer config
122
+
123
+ Args:
124
+ text_tokenizer (TextTokenizerConfig): Text tokenizer config file
125
+ class_tokenizer (ClassTokenizerConfig): Class tokenizer config file
126
+ video_tokenizer (VideoTokenizerConfig): Video tokenizer config file
127
+ image_tokenizer (ImageTokenizerConfig): Image tokenizer config file
128
+ seq_len (int): Final token sequence length
129
+ training_type (str): Type of training we use. Supports ["text_only", "text_to_video", "class_to_image", "image_text_interleaved"]
130
+ add_special_tokens (bool): Whether to add special tokens to the output tokens
131
+ pad_to_multiple_of (int): Pad the token sequence length to the nearest multiple of this number. Defaults to 64.
132
+ """
133
+
134
+ text_tokenizer: Optional[TextTokenizerConfig] = None
135
+ video_tokenizer: Optional[VideoTokenizerConfig] = None
136
+ seq_len: int = 4096
137
+ training_type: str = None
138
+ add_special_tokens: bool = True
139
+ pad_to_multiple_of: Optional[int] = 64
cosmos_predict1/autoregressive/configs/config.py ADDED
@@ -0,0 +1,111 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Default config for cosmos_ar project."""
17
+
18
+ import os
19
+ from typing import Any, List
20
+
21
+ import attrs
22
+
23
+ from cosmos_predict1.autoregressive.configs.registry import register_configs
24
+ from cosmos_predict1.autoregressive.trainer import Trainer
25
+ from cosmos_predict1.utils import config, log
26
+ from cosmos_predict1.utils.config_helper import import_all_modules_from_package
27
+
28
+
29
+ @attrs.define(slots=False)
30
+ class Config(config.Config):
31
+ defaults: List[Any] = attrs.field(
32
+ factory=lambda: [
33
+ "_self_",
34
+ {"model": None},
35
+ {"data_train": "mock_video"},
36
+ {"data_val": None},
37
+ {"optimizer": "fused_adamw"},
38
+ {"scheduler": "warmup_cosine_lr"},
39
+ {"checkpoint": "local"},
40
+ {"callbacks": "basic"},
41
+ {"global_config": None},
42
+ {"experiment": None},
43
+ ]
44
+ )
45
+
46
+ def validate(self) -> None:
47
+ """Validate that the config has all required fields."""
48
+ assert self.job.project != "", "job.project is not set"
49
+ assert self.job.group != "", "job.group is not set"
50
+ assert self.job.name != "", "job.name is not set"
51
+ log.info("Validating config for cosmos_autoregressive job")
52
+ # FSDP config check
53
+ if self.model.model_config.fsdp_enabled:
54
+ assert self.trainer.distributed_parallelism == "fsdp"
55
+ else:
56
+ assert self.trainer.distributed_parallelism == "ddp"
57
+
58
+ # Transformer Engine config check
59
+ if self.model.model_config.backend == "transformer_engine":
60
+ assert (
61
+ "NVTE_FLASH_ATTN" in os.environ and os.environ["NVTE_FLASH_ATTN"] == "1"
62
+ ) # Enable Flash attention for transformer engine
63
+
64
+ # TP, CP config check
65
+ if self.model_parallel is not None:
66
+ if self.model_parallel.context_parallel_size > 1:
67
+ assert (
68
+ self.model.model_config.backend == "transformer_engine"
69
+ ), "Context parallelism is only supported in transformer engine."
70
+
71
+ if self.model_parallel.tensor_model_parallel_size > 1:
72
+ assert (
73
+ self.model.model_config.set_parallel_mode
74
+ ), "Tensor model parallelism is only supported in parallel mode."
75
+
76
+ if self.model_parallel.sequence_parallel:
77
+ assert (
78
+ self.model_parallel.tensor_model_parallel_size > 1
79
+ ), "Sequence parallelism is only supported in tensor model parallelism."
80
+ assert (
81
+ self.model.model_config.backend == "transformer_engine"
82
+ ), "Sequence parallelism is only supported in transformer engine."
83
+
84
+
85
+ def make_config():
86
+ c = Config(
87
+ model=None,
88
+ optimizer=None,
89
+ scheduler=None,
90
+ dataloader_train=None,
91
+ dataloader_val=None,
92
+ checkpoint=None,
93
+ )
94
+
95
+ c.job.project = "cosmos_autoregressive"
96
+ c.job.group = "debug"
97
+ c.job.name = "default_${now:%Y-%m-%d}_${now:%H-%M-%S}"
98
+
99
+ c.trainer.type = Trainer
100
+ c.trainer.run_validation = True
101
+
102
+ c.trainer.seed = 0
103
+ c.trainer.max_iter = 10
104
+ c.trainer.logging_iter = 1
105
+
106
+ c.trainer.callbacks = None
107
+ register_configs()
108
+ # experiment config are defined in the experiment folder
109
+ # call import_all_modules_from_package to register them
110
+ import_all_modules_from_package("cosmos_predict1.autoregressive.configs.experiment")
111
+ return c
cosmos_predict1/autoregressive/configs/experiment/video2video/__init__.py ADDED
File without changes
cosmos_predict1/autoregressive/configs/experiment/video2video/basic.py ADDED
@@ -0,0 +1,163 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ This file contains a basic configuration for video2video experiments.
18
+ """
19
+
20
+ from hydra.core.config_store import ConfigStore
21
+
22
+ from cosmos_predict1.autoregressive.configs.base.model_config import create_video2world_model
23
+ from cosmos_predict1.autoregressive.configs.base.model_parallel import create_model_parallel_config
24
+ from cosmos_predict1.utils import log
25
+ from cosmos_predict1.utils.lazy_config import LazyDict
26
+
27
+ cs = ConfigStore.instance()
28
+
29
+
30
+ """
31
+ Finetune 4B model with TP=1, pytorch backend, low resolution tealrobot data, frames 33, chunk 33.
32
+ Usage:
33
+ torchrun --nproc_per_node=1 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobotsmall_tp1
34
+ """
35
+ base_4b_example_tealrobotsmall_tp1: LazyDict = LazyDict(
36
+ dict(
37
+ defaults=[
38
+ {"override /data_train": "tealrobot_video_small"},
39
+ {
40
+ "override /callbacks": [
41
+ "basic",
42
+ "video_teacher_forcing",
43
+ ]
44
+ },
45
+ {"override /checkpoint": "local"},
46
+ {"override /optimizer": "fused_adamw"},
47
+ {"override /scheduler": "warmup_cosine_lr"},
48
+ "_self_",
49
+ ],
50
+ job=dict(
51
+ project="posttraining",
52
+ group="autoregressive_base",
53
+ name="base_4b_example_tealrobotsmall_tp1",
54
+ ),
55
+ model=create_video2world_model(
56
+ model_size="4b",
57
+ model_family="cosmos",
58
+ backend="pytorch",
59
+ tensor_model_parallel_size=1,
60
+ batch_size=1,
61
+ pixel_chunk_duration=33,
62
+ num_video_frames=33,
63
+ video_height=384,
64
+ video_width=640,
65
+ tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit",
66
+ add_special_tokens=False,
67
+ ),
68
+ trainer=dict(
69
+ max_iter=50000,
70
+ grad_accum_iter=1,
71
+ grad_scaler_args=dict(enabled=False),
72
+ run_validation=False, # No need for validation as epoch <= 1
73
+ distributed_parallelism="ddp",
74
+ callbacks=dict(
75
+ vid_sampling_tf=dict(
76
+ every_n=500,
77
+ ),
78
+ ),
79
+ ),
80
+ checkpoint=dict(
81
+ load_path="checkpoints/Cosmos-Predict1-4B/model.pt",
82
+ load_training_state=False,
83
+ strict_resume=True,
84
+ save_iter=1000,
85
+ ),
86
+ model_parallel=create_model_parallel_config(),
87
+ ),
88
+ )
89
+
90
+
91
+ """
92
+ Finetune 4B model with TP=4, pytorch backend, high resolution tealrobot data, frame 33, chunk 33.
93
+ Usage:
94
+ torchrun --nproc_per_node=4 -m cosmos_predict1.autoregressive.train --config=cosmos_predict1/autoregressive/configs/config.py -- experiment=base_4b_example_tealrobot_tp4
95
+ """
96
+ base_4b_example_tealrobot_tp4: LazyDict = LazyDict(
97
+ dict(
98
+ defaults=[
99
+ {"override /data_train": "tealrobot_video"},
100
+ {
101
+ "override /callbacks": [
102
+ "basic",
103
+ "video_teacher_forcing",
104
+ ]
105
+ },
106
+ {"override /checkpoint": "local"},
107
+ {"override /optimizer": "fused_adamw"},
108
+ {"override /scheduler": "warmup_cosine_lr"},
109
+ "_self_",
110
+ ],
111
+ job=dict(
112
+ project="posttraining",
113
+ group="autoregressive_base",
114
+ name="base_4b_example_tealrobot_tp4",
115
+ ),
116
+ model=create_video2world_model(
117
+ model_size="4b",
118
+ model_family="cosmos",
119
+ backend="pytorch",
120
+ tensor_model_parallel_size=4,
121
+ batch_size=1,
122
+ pixel_chunk_duration=33,
123
+ num_video_frames=33,
124
+ video_height=640,
125
+ video_width=848,
126
+ tokenizer_ckpt_path="checkpoints/Cosmos-Tokenize1-DV8x16x16-720p/ema.jit",
127
+ add_special_tokens=False,
128
+ ),
129
+ trainer=dict(
130
+ max_iter=50000,
131
+ grad_accum_iter=1,
132
+ grad_scaler_args=dict(enabled=False),
133
+ run_validation=False, # No need for validation as epoch <= 1
134
+ distributed_parallelism="ddp",
135
+ callbacks=dict(
136
+ vid_sampling_tf=dict(
137
+ every_n=500,
138
+ ),
139
+ ),
140
+ ),
141
+ checkpoint=dict(
142
+ load_path="checkpoints/Cosmos-Predict1-4B/model.pt",
143
+ load_training_state=False,
144
+ strict_resume=False,
145
+ save_iter=1000,
146
+ ),
147
+ model_parallel=create_model_parallel_config(),
148
+ ),
149
+ )
150
+
151
+
152
+ def register_experiments(cs):
153
+ # Register the experiments
154
+ for _item in [
155
+ base_4b_example_tealrobotsmall_tp1,
156
+ base_4b_example_tealrobot_tp4,
157
+ ]:
158
+ cs.store(
159
+ group="experiment",
160
+ package="_global_",
161
+ name=_item["job"]["name"],
162
+ node=_item,
163
+ )
cosmos_predict1/autoregressive/configs/inference/inference_config.py ADDED
@@ -0,0 +1,102 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, List, Optional, Union
17
+
18
+ import attrs
19
+
20
+ from cosmos_predict1.autoregressive.configs.base.model import ModelConfig, TokenizerConfig
21
+
22
+
23
+ @attrs.define(slots=False)
24
+ class DataShapeConfig:
25
+ latent_shape: list = []
26
+ num_video_frames: Union[None, int] = None
27
+ height: Union[None, int] = None
28
+ width: Union[None, int] = None
29
+
30
+
31
+ @attrs.define(slots=False)
32
+ class SamplingConfig:
33
+ """
34
+ Sampling config
35
+ Args:
36
+ temperature (float): Temperature value for controlling randomness in sampling. Defaults to 0.6.
37
+ top_p (float): Top-p probability threshold for nucleus sampling. Defaults to 0.9.
38
+ logprobs (bool): Flag indicating whether to compute token log probabilities. Defaults to False.
39
+ echo (bool): Flag indicating whether to include prompt tokens in the generated output. Defaults to False.
40
+
41
+ """
42
+
43
+ temperature: float = 0.6
44
+ top_k: int = None
45
+ top_p: float = 0.9
46
+ compile_prefill: bool = False
47
+ compile_sampling: bool = True
48
+ logprobs: bool = False
49
+ echo: bool = False
50
+
51
+
52
+ @attrs.define(slots=False)
53
+ class DiffusionDecoderSamplingConfig:
54
+ """
55
+ Diffusion decoder sampling config
56
+ Args:
57
+ guidance (float): Guidance scale for the diffusion process. Controls how much the model follows the conditioning. Defaults to 0.8.
58
+ sigma_min (float): Minimum noise level for the diffusion process. Defaults to 0.02.
59
+ sigma (float): Initial noise level for the diffusion process. Defaults to 8.
60
+ num_steps (int): Number of denoising steps to perform. Defaults to 35.
61
+ overlap (int): Number of overlapping frames between video chunks during processing. Defaults to 2.
62
+ continuous_tokenizer_channel (int): Number of channels in the continuous tokenizer of diffusion decoder. Defaults to 16.
63
+ continuous_tokenizer_spatial_compression_ratio (int): Spatial compression ratio for the continuous tokenizer of diffusion decoder. Defaults to 8.
64
+ dd_train_num_video_frames (int): Number of video frames used during training for diffusion decoder. Defaults to 57.
65
+ """
66
+
67
+ guidance: float = 1.8
68
+ sigma_min: float = 0.02
69
+ sigma: float = 8
70
+ num_steps: int = 15
71
+ overlap: int = 2
72
+ continuous_tokenizer_channel = 16
73
+ continuous_tokenizer_spatial_compression_ratio = 8
74
+ dd_train_num_video_frames: int = 57
75
+ max_iter: int = 99
76
+ fps: int = 24
77
+
78
+
79
+ @attrs.define(slots=False)
80
+ class InferenceConfig:
81
+ """
82
+ Inference config
83
+ Args:
84
+ model_config (ModelConfig): Model config
85
+ tokenizer_config (TokenizerConfig): Tokenizer config
86
+ ckpt_path (str): Path to the checkpoint
87
+ latent_shape (list): Shape of the latent
88
+ """
89
+
90
+ model_config: ModelConfig = None
91
+ tokenizer_config: TokenizerConfig = None
92
+ ckpt_path: str = ""
93
+ data_shape_config: DataShapeConfig = None
94
+
95
+ defaults: List[Any] = attrs.field(
96
+ factory=lambda: [
97
+ "_self_",
98
+ {"data_val": None},
99
+ {"data_shape_config": "video_shape_as_model_config"},
100
+ {"eval_job": None},
101
+ ]
102
+ )
cosmos_predict1/autoregressive/configs/registry.py ADDED
@@ -0,0 +1,89 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import torch
17
+ from hydra.core.config_store import ConfigStore
18
+
19
+ from cosmos_predict1.autoregressive.configs.base.callbacks import BASIC_CALLBACKS, VIDEO_TEACHER_FORCING_CALLBACK
20
+ from cosmos_predict1.autoregressive.configs.base.dataloader import get_tealrobot_video
21
+ from cosmos_predict1.autoregressive.configs.base.optim import LambdaLinearLR
22
+ from cosmos_predict1.autoregressive.configs.experiment.video2video.basic import register_experiments
23
+ from cosmos_predict1.utils import config, log
24
+ from cosmos_predict1.utils.lazy_config import LazyCall as L
25
+ from cosmos_predict1.utils.scheduler import WarmupCosineLR
26
+
27
+
28
+ def register_checkpoint(cs):
29
+ checkpoint_local = config.CheckpointConfig(
30
+ save_iter=5000,
31
+ broadcast_via_filesystem=True,
32
+ )
33
+ cs.store(group="checkpoint", package="checkpoint", name="local", node=checkpoint_local)
34
+
35
+
36
+ def register_callbacks(cs):
37
+ cs.store(group="callbacks", package="trainer.callbacks", name="basic", node=BASIC_CALLBACKS)
38
+ cs.store(
39
+ group="callbacks",
40
+ package="trainer.callbacks",
41
+ name="video_teacher_forcing",
42
+ node=VIDEO_TEACHER_FORCING_CALLBACK,
43
+ )
44
+
45
+
46
+ def register_scheduler(cs):
47
+ cs.store(
48
+ group="scheduler",
49
+ package="scheduler",
50
+ name="warmup_cosine_lr",
51
+ node=L(WarmupCosineLR)(optimizer=None, warmup_iters=5000, lr_decay_iters="${trainer.max_iter}", min_lr=1e-8),
52
+ )
53
+ cs.store(group="scheduler", package="scheduler", name="lambdalinear", node=LambdaLinearLR)
54
+
55
+
56
+ def register_optimizer(cs):
57
+ cs.store(
58
+ group="optimizer",
59
+ package="optimizer",
60
+ name="fused_adamw",
61
+ node=L(torch.optim.AdamW)(params=None, lr=1e-3, weight_decay=0.05, fused=True),
62
+ )
63
+ cs.store(
64
+ group="optimizer",
65
+ package="optimizer",
66
+ name="sgd",
67
+ node=L(torch.optim.SGD)(params=None, lr=5e-6, momentum=0.9),
68
+ )
69
+
70
+
71
+ def register_training_data(cs):
72
+ cs.store(
73
+ group="data_train",
74
+ package="dataloader_train",
75
+ name="tealrobot_video_small",
76
+ node=get_tealrobot_video(num_frames=33, video_size=[384, 640]),
77
+ )
78
+ cs.store(group="data_train", package="dataloader_train", name="tealrobot_video", node=get_tealrobot_video())
79
+
80
+
81
+ def register_configs():
82
+ log.info("Registering configs for autoregressive_base")
83
+ cs = ConfigStore.instance()
84
+ register_callbacks(cs)
85
+ register_checkpoint(cs)
86
+ register_optimizer(cs)
87
+ register_scheduler(cs)
88
+ register_training_data(cs)
89
+ register_experiments(cs)
cosmos_predict1/autoregressive/datasets/dataset_utils.py ADDED
@@ -0,0 +1,173 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Any, Optional
17
+
18
+ import torch
19
+ import torchvision.transforms.functional as transforms_F
20
+ from PIL import Image
21
+
22
+
23
+ def obtain_image_size(data_dict: dict, input_keys: list) -> tuple[int, int]:
24
+ r"""Function for obtaining the image size from the data dict.
25
+
26
+ Args:
27
+ data_dict (dict): Input data dict
28
+ input_keys (list): List of input keys
29
+ Returns:
30
+ width (int): Width of the input image
31
+ height (int): Height of the input image
32
+ """
33
+
34
+ data1 = data_dict[input_keys[0]]
35
+ if isinstance(data1, Image.Image):
36
+ width, height = data1.size
37
+ elif isinstance(data1, torch.Tensor):
38
+ height, width = data1.size()[-2:]
39
+ else:
40
+ raise ValueError("data to random crop should be PIL Image or tensor")
41
+
42
+ return width, height
43
+
44
+
45
+ class Augmentor:
46
+ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
47
+ r"""Base augmentor class
48
+
49
+ Args:
50
+ input_keys (list): List of input keys
51
+ output_keys (list): List of output keys
52
+ args (dict): Arguments associated with the augmentation
53
+ """
54
+ self.input_keys = input_keys
55
+ self.output_keys = output_keys
56
+ self.args = args
57
+
58
+ def __call__(self, *args: Any, **kwds: Any) -> Any:
59
+ raise ValueError("Augmentor not implemented")
60
+
61
+
62
+ class ResizeSmallestSideAspectPreserving(Augmentor):
63
+ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
64
+ super().__init__(input_keys, output_keys, args)
65
+
66
+ def __call__(self, data_dict: dict) -> dict:
67
+ r"""Performs aspect-ratio preserving resizing.
68
+ Image is resized to the dimension which has the smaller ratio of (size / target_size).
69
+ First we compute (w_img / w_target) and (h_img / h_target) and resize the image
70
+ to the dimension that has the smaller of these ratios.
71
+
72
+ Args:
73
+ data_dict (dict): Input data dict
74
+ Returns:
75
+ data_dict (dict): Output dict where images are resized
76
+ """
77
+
78
+ if self.output_keys is None:
79
+ self.output_keys = self.input_keys
80
+ assert self.args is not None, "Please specify args in augmentations"
81
+
82
+ img_w, img_h = self.args["img_w"], self.args["img_h"]
83
+
84
+ orig_w, orig_h = obtain_image_size(data_dict, self.input_keys)
85
+ scaling_ratio = max((img_w / orig_w), (img_h / orig_h))
86
+ target_size = (int(scaling_ratio * orig_h + 0.5), int(scaling_ratio * orig_w + 0.5))
87
+
88
+ assert (
89
+ target_size[0] >= img_h and target_size[1] >= img_w
90
+ ), f"Resize error. orig {(orig_w, orig_h)} desire {(img_w, img_h)} compute {target_size}"
91
+
92
+ for inp_key, out_key in zip(self.input_keys, self.output_keys):
93
+ data_dict[out_key] = transforms_F.resize(
94
+ data_dict[inp_key],
95
+ size=target_size, # type: ignore
96
+ interpolation=getattr(self.args, "interpolation", transforms_F.InterpolationMode.BICUBIC),
97
+ antialias=True,
98
+ )
99
+
100
+ if out_key != inp_key:
101
+ del data_dict[inp_key]
102
+ return data_dict
103
+
104
+
105
+ class CenterCrop(Augmentor):
106
+ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
107
+ super().__init__(input_keys, output_keys, args)
108
+
109
+ def __call__(self, data_dict: dict) -> dict:
110
+ r"""Performs center crop.
111
+
112
+ Args:
113
+ data_dict (dict): Input data dict
114
+ Returns:
115
+ data_dict (dict): Output dict where images are center cropped.
116
+ We also save the cropping parameters in the aug_params dict
117
+ so that it will be used by other transforms.
118
+ """
119
+ assert (
120
+ (self.args is not None) and ("img_w" in self.args) and ("img_h" in self.args)
121
+ ), "Please specify size in args"
122
+
123
+ img_w, img_h = self.args["img_w"], self.args["img_h"]
124
+
125
+ orig_w, orig_h = obtain_image_size(data_dict, self.input_keys)
126
+ for key in self.input_keys:
127
+ data_dict[key] = transforms_F.center_crop(data_dict[key], [img_h, img_w])
128
+
129
+ # We also add the aug params we use. This will be useful for other transforms
130
+ crop_x0 = (orig_w - img_w) // 2
131
+ crop_y0 = (orig_h - img_h) // 2
132
+ cropping_params = {
133
+ "resize_w": orig_w,
134
+ "resize_h": orig_h,
135
+ "crop_x0": crop_x0,
136
+ "crop_y0": crop_y0,
137
+ "crop_w": img_w,
138
+ "crop_h": img_h,
139
+ }
140
+
141
+ if "aug_params" not in data_dict:
142
+ data_dict["aug_params"] = dict()
143
+
144
+ data_dict["aug_params"]["cropping"] = cropping_params
145
+ data_dict["padding_mask"] = torch.zeros((1, cropping_params["crop_h"], cropping_params["crop_w"]))
146
+ return data_dict
147
+
148
+
149
+ class Normalize(Augmentor):
150
+ def __init__(self, input_keys: list, output_keys: Optional[list] = None, args: Optional[dict] = None) -> None:
151
+ super().__init__(input_keys, output_keys, args)
152
+
153
+ def __call__(self, data_dict: dict) -> dict:
154
+ r"""Performs data normalization.
155
+
156
+ Args:
157
+ data_dict (dict): Input data dict
158
+ Returns:
159
+ data_dict (dict): Output dict where images are center cropped.
160
+ """
161
+ assert self.args is not None, "Please specify args"
162
+
163
+ mean = self.args["mean"]
164
+ std = self.args["std"]
165
+
166
+ for key in self.input_keys:
167
+ if isinstance(data_dict[key], torch.Tensor):
168
+ data_dict[key] = data_dict[key].to(dtype=torch.get_default_dtype()).div(255)
169
+ else:
170
+ data_dict[key] = transforms_F.to_tensor(data_dict[key]) # division by 255 is applied in to_tensor()
171
+
172
+ data_dict[key] = transforms_F.normalize(tensor=data_dict[key], mean=mean, std=std)
173
+ return data_dict
cosmos_predict1/autoregressive/datasets/video_dataset.py ADDED
@@ -0,0 +1,190 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ Run this command to interactively debug:
18
+ PYTHONPATH=. python cosmos_predict1/autoregressive/datasets/video_dataset.py
19
+ """
20
+
21
+ import os
22
+ import traceback
23
+ import warnings
24
+ from concurrent.futures import ThreadPoolExecutor, as_completed
25
+
26
+ import numpy as np
27
+ import torch
28
+ from decord import VideoReader, cpu
29
+ from torch.utils.data import Dataset
30
+ from tqdm import tqdm
31
+
32
+ from cosmos_predict1.autoregressive.configs.base.dataset import VideoDatasetConfig
33
+ from cosmos_predict1.autoregressive.datasets.dataset_utils import (
34
+ CenterCrop,
35
+ Normalize,
36
+ ResizeSmallestSideAspectPreserving,
37
+ )
38
+
39
+
40
+ class VideoDataset(Dataset):
41
+ def __init__(self, config: VideoDatasetConfig):
42
+ """Video Dataset class for loading video-to-video generation data."""
43
+
44
+ super().__init__()
45
+ self.dataset_dir = config.dataset_dir
46
+ self.sequence_interval = config.sequence_interval
47
+ self.sequence_length = config.num_frames
48
+ self.video_size = config.video_size
49
+ self.start_frame_interval = config.start_frame_interval
50
+
51
+ self.video_dir = self.dataset_dir
52
+ self.video_paths = [os.path.join(self.video_dir, f) for f in os.listdir(self.video_dir) if f.endswith(".mp4")]
53
+ print(f"{len(self.video_paths)} videos in total")
54
+
55
+ self.samples = self._init_samples(self.video_paths)
56
+ self.samples = sorted(self.samples, key=lambda x: (x["video_path"], x["frame_ids"][0]))
57
+ print(f"{len(self.samples)} samples in total")
58
+ self.wrong_number = 0
59
+
60
+ self.resize_transform = ResizeSmallestSideAspectPreserving(
61
+ input_keys=["video"],
62
+ args={"img_w": self.video_size[1], "img_h": self.video_size[0]},
63
+ )
64
+ self.crop_transform = CenterCrop(
65
+ input_keys=["video"],
66
+ args={"img_w": self.video_size[1], "img_h": self.video_size[0]},
67
+ )
68
+ self.normalize_transform = Normalize(
69
+ input_keys=["video"],
70
+ args={"mean": 0.5, "std": 0.5},
71
+ )
72
+
73
+ def __str__(self):
74
+ return f"{len(self.video_paths)} samples from {self.dataset_dir}"
75
+
76
+ def _init_samples(self, video_paths):
77
+ samples = []
78
+ with ThreadPoolExecutor(32) as executor:
79
+ future_to_video_path = {
80
+ executor.submit(self._load_and_process_video_path, video_path): video_path for video_path in video_paths
81
+ }
82
+ for future in tqdm(as_completed(future_to_video_path), total=len(video_paths)):
83
+ samples.extend(future.result())
84
+ return samples
85
+
86
+ def _load_and_process_video_path(self, video_path):
87
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
88
+ n_frames = len(vr)
89
+
90
+ samples = []
91
+ for frame_i in range(0, n_frames, self.start_frame_interval):
92
+ sample = dict()
93
+ sample["video_path"] = video_path
94
+ sample["orig_num_frames"] = n_frames
95
+ sample["chunk_index"] = -1
96
+ sample["frame_ids"] = []
97
+ curr_frame_i = frame_i
98
+ while True:
99
+ if curr_frame_i > (n_frames - 1):
100
+ break
101
+ sample["frame_ids"].append(curr_frame_i)
102
+ if len(sample["frame_ids"]) == self.sequence_length:
103
+ break
104
+ curr_frame_i += self.sequence_interval
105
+ # make sure there are sequence_length number of frames
106
+ if len(sample["frame_ids"]) == self.sequence_length:
107
+ sample["chunk_index"] += 1
108
+ samples.append(sample)
109
+ return samples
110
+
111
+ def __len__(self):
112
+ return len(self.samples)
113
+
114
+ def _load_video(self, video_path, frame_ids):
115
+ vr = VideoReader(video_path, ctx=cpu(0), num_threads=2)
116
+ assert (np.array(frame_ids) < len(vr)).all(), "Some frame_ids are out of range."
117
+ assert (np.array(frame_ids) >= 0).all(), "Some frame_ids are negative."
118
+ vr.seek(0)
119
+ frame_data = vr.get_batch(frame_ids).asnumpy()
120
+ fps = vr.get_avg_fps()
121
+ return frame_data, fps
122
+
123
+ def _get_frames(self, video_path, frame_ids):
124
+ frames, fps = self._load_video(video_path, frame_ids)
125
+ frames = frames.astype(np.uint8)
126
+ frames = torch.from_numpy(frames)
127
+ frames = frames.permute(0, 3, 1, 2) # Rearrange from [T, H, W, C] to [T, C, H, W]
128
+ return frames, fps
129
+
130
+ def __getitem__(self, index):
131
+ try:
132
+ sample = self.samples[index]
133
+ video_path = sample["video_path"]
134
+ frame_ids = sample["frame_ids"]
135
+
136
+ data = dict()
137
+
138
+ video, fps = self._get_frames(video_path, frame_ids)
139
+ data["video"] = video
140
+ data["fps"] = fps
141
+ data["num_frames"] = self.sequence_length
142
+ data["orig_num_frames"] = sample["orig_num_frames"]
143
+ data["chunk_index"] = sample["chunk_index"]
144
+ data["frame_start"] = frame_ids[0]
145
+ data["frame_end"] = frame_ids[-1]
146
+
147
+ data["video_name"] = {
148
+ "video_path": video_path,
149
+ "start_frame_id": str(frame_ids[0]),
150
+ }
151
+
152
+ # resize video to smallest side aspect preserving
153
+ data = self.resize_transform(data)
154
+ # center crop video
155
+ data = self.crop_transform(data)
156
+ # normalize video
157
+ data = self.normalize_transform(data)
158
+
159
+ data["video"] = data["video"].permute(1, 0, 2, 3) # Rearrange from [T, C, H, W] to [C, T, H, W]
160
+
161
+ return data
162
+ except Exception:
163
+ warnings.warn(
164
+ f"Invalid data encountered: {self.samples[index]['video_path']}. Skipped "
165
+ f"(by randomly sampling another sample in the same dataset)."
166
+ )
167
+ warnings.warn("FULL TRACEBACK:")
168
+ warnings.warn(traceback.format_exc())
169
+ self.wrong_number += 1
170
+ print(self.wrong_number)
171
+ return self[np.random.randint(len(self.samples))]
172
+
173
+
174
+ if __name__ == "__main__":
175
+ config = VideoDatasetConfig(dataset_dir="datasets/cosmos_nemo_assets/videos/")
176
+ dataset = VideoDataset(config)
177
+
178
+ indices = [0, 1, 2, -1]
179
+ for idx in indices:
180
+ data = dataset[idx]
181
+ print(
182
+ (
183
+ f"{idx=} "
184
+ f"{data['video'].sum()=}\n"
185
+ f"{data['video'].shape=}\n"
186
+ f"{data['video_name']=}\n"
187
+ f"{data.keys()=}\n"
188
+ "---"
189
+ )
190
+ )