tchoudha21 commited on
Commit
5a33b3a
·
verified ·
1 Parent(s): 0660fbf

Upload modified files

Browse files
base_world_generation_pipeline.py ADDED
@@ -0,0 +1,362 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gc
17
+ import os
18
+ from abc import ABC
19
+ from typing import Any
20
+
21
+ import numpy as np
22
+ import torch
23
+
24
+ from Cosmos.t5_text_encoder import CosmosT5TextEncoder
25
+ from Cosmos import guardrail_presets as guardrail_presets
26
+
27
+
28
+ class BaseWorldGenerationPipeline(ABC):
29
+ def __init__(
30
+ self,
31
+ inference_type: str | None = None,
32
+ checkpoint_dir: str | None = None,
33
+ checkpoint_name: str | None = None,
34
+ enable_text_guardrail: bool = False,
35
+ enable_video_guardrail: bool = False,
36
+ offload_network: bool = False,
37
+ offload_tokenizer: bool = False,
38
+ offload_text_encoder_model: bool = False,
39
+ offload_guardrail_models: bool = False,
40
+ ):
41
+ """Initialize base world generation pipeline.
42
+
43
+ This abstract base class provides core functionality for world generation models including:
44
+ - Model loading and initialization
45
+ - Text encoding and embedding
46
+ - Safety checks and content filtering
47
+ - Memory management through model offloading
48
+
49
+ Args:
50
+ inference_type: The type of inference pipeline ("text2world" or "video2world")
51
+ checkpoint_dir: Root directory containing model checkpoints
52
+ checkpoint_name: Name of the specific checkpoint file to load
53
+ enable_text_guardrail: If True, validates input prompts for safety
54
+ enable_video_guardrail: If True, validates generated videos for safety
55
+ offload_network: If True, moves main model to CPU after inference
56
+ offload_tokenizer: If True, moves tokenizer to CPU after use
57
+ offload_text_encoder_model: If True, moves T5 encoder to CPU after encoding
58
+ offload_guardrail_models: If True, moves safety models to CPU after checks
59
+ """
60
+ self.inference_type = inference_type
61
+ self.checkpoint_dir = checkpoint_dir
62
+ self.checkpoint_name = checkpoint_name
63
+ self.guardrail_dir = "Cosmos-1.0-Guardrail"
64
+ self.enable_text_guardrail = enable_text_guardrail
65
+ self.enable_video_guardrail = enable_video_guardrail
66
+
67
+ # Add offloading flags
68
+ self.offload_network = offload_network
69
+ self.offload_tokenizer = offload_tokenizer
70
+ self.offload_text_encoder_model = offload_text_encoder_model
71
+ self.offload_guardrail_models = offload_guardrail_models
72
+
73
+ # Initialize model instances
74
+ self.text_guardrail = None
75
+ self.video_guardrail = None
76
+ self.text_encoder = None
77
+ self.model = None
78
+
79
+ self._load_model()
80
+
81
+ if not self.offload_text_encoder_model:
82
+ self._load_text_encoder_model()
83
+ if not self.offload_guardrail_models:
84
+ if self.enable_text_guardrail:
85
+ self._load_text_guardrail()
86
+ if self.enable_video_guardrail:
87
+ self._load_video_guardrail()
88
+ if not self.offload_network:
89
+ self._load_network()
90
+ if not self.offload_tokenizer:
91
+ self._load_tokenizer()
92
+
93
+ def _load_tokenizer(self):
94
+ pass
95
+
96
+ def _load_network(self):
97
+ pass
98
+
99
+ def _load_model(self, checkpoint_name: str) -> Any:
100
+ """Load the world generation model from a checkpoint.
101
+
102
+ This abstract method must be implemented by subclasses to load their specific
103
+ model architecture and weights.
104
+
105
+ Args:
106
+ checkpoint_name: Path to the model checkpoint file
107
+
108
+ Returns:
109
+ The loaded model instance
110
+
111
+ Raises:
112
+ NotImplementedError: Must be implemented by subclasses
113
+ """
114
+ pass
115
+
116
+ def _load_text_encoder_model(self):
117
+ """Load the T5 text encoder model.
118
+
119
+ Initializes and loads the T5 encoder model used for converting text prompts
120
+ into embeddings that condition the world generation model.
121
+
122
+ Returns:
123
+ Loaded T5 text encoder model instance
124
+ """
125
+ self.text_encoder = CosmosT5TextEncoder(cache_dir=self.checkpoint_dir)
126
+
127
+ def _load_text_guardrail(self):
128
+ """Load text safety classifier models.
129
+
130
+ Initializes models used for checking input prompts against safety policies.
131
+ Models are loaded from the specified guardrail directory.
132
+ """
133
+ self.text_guardrail = guardrail_presets.create_text_guardrail_runner(
134
+ checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir)
135
+ )
136
+
137
+ def _load_video_guardrail(self):
138
+ """Load video safety classifier models.
139
+
140
+ Initializes models used for validating generated video content against
141
+ safety policies. Models are loaded from the specified guardrail directory.
142
+ """
143
+ self.video_guardrail = guardrail_presets.create_video_guardrail_runner(
144
+ checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir)
145
+ )
146
+
147
+ def _offload_network(self):
148
+ if self.model.model:
149
+ del self.model.model
150
+ self.model.model = None
151
+ gc.collect()
152
+ torch.cuda.empty_cache()
153
+
154
+ def _offload_tokenizer(self):
155
+ if self.model.tokenizer:
156
+ del self.model.tokenizer
157
+ self.model.tokenizer = None
158
+ gc.collect()
159
+ torch.cuda.empty_cache()
160
+
161
+ def _offload_guardrail_models(self):
162
+ """Offload safety classifier models to reduce memory usage.
163
+
164
+ Moves safety models to CPU and clears GPU memory if they are no longer needed.
165
+ This helps manage memory when processing multiple inputs sequentially.
166
+ """
167
+ if self.text_guardrail:
168
+ del self.text_guardrail
169
+ self.text_guardrail = None
170
+ if self.video_guardrail:
171
+ del self.video_guardrail
172
+ self.video_guardrail = None
173
+ gc.collect()
174
+ torch.cuda.empty_cache()
175
+
176
+ def _offload_text_encoder_model(self):
177
+ """Offload T5 text encoder to reduce memory usage.
178
+
179
+ Moves the T5 encoder to CPU and clears GPU memory after text encoding is complete.
180
+ This helps manage memory when processing multiple inputs sequentially.
181
+ """
182
+ if self.text_encoder:
183
+ del self.text_encoder
184
+ self.text_encoder = None
185
+ gc.collect()
186
+ torch.cuda.empty_cache()
187
+
188
+ def _run_model(self, *args: Any, **kwargs: Any) -> torch.Tensor:
189
+ """Generate world latents using the model.
190
+
191
+ This abstract method must be implemented by subclasses to define their specific
192
+ generation process.
193
+
194
+ Args:
195
+ *args: Variable positional arguments for model inference
196
+ **kwargs: Variable keyword arguments for model inference
197
+
198
+ Returns:
199
+ torch.Tensor: Generated world representation tensor
200
+ """
201
+ pass
202
+
203
+ def _run_model_with_offload(self, *args: Any, **kwargs: Any) -> torch.Tensor:
204
+ """Generate world representation with memory management.
205
+
206
+ Handles loading the model before inference and offloading afterward if enabled.
207
+ This helps minimize GPU memory usage during inference.
208
+
209
+ Args:
210
+ *args: Arguments passed to _run_model
211
+ **kwargs: Keyword arguments passed to _run_model
212
+
213
+ Returns:
214
+ np.ndarray: Generated world representation as numpy array
215
+ """
216
+ pass
217
+
218
+ def _run_guardrail_on_prompt(self, prompt: str) -> bool:
219
+ """Check if prompt meets safety requirements.
220
+
221
+ Validates the input prompt against safety policies using loaded guardrail models.
222
+
223
+ Args:
224
+ prompt: Raw text prompt to validate
225
+
226
+ Returns:
227
+ bool: True if prompt passes all safety checks, False otherwise
228
+ """
229
+ return guardrail_presets.run_text_guardrail(prompt, self.text_guardrail)
230
+
231
+ def _run_guardrail_on_prompt_with_offload(self, prompt: str) -> bool:
232
+ """Check prompt safety with memory management.
233
+
234
+ Validates prompt safety while handling model loading/offloading to manage memory.
235
+
236
+ Args:
237
+ prompt: Raw text prompt to validate
238
+
239
+ Returns:
240
+ bool: True if prompt passes all safety checks, False otherwise
241
+ """
242
+ if self.offload_guardrail_models:
243
+ self._load_text_guardrail()
244
+
245
+ is_safe = self._run_guardrail_on_prompt(prompt)
246
+
247
+ if self.offload_guardrail_models:
248
+ self._offload_guardrail_models()
249
+
250
+ return is_safe
251
+
252
+ def _run_guardrail_on_video(self, video: np.ndarray) -> np.ndarray | None:
253
+ """Check if video meets safety requirements.
254
+
255
+ Validates generated video content against safety policies using guardrail models.
256
+
257
+ Args:
258
+ video: Video frames to validate
259
+
260
+ Returns:
261
+ np.ndarray: Processed video if safe, None if unsafe
262
+ """
263
+ return guardrail_presets.run_video_guardrail(video, self.video_guardrail)
264
+
265
+ def _run_guardrail_on_video_with_offload(self, video: np.ndarray) -> np.ndarray | None:
266
+ """Check if generated video meets safety requirements.
267
+
268
+ Args:
269
+ video: Video frames to validate
270
+
271
+ Returns:
272
+ np.ndarray: Processed video frames if safe, None otherwise
273
+
274
+ Note:
275
+ Guardrail models are offloaded after checks if enabled.
276
+ """
277
+ if self.offload_guardrail_models:
278
+ self._load_video_guardrail()
279
+
280
+ video = self._run_guardrail_on_video(video)
281
+
282
+ if self.offload_guardrail_models:
283
+ self._offload_guardrail_models()
284
+ return video
285
+
286
+ def _run_text_embedding_on_prompt(
287
+ self, prompts: list[str], **kwargs: Any
288
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
289
+ """Convert text prompts to embeddings.
290
+
291
+ Processes text prompts into embedding tensors that condition the generation model.
292
+
293
+ Args:
294
+ prompts: List of text prompts to encode
295
+ **kwargs: Additional arguments for text encoding
296
+
297
+ Returns:
298
+ tuple containing:
299
+ - List of text embedding tensors for each prompt
300
+ - List of attention masks for each embedding
301
+ """
302
+
303
+ embeddings = []
304
+ masks = []
305
+ for prompt in prompts:
306
+ embedding, mask = self.text_encoder.encode_prompts(
307
+ [prompt],
308
+ **kwargs,
309
+ )
310
+ embeddings.append(embedding)
311
+ masks.append(mask)
312
+
313
+ return embeddings, masks
314
+
315
+ def _run_text_embedding_on_prompt_with_offload(
316
+ self, prompts: list[str], **kwargs: Any
317
+ ) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
318
+ """Convert text prompt into embeddings using T5 encoder.
319
+
320
+ Args:
321
+ prompt: Processed and validated text prompt
322
+
323
+ Returns:
324
+ Text embedding tensor to condition diffusion model
325
+
326
+ Note:
327
+ T5 model is offloaded after encoding if enabled.
328
+ """
329
+ if self.offload_text_encoder_model:
330
+ self._load_text_encoder_model()
331
+
332
+ embeddings, masks = self._run_text_embedding_on_prompt(prompts, **kwargs)
333
+
334
+ if self.offload_text_encoder_model:
335
+ self._offload_text_encoder_model()
336
+ return embeddings, masks
337
+
338
+ def _run_tokenizer_decoding(self, samples: torch.Tensor) -> np.ndarray:
339
+ """Decode model outputs into final world representation.
340
+
341
+ This abstract method must be implemented by subclasses to convert raw model
342
+ outputs into their specific world representation format.
343
+
344
+ Args:
345
+ samples: Raw output tensor from the generation model
346
+
347
+ Returns:
348
+ np.ndarray: Decoded world representation
349
+ """
350
+ pass
351
+
352
+ def generate(self, *args: Any, **kwargs: Any):
353
+ """Generate world representation.
354
+
355
+ This abstract method must be implemented by subclasses to convert raw model
356
+ outputs into their specific world representation format.
357
+
358
+ Args:
359
+ *args: Variable positional arguments for model inference
360
+ **kwargs: Variable keyword arguments for model inference
361
+ """
362
+ pass
conditioner.py ADDED
@@ -0,0 +1,323 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import copy
17
+ from abc import ABC, abstractmethod
18
+ from collections import defaultdict
19
+ from dataclasses import dataclass, fields
20
+ from enum import Enum
21
+ from typing import Any, Dict, List, Optional, Tuple, Union
22
+
23
+ import torch
24
+ import torch.nn as nn
25
+
26
+ from cosmos1.models.diffusion.diffusion.functional.batch_ops import batch_mul
27
+ from Cosmos.utils import log
28
+ from Cosmos.lazy_config import instantiate
29
+
30
+
31
+ class BaseConditionEntry(nn.Module):
32
+ def __init__(self):
33
+ super().__init__()
34
+
35
+ self._dropout_rate = None
36
+ self._input_key = None
37
+ self._return_dict = False
38
+
39
+ @property
40
+ def dropout_rate(self) -> Union[float, torch.Tensor]:
41
+ return self._dropout_rate
42
+
43
+ @property
44
+ def input_key(self) -> str:
45
+ return self._input_key
46
+
47
+ @property
48
+ def is_return_dict(self) -> bool:
49
+ return self._return_dict
50
+
51
+ @dropout_rate.setter
52
+ def dropout_rate(self, value: Union[float, torch.Tensor]):
53
+ self._dropout_rate = value
54
+
55
+ @input_key.setter
56
+ def input_key(self, value: str):
57
+ self._input_key = value
58
+
59
+ @is_return_dict.setter
60
+ def is_return_dict(self, value: bool):
61
+ self._return_dict = value
62
+
63
+ @dropout_rate.deleter
64
+ def dropout_rate(self):
65
+ del self._dropout_rate
66
+
67
+ @input_key.deleter
68
+ def input_key(self):
69
+ del self._input_key
70
+
71
+ @is_return_dict.deleter
72
+ def is_return_dict(self):
73
+ del self._return_dict
74
+
75
+ def random_dropout_input(
76
+ self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None
77
+ ) -> torch.Tensor:
78
+ del key
79
+ dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate
80
+ return batch_mul(
81
+ torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor),
82
+ in_tensor,
83
+ )
84
+
85
+ def summary(self) -> str:
86
+ pass
87
+
88
+
89
+ class DataType(Enum):
90
+ IMAGE = "image"
91
+ VIDEO = "video"
92
+
93
+
94
+ class TextAttr(BaseConditionEntry):
95
+ def __init__(self):
96
+ super().__init__()
97
+
98
+ def forward(self, token: torch.Tensor, mask: torch.Tensor):
99
+ return {"crossattn_emb": token, "crossattn_mask": mask}
100
+
101
+ def random_dropout_input(
102
+ self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None
103
+ ) -> torch.Tensor:
104
+ if key is not None and "mask" in key:
105
+ return in_tensor
106
+ return super().random_dropout_input(in_tensor, dropout_rate, key)
107
+
108
+
109
+ @dataclass
110
+ class BaseVideoCondition:
111
+ crossattn_emb: torch.Tensor
112
+ crossattn_mask: torch.Tensor
113
+ data_type: DataType = DataType.VIDEO
114
+ padding_mask: Optional[torch.Tensor] = None
115
+ fps: Optional[torch.Tensor] = None
116
+ num_frames: Optional[torch.Tensor] = None
117
+ image_size: Optional[torch.Tensor] = None
118
+ scalar_feature: Optional[torch.Tensor] = None
119
+
120
+ def to_dict(self) -> Dict[str, Optional[torch.Tensor]]:
121
+ return {f.name: getattr(self, f.name) for f in fields(self)}
122
+
123
+
124
+ @dataclass
125
+ class VideoExtendCondition(BaseVideoCondition):
126
+ video_cond_bool: Optional[torch.Tensor] = None # whether or not it conditioned on video
127
+ gt_latent: Optional[torch.Tensor] = None
128
+ condition_video_indicator: Optional[torch.Tensor] = None # 1 for condition region
129
+
130
+ # condition_video_input_mask will concat to the input of network, along channel dim;
131
+ # Will be concat with the input tensor
132
+ condition_video_input_mask: Optional[torch.Tensor] = None
133
+ # condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed"
134
+ condition_video_augment_sigma: Optional[torch.Tensor] = None
135
+
136
+
137
+ class GeneralConditioner(nn.Module, ABC):
138
+ """
139
+ An abstract module designed to handle various embedding models with conditional and
140
+ unconditional configurations. This abstract base class initializes and manages a collection
141
+ of embedders that can dynamically adjust their dropout rates based on conditioning.
142
+
143
+ Attributes:
144
+ KEY2DIM (dict): A mapping from output keys to dimensions used for concatenation.
145
+ embedders (nn.ModuleDict): A dictionary containing all embedded models initialized and
146
+ configured based on the provided configurations.
147
+
148
+ Parameters:
149
+ emb_models (Union[List, Any]): A dictionary where keys are embedder names and values
150
+ are configurations for initializing the embedders.
151
+
152
+ """
153
+
154
+ KEY2DIM = {"crossattn_emb": 1, "crossattn_mask": 1}
155
+
156
+ def __init__(self, **emb_models: Union[List, Any]):
157
+ super().__init__()
158
+ self.embedders = nn.ModuleDict()
159
+ for n, (emb_name, embconfig) in enumerate(emb_models.items()):
160
+ embedder = instantiate(embconfig.obj)
161
+ assert isinstance(
162
+ embedder, BaseConditionEntry
163
+ ), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
164
+ embedder.dropout_rate = getattr(embconfig, "dropout_rate", 0.0)
165
+
166
+ if hasattr(embconfig, "input_key"):
167
+ embedder.input_key = embconfig.input_key
168
+ elif hasattr(embconfig, "input_keys"):
169
+ embedder.input_keys = embconfig.input_keys
170
+ else:
171
+ raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}")
172
+
173
+ log.debug(f"Initialized embedder #{n}-{emb_name}: \n {embedder.summary()}")
174
+ self.embedders[emb_name] = embedder
175
+
176
+ @abstractmethod
177
+ def forward(
178
+ self,
179
+ batch: Dict,
180
+ override_dropout_rate: Optional[Dict[str, float]] = None,
181
+ ) -> Any:
182
+ """Should be implemented in subclasses to handle conditon datatype"""
183
+ raise NotImplementedError
184
+
185
+ def _forward(
186
+ self,
187
+ batch: Dict,
188
+ override_dropout_rate: Optional[Dict[str, float]] = None,
189
+ ) -> Dict:
190
+ """
191
+ Processes the input batch through all configured embedders, applying conditional dropout rates if specified.
192
+ Output tensors for each key are concatenated along the dimensions specified in KEY2DIM.
193
+
194
+ Parameters:
195
+ batch (Dict): The input data batch to process.
196
+ override_dropout_rate (Optional[Dict[str, float]]): Optional dictionary to override default dropout rates
197
+ per embedder key.
198
+
199
+ Returns:
200
+ Dict: A dictionary of output tensors concatenated by specified dimensions.
201
+
202
+ Note:
203
+ In case the network code is sensitive to the order of concatenation, you can either control the order via \
204
+ config file or make sure the embedders return a unique key for each output.
205
+ """
206
+ output = defaultdict(list)
207
+ if override_dropout_rate is None:
208
+ override_dropout_rate = {}
209
+
210
+ # make sure emb_name in override_dropout_rate is valid
211
+ for emb_name in override_dropout_rate.keys():
212
+ assert emb_name in self.embedders, f"invalid name found {emb_name}"
213
+
214
+ for emb_name, embedder in self.embedders.items():
215
+ with torch.no_grad():
216
+ if hasattr(embedder, "input_key") and (embedder.input_key is not None):
217
+ emb_out = embedder(
218
+ embedder.random_dropout_input(
219
+ batch[embedder.input_key], override_dropout_rate.get(emb_name, None)
220
+ )
221
+ )
222
+ elif hasattr(embedder, "input_keys"):
223
+ emb_out = embedder(
224
+ *[
225
+ embedder.random_dropout_input(batch[k], override_dropout_rate.get(emb_name, None), k)
226
+ for k in embedder.input_keys
227
+ ]
228
+ )
229
+ for k, v in emb_out.items():
230
+ output[k].append(v)
231
+ # Concatenate the outputs
232
+ return {k: torch.cat(v, dim=self.KEY2DIM.get(k, -1)) for k, v in output.items()}
233
+
234
+ def get_condition_uncondition(
235
+ self,
236
+ data_batch: Dict,
237
+ ) -> Tuple[Any, Any]:
238
+ """
239
+ Processes the provided data batch to generate conditioned and unconditioned outputs.
240
+
241
+ This method manipulates dropout rates to simulate two scenarios:
242
+ 1. All conditions applied (conditioned)
243
+ 2. Conditions removed/reduced to minimum (unconditioned)
244
+
245
+ This method sets dropout rates to zero for the conditioned scenario to fully apply
246
+ embedders' effects. For unconditioned, it sets rates to 1 (or 0 if initial rate is
247
+ insignificant) to minimize embedder influences.
248
+
249
+ Parameters:
250
+ data_batch (Dict): Input data batch containing all necessary information for
251
+ embedding processing.
252
+
253
+ Returns:
254
+ Tuple[Any, Any]: A tuple containing:
255
+ - Outputs with all embedders fully applied (conditioned)
256
+ - Outputs with embedders minimized/not applied (unconditioned)
257
+ """
258
+ cond_dropout_rates, dropout_rates = {}, {}
259
+ for emb_name, embedder in self.embedders.items():
260
+ cond_dropout_rates[emb_name] = 0.0
261
+ dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0
262
+
263
+ condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates)
264
+ un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates)
265
+ return condition, un_condition
266
+
267
+ def get_condition_with_negative_prompt(
268
+ self,
269
+ data_batch: Dict,
270
+ ) -> Tuple[Any, Any]:
271
+ """
272
+ Similar functionality as get_condition_uncondition
273
+ But use negative prompts for unconditon
274
+ """
275
+ cond_dropout_rates, uncond_dropout_rates = {}, {}
276
+ for emb_name, embedder in self.embedders.items():
277
+ cond_dropout_rates[emb_name] = 0.0
278
+ if isinstance(embedder, TextAttr):
279
+ uncond_dropout_rates[emb_name] = 0.0
280
+ else:
281
+ uncond_dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0
282
+
283
+ data_batch_neg_prompt = copy.deepcopy(data_batch)
284
+ if "neg_t5_text_embeddings" in data_batch_neg_prompt:
285
+ if isinstance(data_batch_neg_prompt["neg_t5_text_embeddings"], torch.Tensor):
286
+ data_batch_neg_prompt["t5_text_embeddings"] = data_batch_neg_prompt["neg_t5_text_embeddings"]
287
+ data_batch_neg_prompt["t5_text_mask"] = data_batch_neg_prompt["neg_t5_text_mask"]
288
+
289
+ condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates)
290
+ un_condition: Any = self(data_batch_neg_prompt, override_dropout_rate=uncond_dropout_rates)
291
+
292
+ return condition, un_condition
293
+
294
+
295
+ @dataclass
296
+ class CosmosCondition:
297
+ crossattn_emb: torch.Tensor
298
+ crossattn_mask: torch.Tensor
299
+ padding_mask: Optional[torch.Tensor] = None
300
+ scalar_feature: Optional[torch.Tensor] = None
301
+
302
+ def to_dict(self) -> Dict[str, Optional[torch.Tensor]]:
303
+ return {f.name: getattr(self, f.name) for f in fields(self)}
304
+
305
+
306
+ class VideoConditioner(GeneralConditioner):
307
+ def forward(
308
+ self,
309
+ batch: Dict,
310
+ override_dropout_rate: Optional[Dict[str, float]] = None,
311
+ ) -> BaseVideoCondition:
312
+ output = super()._forward(batch, override_dropout_rate)
313
+ return BaseVideoCondition(**output)
314
+
315
+
316
+ class VideoExtendConditioner(GeneralConditioner):
317
+ def forward(
318
+ self,
319
+ batch: Dict,
320
+ override_dropout_rate: Optional[Dict[str, float]] = None,
321
+ ) -> VideoExtendCondition:
322
+ output = super()._forward(batch, override_dropout_rate)
323
+ return VideoExtendCondition(**output)
config.json ADDED
@@ -0,0 +1,10 @@
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "architectures": [
3
+ "DiffusionVideo2World"
4
+ ],
5
+ "auto_map": {
6
+ "AutoConfig": "video2world_hf.DiffusionVideo2WorldConfig",
7
+ "AutoModel": "video2world_hf.DiffusionVideo2World"
8
+ },
9
+ "model_type": "AutoModel"
10
+ }
convert_pixtral_ckpt.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """Convert pretrained Pixtral vision model weights to checkpoint and verify the checkpoint loading.
17
+
18
+ Usage:
19
+
20
+ PYTHONPATH=$(pwd) python cosmos1/scripts/convert_pixtral_ckpt.py
21
+ """
22
+
23
+ import argparse
24
+ import json
25
+ import os
26
+ import shutil
27
+ from glob import glob
28
+
29
+ import torch
30
+ from huggingface_hub import snapshot_download
31
+ from safetensors.torch import load_file
32
+
33
+
34
+ def convert_pixtral_checkpoint(checkpoint_dir: str, checkpoint_name: str, vit_type: str):
35
+ """
36
+ Main function to convert Pixtral vision model weights to checkpoint and optionally verify and save the converted checkpoint.
37
+
38
+ Args:
39
+ checkpoint_dir (str): Path to the checkpoint directory
40
+ checkpoint_name (str): Name of the checkpoint
41
+ vit_type (str): Type of ViT used in the Pixtral model
42
+
43
+ This function performs the following steps:
44
+ 0. Download the checkpoint from Hugging Face
45
+ 1. Loads the original Pixtral checkpoint
46
+ 2. Splits the checkpoint into vision encoder, projector, and LLM weights
47
+ 3. Reorganizes the weights to match the expected format
48
+ 4. Extracts and verifies the vision encoder configuration
49
+ 5. Optionally verifies the converted checkpoint by loading it into a VisionTransformer
50
+ 6. Optionally saves the converted checkpoint and configuration
51
+ """
52
+
53
+ save_dir = os.path.join(checkpoint_dir, checkpoint_name)
54
+ os.makedirs(save_dir, exist_ok=True)
55
+ # Save the converted checkpoint
56
+ save_path = os.path.join(save_dir, "model.pt")
57
+ if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
58
+ print(f"Checkpoint {save_path} already exists and is not empty")
59
+ return
60
+
61
+ pixtral_ckpt_dir = os.path.join(checkpoint_dir, "Pixtral-12B-2409")
62
+ os.makedirs(pixtral_ckpt_dir, exist_ok=True)
63
+ repo_id = "mistralai/Pixtral-12B-2409"
64
+ print(f"Downloading {repo_id} to {pixtral_ckpt_dir}...")
65
+ snapshot_download(
66
+ repo_id=repo_id,
67
+ allow_patterns=["params.json", "consolidated.safetensors"],
68
+ local_dir=pixtral_ckpt_dir,
69
+ local_dir_use_symlinks=False,
70
+ )
71
+ orig_dtype = torch.get_default_dtype()
72
+ dtype = torch.bfloat16
73
+ torch.set_default_dtype(dtype)
74
+
75
+ # Load checkpoint file
76
+ ckpt_files = glob(os.path.join(pixtral_ckpt_dir, "*.safetensors"))
77
+ assert len(ckpt_files) == 1, "ckpt_dir should contain only one file"
78
+ ckpt_path = ckpt_files[0]
79
+ ckpt = load_file(ckpt_path)
80
+
81
+ # Split checkpoint into weights of vision encoder, projector, and LLM
82
+ vit_key_prefix = "vision_encoder."
83
+ vit_ckpt = {}
84
+ for key, value in ckpt.items():
85
+ if key.startswith(vit_key_prefix):
86
+ vit_ckpt[key.lstrip(vit_key_prefix)] = value
87
+
88
+ projector_key_prefix = "vision_language_adapter."
89
+ projector_ckpt = {}
90
+ substring_replacement_map = {
91
+ "w_in.": "projector.0.",
92
+ "w_out.": "projector.2.",
93
+ }
94
+ for key, value in ckpt.items():
95
+ if key.startswith(projector_key_prefix):
96
+ key = key.lstrip(projector_key_prefix)
97
+ for old, new in substring_replacement_map.items():
98
+ key = key.replace(old, new)
99
+ projector_ckpt[key] = value
100
+
101
+ llm_ckpt = {}
102
+ for key, value in ckpt.items():
103
+ if key.startswith(vit_key_prefix) or key.startswith(projector_key_prefix):
104
+ continue
105
+ llm_ckpt[key] = value
106
+
107
+ vlm_ckpt = {}
108
+ for key, value in llm_ckpt.items():
109
+ vlm_ckpt["model." + key] = value
110
+ for key, value in projector_ckpt.items():
111
+ vlm_ckpt["mm_projector." + key] = value
112
+ for key, value in vit_ckpt.items():
113
+ vlm_ckpt["vision_encoder." + key] = value
114
+
115
+ # Load config
116
+ config_path = os.path.join(pixtral_ckpt_dir, "params.json")
117
+ with open(config_path, "r") as f:
118
+ pixtral_config = json.load(f)
119
+
120
+ # Extract the vision encoder configuration
121
+ vision_encoder_config = {
122
+ "dim": pixtral_config["vision_encoder"]["hidden_size"],
123
+ "num_channels": pixtral_config["vision_encoder"]["num_channels"],
124
+ "image_size": pixtral_config["vision_encoder"]["image_size"],
125
+ "patch_size": pixtral_config["vision_encoder"]["patch_size"],
126
+ "rope_theta": pixtral_config["vision_encoder"]["rope_theta"],
127
+ "ffn_hidden_size": pixtral_config["vision_encoder"]["intermediate_size"],
128
+ "n_layers": pixtral_config["vision_encoder"]["num_hidden_layers"],
129
+ "n_heads": pixtral_config["vision_encoder"]["num_attention_heads"],
130
+ "n_kv_heads": pixtral_config["vision_encoder"]["num_attention_heads"],
131
+ "norm_type": "rmsnorm",
132
+ "norm_eps": pixtral_config["norm_eps"],
133
+ "image_token_id": pixtral_config["vision_encoder"]["image_token_id"],
134
+ }
135
+ # Configuration for the 400M ViT of Pixtral 12B VLM
136
+ vit_config = dict(
137
+ dim=1024,
138
+ num_channels=3,
139
+ image_size=1024,
140
+ patch_size=16,
141
+ rope_theta=10000,
142
+ ffn_hidden_size=4096,
143
+ n_layers=24,
144
+ n_heads=16,
145
+ n_kv_heads=16,
146
+ norm_type="rmsnorm",
147
+ norm_eps=1e-5,
148
+ image_token_id=10,
149
+ )
150
+ # Compare the two configurations
151
+ for key, value in vit_config.items():
152
+ assert vision_encoder_config[key] == value, f"Mismatch in {key}: {vision_encoder_config[key]} != {value}"
153
+
154
+ llm_config_keys = [
155
+ "dim",
156
+ "n_layers",
157
+ "head_dim",
158
+ "hidden_dim",
159
+ "n_heads",
160
+ "n_kv_heads",
161
+ "rope_theta",
162
+ "norm_eps",
163
+ "vocab_size",
164
+ ]
165
+ assert set(list(pixtral_config.keys())) == set(llm_config_keys + ["vision_encoder"]), "Config keys mismatch"
166
+ replace_map = {
167
+ "hidden_dim": "ffn_hidden_size",
168
+ }
169
+ llm_config = {}
170
+ for k, v in pixtral_config.items():
171
+ if k in llm_config_keys:
172
+ llm_config[replace_map.get(k, k)] = v
173
+ elif k == "vision_encoder":
174
+ llm_config["vision_encoder"] = vit_type
175
+ else:
176
+ raise ValueError(f"Unknown key: {k}")
177
+
178
+ ckpt_to_save = {"model": vlm_ckpt, "mm_projector": projector_ckpt, "vision_encoder": vit_ckpt}
179
+ torch.save(ckpt_to_save, save_path)
180
+ print(f"Model saved to {save_path}")
181
+
182
+ # Save config
183
+ config_path = os.path.join(save_dir, "config.json")
184
+ with open(config_path, "w") as f:
185
+ json.dump(llm_config, f)
186
+
187
+ torch.set_default_dtype(orig_dtype) # Reset the default dtype
188
+
189
+ # Remove the original Pixtral checkpoint
190
+ shutil.rmtree(pixtral_ckpt_dir, ignore_errors=True)
191
+ print(f"Removed {pixtral_ckpt_dir}")
192
+
193
+
194
+ if __name__ == "__main__":
195
+ parser = argparse.ArgumentParser(
196
+ description="Convert pretrained Pixtral vision model weights to checkpoint and verify accuracy"
197
+ )
198
+ parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Path to the checkpoint directory")
199
+ parser.add_argument(
200
+ "--checkpoint_name",
201
+ type=str,
202
+ default="Pixtral-12B",
203
+ help="Name of the checkpoint",
204
+ )
205
+ parser.add_argument("--vit_type", default="pixtral-12b-vit", help="Type of ViT used in the Pixtral model")
206
+ args = parser.parse_args()
207
+ convert_pixtral_checkpoint(
208
+ checkpoint_dir=args.checkpoint_dir, checkpoint_name=args.checkpoint_name, vit_type=args.vit_type
209
+ )
download_diffusion.py ADDED
@@ -0,0 +1,123 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ from pathlib import Path
18
+
19
+ from huggingface_hub import snapshot_download
20
+
21
+ from Cosmos.convert_pixtral_ckpt import convert_pixtral_checkpoint
22
+
23
+
24
+ def parse_args():
25
+ parser = argparse.ArgumentParser(description="Download NVIDIA Cosmos-1.0 Diffusion models from Hugging Face")
26
+ parser.add_argument(
27
+ "--model_sizes",
28
+ nargs="*",
29
+ default=[
30
+ "7B",
31
+ "14B",
32
+ ], # Download all by default
33
+ choices=["7B", "14B"],
34
+ help="Which model sizes to download. Possible values: 7B, 14B",
35
+ )
36
+ parser.add_argument(
37
+ "--model_types",
38
+ nargs="*",
39
+ default=[
40
+ "Text2World",
41
+ "Video2World",
42
+ ], # Download all by default
43
+ choices=["Text2World", "Video2World"],
44
+ help="Which model types to download. Possible values: Text2World, Video2World",
45
+ )
46
+ parser.add_argument(
47
+ "--cosmos_version",
48
+ type=str,
49
+ default="1.0",
50
+ choices=["1.0"],
51
+ help="Which version of Cosmos to download. Only 1.0 is available at the moment.",
52
+ )
53
+ parser.add_argument(
54
+ "--checkpoint_dir", type=str, default="checkpoints", help="Directory to save the downloaded checkpoints."
55
+ )
56
+ args = parser.parse_args()
57
+ return args
58
+
59
+
60
+ def main(args):
61
+ ORG_NAME = "nvidia"
62
+
63
+ # Mapping from size argument to Hugging Face repository name
64
+ model_map = {
65
+ "7B": "Cosmos-1.0-Diffusion-7B",
66
+ "14B": "Cosmos-1.0-Diffusion-14B",
67
+ }
68
+
69
+ # Additional models that are always downloaded
70
+ extra_models = [
71
+ "Cosmos-1.0-Guardrail",
72
+ "Cosmos-1.0-Tokenizer-CV8x8x8",
73
+ ]
74
+
75
+ if "Text2World" in args.model_types:
76
+ extra_models.append("Cosmos-1.0-Prompt-Upsampler-12B-Text2World")
77
+
78
+ # Create local checkpoints folder
79
+ checkpoints_dir = Path(args.checkpoint_dir)
80
+ checkpoints_dir.mkdir(parents=True, exist_ok=True)
81
+
82
+ download_kwargs = dict(allow_patterns=["README.md", "model.pt", "config.json", "*.jit"])
83
+
84
+ # Download the requested Autoregressive models
85
+ for size in args.model_sizes:
86
+ for model_type in args.model_types:
87
+ suffix = f"-{model_type}"
88
+ model_name = model_map[size] + suffix
89
+ repo_id = f"{ORG_NAME}/{model_name}"
90
+ local_dir = checkpoints_dir.joinpath(model_name)
91
+ local_dir.mkdir(parents=True, exist_ok=True)
92
+
93
+ print(f"Downloading {repo_id} to {local_dir}...")
94
+ snapshot_download(
95
+ repo_id=repo_id, local_dir=str(local_dir), local_dir_use_symlinks=False, **download_kwargs
96
+ )
97
+
98
+ # Download the always-included models
99
+ for model_name in extra_models:
100
+ repo_id = f"{ORG_NAME}/{model_name}"
101
+ local_dir = checkpoints_dir.joinpath(model_name)
102
+ local_dir.mkdir(parents=True, exist_ok=True)
103
+
104
+ print(f"Downloading {repo_id} to {local_dir}...")
105
+ # Download all files for Guardrail
106
+ snapshot_download(
107
+ repo_id=repo_id,
108
+ local_dir=str(local_dir),
109
+ local_dir_use_symlinks=False,
110
+ )
111
+
112
+ if "Video2World" in args.model_types:
113
+ # Prompt Upsampler for Cosmos-1.0-Diffusion-Video2World models
114
+ convert_pixtral_checkpoint(
115
+ checkpoint_dir=args.checkpoint_dir,
116
+ checkpoint_name="Pixtral-12B",
117
+ vit_type="pixtral-12b-vit",
118
+ )
119
+
120
+
121
+ if __name__ == "__main__":
122
+ args = parse_args()
123
+ main(args)
guardrail_presets.py ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import os
17
+
18
+ import numpy as np
19
+
20
+ from cosmos1.models.guardrail.aegis.aegis import Aegis
21
+ from cosmos1.models.guardrail.blocklist.blocklist import Blocklist
22
+ from cosmos1.models.guardrail.common.core import GuardrailRunner
23
+ from cosmos1.models.guardrail.face_blur_filter.face_blur_filter import RetinaFaceFilter
24
+ from cosmos1.models.guardrail.video_content_safety_filter.video_content_safety_filter import VideoContentSafetyFilter
25
+ from Cosmos.utils import log
26
+
27
+
28
+ def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
29
+ """Create the text guardrail runner."""
30
+ blocklist_checkpoint_dir = os.path.join(checkpoint_dir, "blocklist")
31
+ aegis_checkpoint_dir = os.path.join(checkpoint_dir, "aegis")
32
+ return GuardrailRunner(safety_models=[Blocklist(blocklist_checkpoint_dir), Aegis(aegis_checkpoint_dir)])
33
+
34
+
35
+ def create_video_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
36
+ """Create the video guardrail runner."""
37
+ video_filter_checkpoint_dir = os.path.join(checkpoint_dir, "video_content_safety_filter")
38
+ retinaface_checkpoint_path = os.path.join(checkpoint_dir, "face_blur_filter/Resnet50_Final.pth")
39
+ return GuardrailRunner(
40
+ safety_models=[VideoContentSafetyFilter(video_filter_checkpoint_dir)],
41
+ postprocessors=[RetinaFaceFilter(retinaface_checkpoint_path)],
42
+ )
43
+
44
+
45
+ def run_text_guardrail(prompt: str, guardrail_runner: GuardrailRunner) -> bool:
46
+ """Run the text guardrail on the prompt, checking for content safety.
47
+
48
+ Args:
49
+ prompt: The text prompt.
50
+ guardrail_runner: The text guardrail runner.
51
+
52
+ Returns:
53
+ bool: Whether the prompt is safe.
54
+ """
55
+ is_safe, message = guardrail_runner.run_safety_check(prompt)
56
+ if not is_safe:
57
+ log.critical(f"GUARDRAIL BLOCKED: {message}")
58
+ return is_safe
59
+
60
+
61
+ def run_video_guardrail(frames: np.ndarray, guardrail_runner: GuardrailRunner) -> np.ndarray | None:
62
+ """Run the video guardrail on the frames, checking for content safety and applying face blur.
63
+
64
+ Args:
65
+ frames: The frames of the generated video.
66
+ guardrail_runner: The video guardrail runner.
67
+
68
+ Returns:
69
+ The processed frames if safe, otherwise None.
70
+ """
71
+ is_safe, message = guardrail_runner.run_safety_check(frames)
72
+ if not is_safe:
73
+ log.critical(f"GUARDRAIL BLOCKED: {message}")
74
+ return None
75
+
76
+ frames = guardrail_runner.postprocess(frames)
77
+ return frames
inference_utils.py ADDED
@@ -0,0 +1,726 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import importlib
18
+ from contextlib import contextmanager
19
+ from typing import List, NamedTuple, Optional, Tuple
20
+
21
+ from Cosmos.utils import misc
22
+ import einops
23
+ import imageio
24
+ import numpy as np
25
+ import torch
26
+ import torchvision.transforms.functional as transforms_F
27
+
28
+ from Cosmos.model_t2w import DiffusionT2WModel
29
+ from Cosmos.model_v2w import DiffusionV2WModel
30
+ from Cosmos.utils import log
31
+ from Cosmos.utils.config_helper import get_config_module, override
32
+ from Cosmos.utils.io import load_from_fileobj
33
+
34
+ TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
35
+ if TORCH_VERSION >= (1, 11):
36
+ from torch.ao import quantization
37
+ from torch.ao.quantization import FakeQuantizeBase, ObserverBase
38
+ elif (
39
+ TORCH_VERSION >= (1, 8)
40
+ and hasattr(torch.quantization, "FakeQuantizeBase")
41
+ and hasattr(torch.quantization, "ObserverBase")
42
+ ):
43
+ from torch import quantization
44
+ from torch.quantization import FakeQuantizeBase, ObserverBase
45
+
46
+ DEFAULT_AUGMENT_SIGMA = 0.001
47
+
48
+
49
+ def add_common_arguments(parser):
50
+ """Add common command line arguments for text2world and video2world generation.
51
+
52
+ Args:
53
+ parser (ArgumentParser): Argument parser to add arguments to
54
+
55
+ The arguments include:
56
+ - checkpoint_dir: Base directory containing model weights
57
+ - tokenizer_dir: Directory containing tokenizer weights
58
+ - video_save_name: Output video filename for single video generation
59
+ - video_save_folder: Output directory for batch video generation
60
+ - prompt: Text prompt for single video generation
61
+ - batch_input_path: Path to JSONL file with input prompts for batch video generation
62
+ - negative_prompt: Text prompt describing undesired attributes
63
+ - num_steps: Number of diffusion sampling steps
64
+ - guidance: Classifier-free guidance scale
65
+ - num_video_frames: Number of frames to generate
66
+ - height/width: Output video dimensions
67
+ - fps: Output video frame rate
68
+ - seed: Random seed for reproducibility
69
+ - Various model offloading flags
70
+ """
71
+ parser.add_argument(
72
+ "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints"
73
+ )
74
+ parser.add_argument(
75
+ "--tokenizer_dir",
76
+ type=str,
77
+ default="Cosmos-1.0-Tokenizer-CV8x8x8",
78
+ help="Tokenizer weights directory relative to checkpoint_dir",
79
+ )
80
+ parser.add_argument(
81
+ "--video_save_name",
82
+ type=str,
83
+ default="output",
84
+ help="Output filename for generating a single video",
85
+ )
86
+ parser.add_argument(
87
+ "--video_save_folder",
88
+ type=str,
89
+ default="outputs/",
90
+ help="Output folder for generating a batch of videos",
91
+ )
92
+ parser.add_argument(
93
+ "--prompt",
94
+ type=str,
95
+ help="Text prompt for generating a single video",
96
+ )
97
+ parser.add_argument(
98
+ "--batch_input_path",
99
+ type=str,
100
+ help="Path to a JSONL file of input prompts for generating a batch of videos",
101
+ )
102
+ parser.add_argument(
103
+ "--negative_prompt",
104
+ type=str,
105
+ default="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
106
+ "over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
107
+ "underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
108
+ "jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special "
109
+ "effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and "
110
+ "flickering. Overall, the video is of poor quality.",
111
+ help="Negative prompt for the video",
112
+ )
113
+ parser.add_argument("--num_steps", type=int, default=35, help="Number of diffusion sampling steps")
114
+ parser.add_argument("--guidance", type=float, default=7, help="Guidance scale value")
115
+ parser.add_argument("--num_video_frames", type=int, default=121, help="Number of video frames to sample")
116
+ parser.add_argument("--height", type=int, default=704, help="Height of video to sample")
117
+ parser.add_argument("--width", type=int, default=1280, help="Width of video to sample")
118
+ parser.add_argument("--fps", type=int, default=24, help="FPS of the sampled video")
119
+ parser.add_argument("--seed", type=int, default=1, help="Random seed")
120
+ parser.add_argument(
121
+ "--disable_prompt_upsampler",
122
+ action="store_true",
123
+ help="Disable prompt upsampling",
124
+ )
125
+ parser.add_argument(
126
+ "--offload_diffusion_transformer",
127
+ action="store_true",
128
+ help="Offload DiT after inference",
129
+ )
130
+ parser.add_argument(
131
+ "--offload_tokenizer",
132
+ action="store_true",
133
+ help="Offload tokenizer after inference",
134
+ )
135
+ parser.add_argument(
136
+ "--offload_text_encoder_model",
137
+ action="store_true",
138
+ help="Offload text encoder model after inference",
139
+ )
140
+ parser.add_argument(
141
+ "--offload_prompt_upsampler",
142
+ action="store_true",
143
+ help="Offload prompt upsampler after inference",
144
+ )
145
+ parser.add_argument(
146
+ "--offload_guardrail_models",
147
+ action="store_true",
148
+ help="Offload guardrail models after inference",
149
+ )
150
+
151
+
152
+ def validate_args(args: argparse.Namespace, inference_type: str) -> None:
153
+ """Validate command line arguments for text2world and video2world generation."""
154
+ assert inference_type in [
155
+ "text2world",
156
+ "video2world",
157
+ ], "Invalid inference_type, must be 'text2world' or 'video2world'"
158
+
159
+ # Validate prompt/image/video args for single or batch generation
160
+ if inference_type == "text2world" or (inference_type == "video2world" and args.disable_prompt_upsampler):
161
+ assert args.prompt or args.batch_input_path, "--prompt or --batch_input_path must be provided."
162
+ if inference_type == "video2world" and not args.batch_input_path:
163
+ assert (
164
+ args.input_image_or_video_path
165
+ ), "--input_image_or_video_path must be provided for single video generation."
166
+
167
+
168
+ class _IncompatibleKeys(
169
+ NamedTuple(
170
+ "IncompatibleKeys",
171
+ [
172
+ ("missing_keys", List[str]),
173
+ ("unexpected_keys", List[str]),
174
+ ("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]),
175
+ ],
176
+ )
177
+ ):
178
+ pass
179
+
180
+
181
+ def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys:
182
+ """Load a model checkpoint with non-strict matching, handling shape mismatches.
183
+
184
+ Args:
185
+ model (torch.nn.Module): Model to load weights into
186
+ checkpoint_state_dict (dict): State dict from checkpoint
187
+
188
+ Returns:
189
+ _IncompatibleKeys: Named tuple containing:
190
+ - missing_keys: Keys present in model but missing from checkpoint
191
+ - unexpected_keys: Keys present in checkpoint but not in model
192
+ - incorrect_shapes: Keys with mismatched tensor shapes
193
+
194
+ The function handles special cases like:
195
+ - Uninitialized parameters
196
+ - Quantization observers
197
+ - TransformerEngine FP8 states
198
+ """
199
+ # workaround https://github.com/pytorch/pytorch/issues/24139
200
+ model_state_dict = model.state_dict()
201
+ incorrect_shapes = []
202
+ for k in list(checkpoint_state_dict.keys()):
203
+ if k in model_state_dict:
204
+ if "_extra_state" in k: # Key introduced by TransformerEngine for FP8
205
+ log.debug(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.")
206
+ continue
207
+ model_param = model_state_dict[k]
208
+ # Allow mismatch for uninitialized parameters
209
+ if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter):
210
+ continue
211
+ if not isinstance(model_param, torch.Tensor):
212
+ raise ValueError(
213
+ f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not."
214
+ )
215
+
216
+ shape_model = tuple(model_param.shape)
217
+ shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
218
+ if shape_model != shape_checkpoint:
219
+ has_observer_base_classes = (
220
+ TORCH_VERSION >= (1, 8)
221
+ and hasattr(quantization, "ObserverBase")
222
+ and hasattr(quantization, "FakeQuantizeBase")
223
+ )
224
+ if has_observer_base_classes:
225
+ # Handle the special case of quantization per channel observers,
226
+ # where buffer shape mismatches are expected.
227
+ def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module:
228
+ # foo.bar.param_or_buffer_name -> [foo, bar]
229
+ key_parts = key.split(".")[:-1]
230
+ cur_module = model
231
+ for key_part in key_parts:
232
+ cur_module = getattr(cur_module, key_part)
233
+ return cur_module
234
+
235
+ cls_to_skip = (
236
+ ObserverBase,
237
+ FakeQuantizeBase,
238
+ )
239
+ target_module = _get_module_for_key(model, k)
240
+ if isinstance(target_module, cls_to_skip):
241
+ # Do not remove modules with expected shape mismatches
242
+ # them from the state_dict loading. They have special logic
243
+ # in _load_from_state_dict to handle the mismatches.
244
+ continue
245
+
246
+ incorrect_shapes.append((k, shape_checkpoint, shape_model))
247
+ checkpoint_state_dict.pop(k)
248
+ incompatible = model.load_state_dict(checkpoint_state_dict, strict=False)
249
+ # Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling
250
+ missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k]
251
+ unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k]
252
+ return _IncompatibleKeys(
253
+ missing_keys=missing_keys,
254
+ unexpected_keys=unexpected_keys,
255
+ incorrect_shapes=incorrect_shapes,
256
+ )
257
+
258
+
259
+ @contextmanager
260
+ def skip_init_linear():
261
+ # skip init of nn.Linear
262
+ orig_reset_parameters = torch.nn.Linear.reset_parameters
263
+ torch.nn.Linear.reset_parameters = lambda x: x
264
+ xavier_uniform_ = torch.nn.init.xavier_uniform_
265
+ torch.nn.init.xavier_uniform_ = lambda x: x
266
+ yield
267
+ torch.nn.Linear.reset_parameters = orig_reset_parameters
268
+ torch.nn.init.xavier_uniform_ = xavier_uniform_
269
+
270
+
271
+ def load_model_by_config(
272
+ config_job_name,
273
+ config_file="projects/cosmos_video/config/config.py",
274
+ model_class=DiffusionT2WModel,
275
+ ):
276
+ config_module = get_config_module(config_file)
277
+ config = importlib.import_module(config_module).make_config()
278
+
279
+ config = override(config, ["--", f"experiment={config_job_name}"])
280
+
281
+ # Check that the config is valid
282
+ config.validate()
283
+ # Freeze the config so developers don't change it during training.
284
+ config.freeze() # type: ignore
285
+
286
+ # Initialize model
287
+ with skip_init_linear():
288
+ model = model_class(config.model)
289
+ return model
290
+
291
+
292
+ def load_network_model(model: DiffusionT2WModel, ckpt_path: str):
293
+ with skip_init_linear():
294
+ model.set_up_model()
295
+ net_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
296
+ log.debug(non_strict_load_model(model.model, net_state_dict))
297
+ model.cuda()
298
+
299
+
300
+ def load_tokenizer_model(model: DiffusionT2WModel, tokenizer_dir: str):
301
+ with skip_init_linear():
302
+ model.set_up_tokenizer(tokenizer_dir)
303
+ model.cuda()
304
+
305
+
306
+ def prepare_data_batch(
307
+ height: int,
308
+ width: int,
309
+ num_frames: int,
310
+ fps: int,
311
+ prompt_embedding: torch.Tensor,
312
+ negative_prompt_embedding: Optional[torch.Tensor] = None,
313
+ ):
314
+ """Prepare input batch tensors for video generation.
315
+
316
+ Args:
317
+ height (int): Height of video frames
318
+ width (int): Width of video frames
319
+ num_frames (int): Number of frames to generate
320
+ fps (int): Frames per second
321
+ prompt_embedding (torch.Tensor): Encoded text prompt embeddings
322
+ negative_prompt_embedding (torch.Tensor, optional): Encoded negative prompt embeddings
323
+
324
+ Returns:
325
+ dict: Batch dictionary containing:
326
+ - video: Zero tensor of target video shape
327
+ - t5_text_mask: Attention mask for text embeddings
328
+ - image_size: Target frame dimensions
329
+ - fps: Target frame rate
330
+ - num_frames: Number of frames
331
+ - padding_mask: Frame padding mask
332
+ - t5_text_embeddings: Prompt embeddings
333
+ - neg_t5_text_embeddings: Negative prompt embeddings (if provided)
334
+ - neg_t5_text_mask: Mask for negative embeddings (if provided)
335
+ """
336
+ # Create base data batch
337
+ data_batch = {
338
+ "video": torch.zeros((1, 3, num_frames, height, width), dtype=torch.uint8).cuda(),
339
+ "t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(),
340
+ "image_size": torch.tensor([[height, width, height, width]] * 1, dtype=torch.bfloat16).cuda(),
341
+ "fps": torch.tensor([fps] * 1, dtype=torch.bfloat16).cuda(),
342
+ "num_frames": torch.tensor([num_frames] * 1, dtype=torch.bfloat16).cuda(),
343
+ "padding_mask": torch.zeros((1, 1, height, width), dtype=torch.bfloat16).cuda(),
344
+ }
345
+
346
+ # Handle text embeddings
347
+
348
+ t5_embed = prompt_embedding.to(dtype=torch.bfloat16).cuda()
349
+ data_batch["t5_text_embeddings"] = t5_embed
350
+
351
+ if negative_prompt_embedding is not None:
352
+ neg_t5_embed = negative_prompt_embedding.to(dtype=torch.bfloat16).cuda()
353
+ data_batch["neg_t5_text_embeddings"] = neg_t5_embed
354
+ data_batch["neg_t5_text_mask"] = torch.ones(1, 512, dtype=torch.bfloat16).cuda()
355
+
356
+ return data_batch
357
+
358
+
359
+ def get_video_batch(model, prompt_embedding, negative_prompt_embedding, height, width, fps, num_video_frames):
360
+ """Prepare complete input batch for video generation including latent dimensions.
361
+
362
+ Args:
363
+ model: Diffusion model instance
364
+ prompt_embedding (torch.Tensor): Text prompt embeddings
365
+ negative_prompt_embedding (torch.Tensor): Negative prompt embeddings
366
+ height (int): Output video height
367
+ width (int): Output video width
368
+ fps (int): Output video frame rate
369
+ num_video_frames (int): Number of frames to generate
370
+
371
+ Returns:
372
+ tuple:
373
+ - data_batch (dict): Complete model input batch
374
+ - state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression
375
+ """
376
+ raw_video_batch = prepare_data_batch(
377
+ height=height,
378
+ width=width,
379
+ num_frames=num_video_frames,
380
+ fps=fps,
381
+ prompt_embedding=prompt_embedding,
382
+ negative_prompt_embedding=negative_prompt_embedding,
383
+ )
384
+ state_shape = [
385
+ model.tokenizer.channel,
386
+ model.tokenizer.get_latent_num_frames(num_video_frames),
387
+ height // model.tokenizer.spatial_compression_factor,
388
+ width // model.tokenizer.spatial_compression_factor,
389
+ ]
390
+ return raw_video_batch, state_shape
391
+
392
+
393
+ def generate_world_from_text(
394
+ model: DiffusionT2WModel,
395
+ state_shape: list[int],
396
+ is_negative_prompt: bool,
397
+ data_batch: dict,
398
+ guidance: float,
399
+ num_steps: int,
400
+ seed: int,
401
+ ):
402
+ """Generate video from text prompt using diffusion model.
403
+
404
+ Args:
405
+ model (DiffusionT2WModel): Text-to-video diffusion model
406
+ state_shape (list[int]): Latent state dimensions [C,T,H,W]
407
+ is_negative_prompt (bool): Whether negative prompt is provided
408
+ data_batch (dict): Model input batch with embeddings
409
+ guidance (float): Classifier-free guidance scale
410
+ num_steps (int): Number of diffusion sampling steps
411
+ seed (int): Random seed for reproducibility
412
+
413
+ Returns:
414
+ np.ndarray: Generated video frames [T,H,W,C], range [0,255]
415
+
416
+ The function:
417
+ 1. Initializes random latent with maximum noise
418
+ 2. Performs guided diffusion sampling
419
+ 3. Decodes latents to pixel space
420
+ """
421
+ x_sigma_max = (
422
+ misc.arch_invariant_rand(
423
+ (1,) + tuple(state_shape),
424
+ torch.float32,
425
+ model.tensor_kwargs["device"],
426
+ seed,
427
+ )
428
+ * model.sde.sigma_max
429
+ )
430
+
431
+ # Generate video
432
+ sample = model.generate_samples_from_batch(
433
+ data_batch,
434
+ guidance=guidance,
435
+ state_shape=state_shape,
436
+ num_steps=num_steps,
437
+ is_negative_prompt=is_negative_prompt,
438
+ seed=seed,
439
+ x_sigma_max=x_sigma_max,
440
+ )
441
+
442
+ return sample
443
+
444
+
445
+ def generate_world_from_video(
446
+ model: DiffusionV2WModel,
447
+ state_shape: list[int],
448
+ is_negative_prompt: bool,
449
+ data_batch: dict,
450
+ guidance: float,
451
+ num_steps: int,
452
+ seed: int,
453
+ condition_latent: torch.Tensor,
454
+ num_input_frames: int,
455
+ ) -> Tuple[np.array, list, list]:
456
+ """Generate video using a conditioning video/image input.
457
+
458
+ Args:
459
+ model (DiffusionV2WModel): The diffusion model instance
460
+ state_shape (list[int]): Shape of the latent state [C,T,H,W]
461
+ is_negative_prompt (bool): Whether negative prompt is provided
462
+ data_batch (dict): Batch containing model inputs including text embeddings
463
+ guidance (float): Classifier-free guidance scale for sampling
464
+ num_steps (int): Number of diffusion sampling steps
465
+ seed (int): Random seed for generation
466
+ condition_latent (torch.Tensor): Latent tensor from conditioning video/image file
467
+ num_input_frames (int): Number of input frames
468
+
469
+ Returns:
470
+ np.array: Generated video frames in shape [T,H,W,C], range [0,255]
471
+ """
472
+ assert not model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i, "not supported"
473
+ augment_sigma = DEFAULT_AUGMENT_SIGMA
474
+
475
+ if condition_latent.shape[2] < state_shape[1]:
476
+ # Padding condition latent to state shape
477
+ b, c, t, h, w = condition_latent.shape
478
+ condition_latent = torch.cat(
479
+ [
480
+ condition_latent,
481
+ condition_latent.new_zeros(b, c, state_shape[1] - t, h, w),
482
+ ],
483
+ dim=2,
484
+ ).contiguous()
485
+ num_of_latent_condition = compute_num_latent_frames(model, num_input_frames)
486
+
487
+ x_sigma_max = (
488
+ misc.arch_invariant_rand(
489
+ (1,) + tuple(state_shape),
490
+ torch.float32,
491
+ model.tensor_kwargs["device"],
492
+ seed,
493
+ )
494
+ * model.sde.sigma_max
495
+ )
496
+
497
+ sample = model.generate_samples_from_batch(
498
+ data_batch,
499
+ guidance=guidance,
500
+ state_shape=state_shape,
501
+ num_steps=num_steps,
502
+ is_negative_prompt=is_negative_prompt,
503
+ seed=seed,
504
+ condition_latent=condition_latent,
505
+ num_condition_t=num_of_latent_condition,
506
+ condition_video_augment_sigma_in_inference=augment_sigma,
507
+ x_sigma_max=x_sigma_max,
508
+ )
509
+ return sample
510
+
511
+
512
+ def read_video_or_image_into_frames_BCTHW(
513
+ input_path: str,
514
+ input_path_format: str = "mp4",
515
+ H: int = None,
516
+ W: int = None,
517
+ normalize: bool = True,
518
+ max_frames: int = -1,
519
+ also_return_fps: bool = False,
520
+ ) -> torch.Tensor:
521
+ """Read video or image file and convert to tensor format.
522
+
523
+ Args:
524
+ input_path (str): Path to input video/image file
525
+ input_path_format (str): Format of input file (default: "mp4")
526
+ H (int, optional): Height to resize frames to
527
+ W (int, optional): Width to resize frames to
528
+ normalize (bool): Whether to normalize pixel values to [-1,1] (default: True)
529
+ max_frames (int): Maximum number of frames to read (-1 for all frames)
530
+ also_return_fps (bool): Whether to return fps along with frames
531
+
532
+ Returns:
533
+ torch.Tensor | tuple: Video tensor in shape [B,C,T,H,W], optionally with fps if requested
534
+ """
535
+ log.debug(f"Reading video from {input_path}")
536
+
537
+ loaded_data = load_from_fileobj(input_path, format=input_path_format)
538
+ frames, meta_data = loaded_data
539
+ if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"):
540
+ frames = np.array(frames[0]) # HWC, [0,255]
541
+ if frames.shape[-1] > 3: # RGBA, set the transparent to white
542
+ # Separate the RGB and Alpha channels
543
+ rgb_channels = frames[..., :3]
544
+ alpha_channel = frames[..., 3] / 255.0 # Normalize alpha channel to [0, 1]
545
+
546
+ # Create a white background
547
+ white_bg = np.ones_like(rgb_channels) * 255 # White background in RGB
548
+
549
+ # Blend the RGB channels with the white background based on the alpha channel
550
+ frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype(
551
+ np.uint8
552
+ )
553
+ frames = [frames]
554
+ fps = 0
555
+ else:
556
+ fps = int(meta_data.get("fps"))
557
+ if max_frames != -1:
558
+ frames = frames[:max_frames]
559
+ input_tensor = np.stack(frames, axis=0)
560
+ input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w")
561
+ if normalize:
562
+ input_tensor = input_tensor / 128.0 - 1.0
563
+ input_tensor = torch.from_numpy(input_tensor).bfloat16() # TCHW
564
+ log.debug(f"Raw data shape: {input_tensor.shape}")
565
+ if H is not None and W is not None:
566
+ input_tensor = transforms_F.resize(
567
+ input_tensor,
568
+ size=(H, W), # type: ignore
569
+ interpolation=transforms_F.InterpolationMode.BICUBIC,
570
+ antialias=True,
571
+ )
572
+ input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1)
573
+ if normalize:
574
+ input_tensor = input_tensor.to("cuda")
575
+ log.debug(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}")
576
+ if also_return_fps:
577
+ return input_tensor, fps
578
+ return input_tensor
579
+
580
+
581
+ def compute_num_latent_frames(model: DiffusionV2WModel, num_input_frames: int, downsample_factor=8) -> int:
582
+ """This function computes the number of latent frames given the number of input frames.
583
+ Args:
584
+ model (DiffusionV2WModel): video generation model
585
+ num_input_frames (int): number of input frames
586
+ downsample_factor (int): downsample factor for temporal reduce
587
+ Returns:
588
+ int: number of latent frames
589
+ """
590
+ num_latent_frames = (
591
+ num_input_frames
592
+ // model.tokenizer.video_vae.pixel_chunk_duration
593
+ * model.tokenizer.video_vae.latent_chunk_duration
594
+ )
595
+ if num_input_frames % model.tokenizer.video_vae.latent_chunk_duration == 1:
596
+ num_latent_frames += 1
597
+ elif num_input_frames % model.tokenizer.video_vae.latent_chunk_duration > 1:
598
+ assert (
599
+ num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1
600
+ ) % downsample_factor == 0, f"num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 must be divisible by {downsample_factor}"
601
+ num_latent_frames += (
602
+ 1 + (num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1) // downsample_factor
603
+ )
604
+
605
+ return num_latent_frames
606
+
607
+
608
+ def create_condition_latent_from_input_frames(
609
+ model: DiffusionV2WModel,
610
+ input_frames: torch.Tensor,
611
+ num_frames_condition: int = 25,
612
+ ):
613
+ """Create condition latent for video generation from input frames.
614
+
615
+ Takes the last num_frames_condition frames from input as conditioning.
616
+
617
+ Args:
618
+ model (DiffusionV2WModel): Video generation model
619
+ input_frames (torch.Tensor): Input video tensor [B,C,T,H,W], range [-1,1]
620
+ num_frames_condition (int): Number of frames to use for conditioning
621
+
622
+ Returns:
623
+ tuple: (condition_latent, encode_input_frames) where:
624
+ - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W]
625
+ - encode_input_frames (torch.Tensor): Padded input frames used for encoding
626
+ """
627
+ B, C, T, H, W = input_frames.shape
628
+ num_frames_encode = (
629
+ model.tokenizer.pixel_chunk_duration
630
+ ) # (model.state_shape[1] - 1) / model.vae.pixel_chunk_duration + 1
631
+ log.debug(
632
+ f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}"
633
+ )
634
+
635
+ log.debug(
636
+ f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}"
637
+ )
638
+
639
+ assert (
640
+ input_frames.shape[2] >= num_frames_condition
641
+ ), f"input_frames not enough for condition, require at least {num_frames_condition}, get {input_frames.shape[2]}, {input_frames.shape}"
642
+ assert (
643
+ num_frames_encode >= num_frames_condition
644
+ ), f"num_frames_encode should be larger than num_frames_condition, get {num_frames_encode}, {num_frames_condition}"
645
+
646
+ # Put the conditioal frames to the begining of the video, and pad the end with zero
647
+ condition_frames = input_frames[:, :, -num_frames_condition:]
648
+ padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W)
649
+ encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2)
650
+
651
+ log.debug(
652
+ f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end"
653
+ )
654
+ latent = model.encode(encode_input_frames)
655
+ return latent, encode_input_frames
656
+
657
+
658
+ def get_condition_latent(
659
+ model: DiffusionV2WModel,
660
+ input_image_or_video_path: str,
661
+ num_input_frames: int = 1,
662
+ state_shape: list[int] = None,
663
+ ):
664
+ """Get condition latent from input image/video file.
665
+
666
+ Args:
667
+ model (DiffusionV2WModel): Video generation model
668
+ input_image_or_video_path (str): Path to conditioning image/video
669
+ num_input_frames (int): Number of input frames for video2world prediction
670
+
671
+ Returns:
672
+ tuple: (condition_latent, input_frames) where:
673
+ - condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W]
674
+ - input_frames (torch.Tensor): Input frames tensor [B,C,T,H,W]
675
+ """
676
+ if state_shape is None:
677
+ state_shape = model.state_shape
678
+ assert num_input_frames > 0, "num_input_frames must be greater than 0"
679
+
680
+ H, W = (
681
+ state_shape[-2] * model.tokenizer.spatial_compression_factor,
682
+ state_shape[-1] * model.tokenizer.spatial_compression_factor,
683
+ )
684
+
685
+ input_path_format = input_image_or_video_path.split(".")[-1]
686
+ input_frames = read_video_or_image_into_frames_BCTHW(
687
+ input_image_or_video_path,
688
+ input_path_format=input_path_format,
689
+ H=H,
690
+ W=W,
691
+ )
692
+
693
+ condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_input_frames)
694
+ condition_latent = condition_latent.to(torch.bfloat16)
695
+
696
+ return condition_latent
697
+
698
+
699
+ def check_input_frames(input_path: str, required_frames: int) -> bool:
700
+ """Check if input video/image has sufficient frames.
701
+
702
+ Args:
703
+ input_path: Path to input video or image
704
+ required_frames: Number of required frames
705
+
706
+ Returns:
707
+ np.ndarray of frames if valid, None if invalid
708
+ """
709
+ if input_path.endswith((".jpg", ".jpeg", ".png")):
710
+ if required_frames > 1:
711
+ log.error(f"Input ({input_path}) is an image but {required_frames} frames are required")
712
+ return False
713
+ return True # Let the pipeline handle image loading
714
+ # For video input
715
+ try:
716
+ vid = imageio.get_reader(input_path, "ffmpeg")
717
+ frame_count = vid.count_frames()
718
+
719
+ if frame_count < required_frames:
720
+ log.error(f"Input video has {frame_count} frames but {required_frames} frames are required")
721
+ return False
722
+ else:
723
+ return True
724
+ except Exception as e:
725
+ log.error(f"Error reading video file {input_path}: {e}")
726
+ return False
model_t2w.py ADDED
@@ -0,0 +1,282 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import Callable, Dict, Optional, Tuple
17
+
18
+ from Cosmos.utils import misc
19
+ import torch
20
+ from torch import Tensor
21
+
22
+ from Cosmos.conditioner import CosmosCondition
23
+ from cosmos1.models.diffusion.diffusion.functional.batch_ops import batch_mul
24
+ from cosmos1.models.diffusion.diffusion.modules.denoiser_scaling import EDMScaling
25
+ from cosmos1.models.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS, Sampler
26
+ from Cosmos.types import DenoisePrediction
27
+ from Cosmos.module.blocks import FourierFeatures
28
+ from Cosmos.module.pretrained_vae import BaseVAE
29
+ from Cosmos.utils import log
30
+ from Cosmos.lazy_config import instantiate as lazy_instantiate
31
+
32
+
33
+ class EDMSDE:
34
+ def __init__(
35
+ self,
36
+ sigma_max: float,
37
+ sigma_min: float,
38
+ ):
39
+ self.sigma_max = sigma_max
40
+ self.sigma_min = sigma_min
41
+
42
+
43
+ class DiffusionT2WModel(torch.nn.Module):
44
+ """Text-to-world diffusion model that generates video frames from text descriptions.
45
+
46
+ This model implements a diffusion-based approach for generating videos conditioned on text input.
47
+ It handles the full pipeline including encoding/decoding through a VAE, diffusion sampling,
48
+ and classifier-free guidance.
49
+ """
50
+
51
+ def __init__(self, config):
52
+ """Initialize the diffusion model.
53
+
54
+ Args:
55
+ config: Configuration object containing model parameters and architecture settings
56
+ """
57
+ super().__init__()
58
+ # Initialize trained_data_record with defaultdict, key: image, video, iteration
59
+ self.config = config
60
+
61
+ self.precision = {
62
+ "float32": torch.float32,
63
+ "float16": torch.float16,
64
+ "bfloat16": torch.bfloat16,
65
+ }[config.precision]
66
+ self.tensor_kwargs = {"device": "cuda", "dtype": self.precision}
67
+ log.debug(f"DiffusionModel: precision {self.precision}")
68
+ # Timer passed to network to detect slow ranks.
69
+ # 1. set data keys and data information
70
+ self.sigma_data = config.sigma_data
71
+ self.state_shape = list(config.latent_shape)
72
+ self.setup_data_key()
73
+
74
+ # 2. setup up diffusion processing and scaling~(pre-condition), sampler
75
+ self.sde = EDMSDE(sigma_max=80, sigma_min=0.0002)
76
+ self.sampler = Sampler()
77
+ self.scaling = EDMScaling(self.sigma_data)
78
+ self.tokenizer = None
79
+ self.model = None
80
+
81
+ @property
82
+ def net(self):
83
+ return self.model.net
84
+
85
+ @property
86
+ def conditioner(self):
87
+ return self.model.conditioner
88
+
89
+ @property
90
+ def logvar(self):
91
+ return self.model.logvar
92
+
93
+ def set_up_tokenizer(self, tokenizer_dir: str):
94
+ self.tokenizer: BaseVAE = lazy_instantiate(self.config.tokenizer)
95
+ self.tokenizer.load_weights(tokenizer_dir)
96
+ if hasattr(self.tokenizer, "reset_dtype"):
97
+ self.tokenizer.reset_dtype()
98
+
99
+ @misc.timer("DiffusionModel: set_up_model")
100
+ def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format):
101
+ """Initialize the core model components including network, conditioner and logvar."""
102
+ self.model = self.build_model()
103
+ self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs)
104
+
105
+ def build_model(self) -> torch.nn.ModuleDict:
106
+ """Construct the model's neural network components.
107
+
108
+ Returns:
109
+ ModuleDict containing the network, conditioner and logvar components
110
+ """
111
+ config = self.config
112
+ net = lazy_instantiate(config.net)
113
+ conditioner = lazy_instantiate(config.conditioner)
114
+ logvar = torch.nn.Sequential(
115
+ FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False)
116
+ )
117
+
118
+ return torch.nn.ModuleDict(
119
+ {
120
+ "net": net,
121
+ "conditioner": conditioner,
122
+ "logvar": logvar,
123
+ }
124
+ )
125
+
126
+ @torch.no_grad()
127
+ def encode(self, state: torch.Tensor) -> torch.Tensor:
128
+ """Encode input state into latent representation using VAE.
129
+
130
+ Args:
131
+ state: Input tensor to encode
132
+
133
+ Returns:
134
+ Encoded latent representation scaled by sigma_data
135
+ """
136
+ return self.tokenizer.encode(state) * self.sigma_data
137
+
138
+ @torch.no_grad()
139
+ def decode(self, latent: torch.Tensor) -> torch.Tensor:
140
+ """Decode latent representation back to pixel space using VAE.
141
+
142
+ Args:
143
+ latent: Latent tensor to decode
144
+
145
+ Returns:
146
+ Decoded tensor in pixel space
147
+ """
148
+ return self.tokenizer.decode(latent / self.sigma_data)
149
+
150
+ def setup_data_key(self) -> None:
151
+ """Configure input data keys for video and image data."""
152
+ self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model
153
+
154
+ def get_x0_fn_from_batch(
155
+ self,
156
+ data_batch: Dict,
157
+ guidance: float = 1.5,
158
+ is_negative_prompt: bool = False,
159
+ ) -> Callable:
160
+ """
161
+ Generates a callable function `x0_fn` based on the provided data batch and guidance factor.
162
+
163
+ This function processes the input data batch through a conditioning workflow to obtain
164
+ conditioned and unconditioned states. It then defines a nested function `x0_fn` which
165
+ applies denoising on an input `noise_x` at a given noise level `sigma`.
166
+
167
+ Args:
168
+ data_batch: A batch of data used for conditioning. Format should align with conditioner
169
+ guidance: Scalar value that modulates influence of conditioned vs unconditioned state
170
+ is_negative_prompt: Use negative prompt t5 in uncondition if true
171
+
172
+ Returns:
173
+ A function `x0_fn(noise_x, sigma)` that takes noise_x and sigma, returns x0 prediction
174
+ """
175
+ if is_negative_prompt:
176
+ condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
177
+ else:
178
+ condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
179
+
180
+ def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
181
+ cond_x0 = self.denoise(noise_x, sigma, condition).x0
182
+ uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0
183
+ raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0)
184
+ if "guided_image" in data_batch:
185
+ # replacement trick that enables inpainting with base model
186
+ assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present"
187
+ guide_image = data_batch["guided_image"]
188
+ guide_mask = data_batch["guided_mask"]
189
+ raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0
190
+
191
+ return raw_x0
192
+
193
+ return x0_fn
194
+
195
+ def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: CosmosCondition) -> DenoisePrediction:
196
+ """
197
+ Performs denoising on the input noise data, noise level, and condition
198
+
199
+ Args:
200
+ xt (torch.Tensor): The input noise data.
201
+ sigma (torch.Tensor): The noise level.
202
+ condition (CosmosCondition): conditional information, generated from self.conditioner
203
+
204
+ Returns:
205
+ DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \
206
+ noise prediction (eps_pred) and optional confidence (logvar).
207
+ """
208
+
209
+ xt = xt.to(**self.tensor_kwargs)
210
+ sigma = sigma.to(**self.tensor_kwargs)
211
+ # get precondition for the network
212
+ c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma)
213
+
214
+ # forward pass through the network
215
+ net_output = self.net(
216
+ x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf
217
+ timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf
218
+ **condition.to_dict(),
219
+ )
220
+
221
+ logvar = self.model.logvar(c_noise)
222
+ x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output)
223
+
224
+ # get noise prediction based on sde
225
+ eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma)
226
+
227
+ return DenoisePrediction(x0_pred, eps_pred, logvar)
228
+
229
+ def generate_samples_from_batch(
230
+ self,
231
+ data_batch: Dict,
232
+ guidance: float = 1.5,
233
+ seed: int = 1,
234
+ state_shape: Tuple | None = None,
235
+ n_sample: int | None = None,
236
+ is_negative_prompt: bool = False,
237
+ num_steps: int = 35,
238
+ solver_option: COMMON_SOLVER_OPTIONS = "2ab",
239
+ x_sigma_max: Optional[torch.Tensor] = None,
240
+ sigma_max: float | None = None,
241
+ ) -> Tensor:
242
+ """Generate samples from a data batch using diffusion sampling.
243
+
244
+ This function generates samples from either image or video data batches using diffusion sampling.
245
+ It handles both conditional and unconditional generation with classifier-free guidance.
246
+
247
+ Args:
248
+ data_batch (Dict): Raw data batch from the training data loader
249
+ guidance (float, optional): Classifier-free guidance weight. Defaults to 1.5.
250
+ seed (int, optional): Random seed for reproducibility. Defaults to 1.
251
+ state_shape (Tuple | None, optional): Shape of the state tensor. Uses self.state_shape if None. Defaults to None.
252
+ n_sample (int | None, optional): Number of samples to generate. Defaults to None.
253
+ is_negative_prompt (bool, optional): Whether to use negative prompt for unconditional generation. Defaults to False.
254
+ num_steps (int, optional): Number of diffusion sampling steps. Defaults to 35.
255
+ solver_option (COMMON_SOLVER_OPTIONS, optional): Differential equation solver option. Defaults to "2ab" (multistep solver).
256
+ x_sigma_max (Optional[torch.Tensor], optional): Initial noisy tensor. If None, randomly initialized. Defaults to None.
257
+ sigma_max (float | None, optional): Maximum noise level. Uses self.sde.sigma_max if None. Defaults to None.
258
+
259
+ Returns:
260
+ Tensor: Generated samples after diffusion sampling
261
+ """
262
+ x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt)
263
+ if sigma_max is None:
264
+ sigma_max = self.sde.sigma_max
265
+ else:
266
+ log.info("Using provided sigma_max for diffusion sampling.")
267
+ if x_sigma_max is None:
268
+ x_sigma_max = (
269
+ misc.arch_invariant_rand(
270
+ (n_sample,) + tuple(state_shape),
271
+ torch.float32,
272
+ self.tensor_kwargs["device"],
273
+ seed,
274
+ )
275
+ * sigma_max
276
+ )
277
+
278
+ samples = self.sampler(
279
+ x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max, solver_option=solver_option
280
+ )
281
+
282
+ return samples
model_v2w.py ADDED
@@ -0,0 +1,341 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from dataclasses import dataclass
17
+ from typing import Callable, Dict, Optional, Tuple, Union
18
+
19
+ from Cosmos.utils import misc
20
+ import torch
21
+ from torch import Tensor
22
+
23
+ from Cosmos.conditioner import VideoExtendCondition
24
+ from cosmos1.models.diffusion.config.base.conditioner import VideoCondBoolConfig
25
+ from cosmos1.models.diffusion.diffusion.functional.batch_ops import batch_mul
26
+ from Cosmos.model_t2w import DiffusionT2WModel
27
+ from Cosmos.utils import log
28
+
29
+
30
+ @dataclass
31
+ class VideoDenoisePrediction:
32
+ x0: torch.Tensor # clean data prediction
33
+ eps: Optional[torch.Tensor] = None # noise prediction
34
+ logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty
35
+ xt: Optional[torch.Tensor] = None # input to the network, before muliply with c_in
36
+ x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent
37
+
38
+
39
+ class DiffusionV2WModel(DiffusionT2WModel):
40
+ def __init__(self, config):
41
+ super().__init__(config)
42
+
43
+ def augment_conditional_latent_frames(
44
+ self,
45
+ condition: VideoExtendCondition,
46
+ cfg_video_cond_bool: VideoCondBoolConfig,
47
+ gt_latent: Tensor,
48
+ condition_video_augment_sigma_in_inference: float = 0.001,
49
+ sigma: Tensor = None,
50
+ seed: int = 1,
51
+ ) -> Union[VideoExtendCondition, Tensor]:
52
+ """Augments the conditional frames with noise during inference.
53
+
54
+ Args:
55
+ condition (VideoExtendCondition): condition object
56
+ condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor.
57
+ condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network.
58
+ cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config
59
+ gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W
60
+ condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference
61
+ sigma (Tensor): noise level for the generation region
62
+ seed (int): random seed for reproducibility
63
+ Returns:
64
+ VideoExtendCondition: updated condition object
65
+ condition_video_augment_sigma: sigma for the condition region, feed to the network
66
+ augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W
67
+
68
+ """
69
+
70
+ # Inference only, use fixed sigma for the condition region
71
+ assert (
72
+ condition_video_augment_sigma_in_inference is not None
73
+ ), "condition_video_augment_sigma_in_inference should be provided"
74
+ augment_sigma = condition_video_augment_sigma_in_inference
75
+
76
+ if augment_sigma >= sigma.flatten()[0]:
77
+ # This is a inference trick! If the sampling sigma is smaller than the augment sigma, we will start denoising the condition region together.
78
+ # This is achieved by setting all region as `generation`, i.e. value=0
79
+ log.debug("augment_sigma larger than sigma or other frame, remove condition")
80
+ condition.condition_video_indicator = condition.condition_video_indicator * 0
81
+
82
+ augment_sigma = torch.tensor([augment_sigma], **self.tensor_kwargs)
83
+
84
+ # Now apply the augment_sigma to the gt_latent
85
+
86
+ noise = misc.arch_invariant_rand(
87
+ gt_latent.shape,
88
+ torch.float32,
89
+ self.tensor_kwargs["device"],
90
+ seed,
91
+ )
92
+
93
+ augment_latent = gt_latent + noise * augment_sigma[:, None, None, None, None]
94
+
95
+ _, _, c_in_augment, _ = self.scaling(sigma=augment_sigma)
96
+
97
+ # Multiply the whole latent with c_in_augment
98
+ augment_latent_cin = batch_mul(augment_latent, c_in_augment)
99
+
100
+ # Since the whole latent will multiply with c_in later, we devide the value to cancel the effect
101
+ _, _, c_in, _ = self.scaling(sigma=sigma)
102
+ augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in)
103
+
104
+ return condition, augment_latent_cin
105
+
106
+ def denoise(
107
+ self,
108
+ noise_x: Tensor,
109
+ sigma: Tensor,
110
+ condition: VideoExtendCondition,
111
+ condition_video_augment_sigma_in_inference: float = 0.001,
112
+ seed: int = 1,
113
+ ) -> VideoDenoisePrediction:
114
+ """Denoises input tensor using conditional video generation.
115
+
116
+ Args:
117
+ noise_x (Tensor): Noisy input tensor.
118
+ sigma (Tensor): Noise level.
119
+ condition (VideoExtendCondition): Condition for denoising.
120
+ condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference
121
+ seed (int): Random seed for reproducibility
122
+ Returns:
123
+ VideoDenoisePrediction containing:
124
+ - x0: Denoised prediction
125
+ - eps: Noise prediction
126
+ - logvar: Log variance of noise prediction
127
+ - xt: Input before c_in multiplication
128
+ - x0_pred_replaced: x0 prediction with condition regions replaced by ground truth
129
+ """
130
+
131
+ assert (
132
+ condition.gt_latent is not None
133
+ ), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}"
134
+ gt_latent = condition.gt_latent
135
+ cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool
136
+
137
+ condition_latent = gt_latent
138
+
139
+ # Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed
140
+ condition, augment_latent = self.augment_conditional_latent_frames(
141
+ condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma, seed
142
+ )
143
+ condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1]
144
+
145
+ # Compose the model input with condition region (augment_latent) and generation region (noise_x)
146
+ new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x
147
+ # Call the abse model
148
+ denoise_pred = super().denoise(new_noise_xt, sigma, condition)
149
+
150
+ x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0
151
+
152
+ x0_pred = x0_pred_replaced
153
+
154
+ return VideoDenoisePrediction(
155
+ x0=x0_pred,
156
+ eps=batch_mul(noise_x - x0_pred, 1.0 / sigma),
157
+ logvar=denoise_pred.logvar,
158
+ xt=new_noise_xt,
159
+ x0_pred_replaced=x0_pred_replaced,
160
+ )
161
+
162
+ def generate_samples_from_batch(
163
+ self,
164
+ data_batch: Dict,
165
+ guidance: float = 1.5,
166
+ seed: int = 1,
167
+ state_shape: Tuple | None = None,
168
+ n_sample: int | None = None,
169
+ is_negative_prompt: bool = False,
170
+ num_steps: int = 35,
171
+ condition_latent: Union[torch.Tensor, None] = None,
172
+ num_condition_t: Union[int, None] = None,
173
+ condition_video_augment_sigma_in_inference: float = None,
174
+ add_input_frames_guidance: bool = False,
175
+ x_sigma_max: Optional[torch.Tensor] = None,
176
+ ) -> Tensor:
177
+ """Generates video samples conditioned on input frames.
178
+
179
+ Args:
180
+ data_batch: Input data dictionary
181
+ guidance: Classifier-free guidance scale
182
+ seed: Random seed for reproducibility
183
+ state_shape: Shape of output tensor (defaults to model's state shape)
184
+ n_sample: Number of samples to generate (defaults to batch size)
185
+ is_negative_prompt: Whether to use negative prompting
186
+ num_steps: Number of denoising steps
187
+ condition_latent: Conditioning frames tensor (B,C,T,H,W)
188
+ num_condition_t: Number of frames to condition on
189
+ condition_video_augment_sigma_in_inference: Noise level for condition augmentation
190
+ add_input_frames_guidance: Whether to apply guidance to input frames
191
+ x_sigma_max: Maximum noise level tensor
192
+
193
+ Returns:
194
+ Generated video samples tensor
195
+ """
196
+
197
+ if n_sample is None:
198
+ input_key = self.input_data_key
199
+ n_sample = data_batch[input_key].shape[0]
200
+ if state_shape is None:
201
+ log.debug(f"Default Video state shape is used. {self.state_shape}")
202
+ state_shape = self.state_shape
203
+
204
+ assert condition_latent is not None, "condition_latent should be provided"
205
+
206
+ x0_fn = self.get_x0_fn_from_batch_with_condition_latent(
207
+ data_batch,
208
+ guidance,
209
+ is_negative_prompt=is_negative_prompt,
210
+ condition_latent=condition_latent,
211
+ num_condition_t=num_condition_t,
212
+ condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
213
+ add_input_frames_guidance=add_input_frames_guidance,
214
+ seed=seed,
215
+ )
216
+ if x_sigma_max is None:
217
+ x_sigma_max = (
218
+ misc.arch_invariant_rand(
219
+ (n_sample,) + tuple(state_shape),
220
+ torch.float32,
221
+ self.tensor_kwargs["device"],
222
+ seed,
223
+ )
224
+ * self.sde.sigma_max
225
+ )
226
+
227
+ samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max)
228
+ return samples
229
+
230
+ def get_x0_fn_from_batch_with_condition_latent(
231
+ self,
232
+ data_batch: Dict,
233
+ guidance: float = 1.5,
234
+ is_negative_prompt: bool = False,
235
+ condition_latent: torch.Tensor = None,
236
+ num_condition_t: Union[int, None] = None,
237
+ condition_video_augment_sigma_in_inference: float = None,
238
+ add_input_frames_guidance: bool = False,
239
+ seed: int = 1,
240
+ ) -> Callable:
241
+ """Creates denoising function for conditional video generation.
242
+
243
+ Args:
244
+ data_batch: Input data dictionary
245
+ guidance: Classifier-free guidance scale
246
+ is_negative_prompt: Whether to use negative prompting
247
+ condition_latent: Conditioning frames tensor (B,C,T,H,W)
248
+ num_condition_t: Number of frames to condition on
249
+ condition_video_augment_sigma_in_inference: Noise level for condition augmentation
250
+ add_input_frames_guidance: Whether to apply guidance to input frames
251
+ seed: Random seed for reproducibility
252
+
253
+ Returns:
254
+ Function that takes noisy input and noise level and returns denoised prediction
255
+ """
256
+ if is_negative_prompt:
257
+ condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
258
+ else:
259
+ condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
260
+
261
+ condition.video_cond_bool = True
262
+ condition = self.add_condition_video_indicator_and_video_input_mask(
263
+ condition_latent, condition, num_condition_t
264
+ )
265
+
266
+ uncondition.video_cond_bool = False if add_input_frames_guidance else True
267
+ uncondition = self.add_condition_video_indicator_and_video_input_mask(
268
+ condition_latent, uncondition, num_condition_t
269
+ )
270
+
271
+ def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
272
+ cond_x0 = self.denoise(
273
+ noise_x,
274
+ sigma,
275
+ condition,
276
+ condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
277
+ seed=seed,
278
+ ).x0_pred_replaced
279
+ uncond_x0 = self.denoise(
280
+ noise_x,
281
+ sigma,
282
+ uncondition,
283
+ condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
284
+ seed=seed,
285
+ ).x0_pred_replaced
286
+
287
+ return cond_x0 + guidance * (cond_x0 - uncond_x0)
288
+
289
+ return x0_fn
290
+
291
+ def add_condition_video_indicator_and_video_input_mask(
292
+ self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None
293
+ ) -> VideoExtendCondition:
294
+ """Adds conditioning masks to VideoExtendCondition object.
295
+
296
+ Creates binary indicators and input masks for conditional video generation.
297
+
298
+ Args:
299
+ latent_state: Input latent tensor (B,C,T,H,W)
300
+ condition: VideoExtendCondition object to update
301
+ num_condition_t: Number of frames to condition on
302
+
303
+ Returns:
304
+ Updated VideoExtendCondition with added masks:
305
+ - condition_video_indicator: Binary tensor marking condition regions
306
+ - condition_video_input_mask: Input mask for network
307
+ - gt_latent: Ground truth latent tensor
308
+ """
309
+ T = latent_state.shape[2]
310
+ latent_dtype = latent_state.dtype
311
+ condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type(
312
+ latent_dtype
313
+ ) # 1 for condition region
314
+
315
+ # Only in inference to decide the condition region
316
+ assert num_condition_t is not None, "num_condition_t should be provided"
317
+ assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}"
318
+ log.debug(
319
+ f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}"
320
+ )
321
+ condition_video_indicator[:, :, :num_condition_t] += 1.0
322
+
323
+ condition.gt_latent = latent_state
324
+ condition.condition_video_indicator = condition_video_indicator
325
+
326
+ B, C, T, H, W = latent_state.shape
327
+ # Create additional input_mask channel, this will be concatenated to the input of the network
328
+ # See design doc section (Implementation detail A.1 and A.2) for visualization
329
+ ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device)
330
+ zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device)
331
+ assert condition.video_cond_bool is not None, "video_cond_bool should be set"
332
+
333
+ # The input mask indicate whether the input is conditional region or not
334
+ if condition.video_cond_bool: # Condition one given video frames
335
+ condition.condition_video_input_mask = (
336
+ condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding
337
+ )
338
+ else: # Unconditional case, use for cfg
339
+ condition.condition_video_input_mask = zeros_padding
340
+
341
+ return condition
t5_text_encoder.py ADDED
@@ -0,0 +1,108 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from typing import List, Tuple, Union
17
+
18
+ import torch
19
+ import transformers
20
+ from transformers import T5EncoderModel, T5TokenizerFast
21
+
22
+ from Cosmos.utils import log
23
+
24
+ transformers.logging.set_verbosity_error()
25
+
26
+
27
+ class CosmosT5TextEncoder(torch.nn.Module):
28
+ """Handles T5 text encoding operations."""
29
+
30
+ def __init__(self, model_name: str = "google-t5/t5-11b", device: str = "cuda", cache_dir: str = "~/.cache"):
31
+ """Initializes the T5 tokenizer and encoder.
32
+
33
+ Args:
34
+ model_name: The name of the T5 model to use.
35
+ device: The device to use for computations.
36
+ """
37
+ super().__init__()
38
+ try:
39
+ self.tokenizer = T5TokenizerFast.from_pretrained(model_name, cache_dir=cache_dir)
40
+ self.text_encoder = T5EncoderModel.from_pretrained(model_name, cache_dir=cache_dir).to(device)
41
+ except Exception as e:
42
+ log.warning(f"Failed to load T5 model using cache_dir '{cache_dir}', falling back to default location: {e}")
43
+ self.tokenizer = T5TokenizerFast.from_pretrained(model_name)
44
+ self.text_encoder = T5EncoderModel.from_pretrained(model_name).to(device)
45
+ self.text_encoder.eval()
46
+ self.device = device
47
+
48
+ @torch.inference_mode()
49
+ def encode_prompts(
50
+ self, prompts: Union[str, List[str]], max_length: int = 512
51
+ ) -> Tuple[torch.Tensor, torch.Tensor]:
52
+ """Encodes text prompts into hidden state representations using a T5 encoder.
53
+
54
+ This function tokenizes the input prompts, processes them through a T5 text encoder,
55
+ and returns the last hidden states. The encoded outputs beyond the actual sequence
56
+ length are zero-padded. All prompts in a batch are padded to max_length.
57
+
58
+ Args:
59
+ prompts: Input text to encode. Can be a single string or a list of strings.
60
+ max_length: Maximum sequence length for tokenization and padding. Longer
61
+ sequences will be truncated. Defaults to 512.
62
+ return_mask: If True, returns the attention mask along with encoded text.
63
+ Defaults to False.
64
+
65
+ Returns:
66
+ If return_mask is False:
67
+ torch.Tensor: Encoded text embeddings of shape (batch_size, max_length, hidden_size).
68
+ If return_mask is True:
69
+ tuple[torch.Tensor, torch.Tensor]: A tuple containing:
70
+ - Encoded text embeddings of shape (batch_size, max_length, hidden_size)
71
+ - Attention mask of shape (batch_size, max_length) as boolean tensor
72
+
73
+ Raises:
74
+ ValueError: If the input prompts list is empty.
75
+
76
+ Example:
77
+ >>> encoder = CosmosT5TextEncoder()
78
+ >>> prompts = ["Hello world", "Another example"]
79
+ >>> embeddings = encoder.encode_prompts(prompts, max_length=128)
80
+ """
81
+ if isinstance(prompts, str):
82
+ prompts = [prompts]
83
+
84
+ if not prompts:
85
+ raise ValueError("The input prompt list is empty.")
86
+
87
+ batch_encoding = self.tokenizer.batch_encode_plus(
88
+ prompts,
89
+ return_tensors="pt",
90
+ truncation=True,
91
+ padding="max_length",
92
+ max_length=max_length,
93
+ return_length=True,
94
+ return_offsets_mapping=False,
95
+ )
96
+
97
+ input_ids = batch_encoding.input_ids.to(self.device)
98
+ attn_mask = batch_encoding.attention_mask.to(self.device)
99
+
100
+ outputs = self.text_encoder(input_ids=input_ids, attention_mask=attn_mask)
101
+
102
+ encoded_text = outputs.last_hidden_state
103
+ lengths = attn_mask.sum(dim=1).cpu()
104
+
105
+ for batch_id in range(encoded_text.shape[0]):
106
+ encoded_text[batch_id][lengths[batch_id] :] = 0
107
+
108
+ return encoded_text, attn_mask
text2world.py ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+
19
+ from Cosmos.utils import misc
20
+ import torch
21
+
22
+ from Cosmos.inference_utils import add_common_arguments, validate_args
23
+ from Cosmos.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
24
+ from Cosmos.utils import log
25
+ from Cosmos.utils.io import read_prompts_from_file, save_video
26
+
27
+ torch.enable_grad(False)
28
+
29
+
30
+ def parse_arguments() -> argparse.Namespace:
31
+ parser = argparse.ArgumentParser(description="Text to world generation demo script")
32
+ # Add common arguments
33
+ add_common_arguments(parser)
34
+
35
+ # Add text2world specific arguments
36
+ parser.add_argument(
37
+ "--diffusion_transformer_dir",
38
+ type=str,
39
+ default="Cosmos-1.0-Diffusion-7B-Text2World",
40
+ help="DiT model weights directory name relative to checkpoint_dir",
41
+ choices=[
42
+ "Cosmos-1.0-Diffusion-7B-Text2World",
43
+ "Cosmos-1.0-Diffusion-14B-Text2World",
44
+ ],
45
+ )
46
+ parser.add_argument(
47
+ "--prompt_upsampler_dir",
48
+ type=str,
49
+ default="Cosmos-1.0-Prompt-Upsampler-12B-Text2World",
50
+ help="Prompt upsampler weights directory relative to checkpoint_dir",
51
+ )
52
+
53
+ parser.add_argument(
54
+ "--word_limit_to_skip_upsampler",
55
+ type=int,
56
+ default=250,
57
+ help="Skip prompt upsampler for better robustness if the number of words in the prompt is greater than this value",
58
+ )
59
+
60
+ return parser.parse_args()
61
+
62
+
63
+ def demo(cfg):
64
+ """Run text-to-world generation demo.
65
+
66
+ This function handles the main text-to-world generation pipeline, including:
67
+ - Setting up the random seed for reproducibility
68
+ - Initializing the generation pipeline with the provided configuration
69
+ - Processing single or multiple prompts from input
70
+ - Generating videos from text prompts
71
+ - Saving the generated videos and corresponding prompts to disk
72
+
73
+ Args:
74
+ cfg (argparse.Namespace): Configuration namespace containing:
75
+ - Model configuration (checkpoint paths, model settings)
76
+ - Generation parameters (guidance, steps, dimensions)
77
+ - Input/output settings (prompts, save paths)
78
+ - Performance options (model offloading settings)
79
+
80
+ The function will save:
81
+ - Generated MP4 video files
82
+ - Text files containing the processed prompts
83
+
84
+ If guardrails block the generation, a critical log message is displayed
85
+ and the function continues to the next prompt if available.
86
+ """
87
+ misc.set_random_seed(cfg.seed)
88
+ inference_type = "text2world"
89
+ validate_args(cfg, inference_type)
90
+
91
+ # Initialize text2world generation model pipeline
92
+ pipeline = DiffusionText2WorldGenerationPipeline(
93
+ inference_type=inference_type,
94
+ checkpoint_dir=cfg.checkpoint_dir,
95
+ checkpoint_name=cfg.diffusion_transformer_dir,
96
+ prompt_upsampler_dir=cfg.prompt_upsampler_dir,
97
+ enable_prompt_upsampler=not cfg.disable_prompt_upsampler,
98
+ offload_network=cfg.offload_diffusion_transformer,
99
+ offload_tokenizer=cfg.offload_tokenizer,
100
+ offload_text_encoder_model=cfg.offload_text_encoder_model,
101
+ offload_prompt_upsampler=cfg.offload_prompt_upsampler,
102
+ offload_guardrail_models=cfg.offload_guardrail_models,
103
+ guidance=cfg.guidance,
104
+ num_steps=cfg.num_steps,
105
+ height=cfg.height,
106
+ width=cfg.width,
107
+ fps=cfg.fps,
108
+ num_video_frames=cfg.num_video_frames,
109
+ seed=cfg.seed,
110
+ )
111
+
112
+ # Handle multiple prompts if prompt file is provided
113
+ if cfg.batch_input_path:
114
+ log.info(f"Reading batch inputs from path: {args.batch_input_path}")
115
+ prompts = read_prompts_from_file(cfg.batch_input_path)
116
+ else:
117
+ # Single prompt case
118
+ prompts = [{"prompt": cfg.prompt}]
119
+
120
+ os.makedirs(cfg.video_save_folder, exist_ok=True)
121
+ for i, input_dict in enumerate(prompts):
122
+ current_prompt = input_dict.get("prompt", None)
123
+ if current_prompt is None:
124
+ log.critical("Prompt is missing, skipping world generation.")
125
+ continue
126
+
127
+ # Generate video
128
+ generated_output = pipeline.generate(current_prompt, cfg.negative_prompt, cfg.word_limit_to_skip_upsampler)
129
+ if generated_output is None:
130
+ log.critical("Guardrail blocked text2world generation.")
131
+ continue
132
+ video, prompt = generated_output
133
+
134
+ if cfg.batch_input_path:
135
+ video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4")
136
+ prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt")
137
+ else:
138
+ video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
139
+ prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
140
+
141
+ # Save video
142
+ save_video(
143
+ video=video,
144
+ fps=cfg.fps,
145
+ H=cfg.height,
146
+ W=cfg.width,
147
+ video_save_quality=5,
148
+ video_save_path=video_save_path,
149
+ )
150
+
151
+ # Save prompt to text file alongside video
152
+ with open(prompt_save_path, "wb") as f:
153
+ f.write(prompt.encode("utf-8"))
154
+
155
+ log.info(f"Saved video to {video_save_path}")
156
+ log.info(f"Saved prompt to {prompt_save_path}")
157
+
158
+
159
+ if __name__ == "__main__":
160
+ args = parse_arguments()
161
+ demo(args)
text2world_prompt_upsampler_inference.py ADDED
@@ -0,0 +1,157 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ This demo script is used to run inference for Cosmos-1.0-Prompt-Upsampler-12B-Text2World.
18
+ Command:
19
+ PYTHONPATH=$(pwd) python cosmos1/models/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py
20
+
21
+ """
22
+ import argparse
23
+ import os
24
+ import re
25
+
26
+ from cosmos1.models.autoregressive.configs.base.model_config import create_text_model_config
27
+ from cosmos1.models.autoregressive.model import AutoRegressiveModel
28
+ from cosmos1.models.diffusion.prompt_upsampler.inference import chat_completion
29
+ from Cosmos import guardrail_presets as guardrail_presets
30
+ from Cosmos.utils import log
31
+
32
+
33
+ def create_prompt_upsampler(checkpoint_dir: str) -> AutoRegressiveModel:
34
+ model_config, tokenizer_config = create_text_model_config(
35
+ model_ckpt_path=os.path.join(checkpoint_dir, "model.pt"),
36
+ tokenizer_path=os.path.join(checkpoint_dir),
37
+ model_family="mistral",
38
+ model_size="12b",
39
+ is_instruct_model=True,
40
+ max_batch_size=1,
41
+ rope_dim="1D",
42
+ add_special_tokens=True,
43
+ max_seq_len=1024,
44
+ pytorch_rope_version="v1",
45
+ )
46
+ log.debug(f"Text prompt upsampler model config: {model_config}")
47
+
48
+ # Create and return a LLM instance
49
+ return AutoRegressiveModel.build(
50
+ model_config=model_config,
51
+ tokenizer_config=tokenizer_config,
52
+ ).to("cuda")
53
+
54
+
55
+ def run_chat_completion(model: AutoRegressiveModel, input: str, temperature: float = 0.01):
56
+ """
57
+ text2world prompt upsampler model is finetuned for chat.
58
+ During training, the context window for the initial prompt upsampler models is 512 tokens. For inference, we set max_seq_len to 1024 to accommodate longer inputs.
59
+ Setting `max_gen_len` is optional as the finetuned models can naturally determine when to stop generating.
60
+ """
61
+
62
+ dialogs = [[{"role": "user", "content": f"Upsample the short caption to a long caption: {str(input)}"}]]
63
+
64
+ results = chat_completion(
65
+ model,
66
+ dialogs,
67
+ max_gen_len=512,
68
+ temperature=temperature,
69
+ top_p=None,
70
+ top_k=None,
71
+ logprobs=False,
72
+ )
73
+ upsampled_prompt = str(clean_text(results[0]["generation"]["content"]))
74
+ return upsampled_prompt
75
+
76
+
77
+ def clean_text(text: str) -> str:
78
+ """Clean the text by removing prefixes, suffixes, formatting markers, and normalizing whitespace."""
79
+ # Replace all variations of newlines with a space
80
+ text = text.replace("\n", " ").replace("\r", " ")
81
+
82
+ # Use a regex to find sections of the form '- **...**'
83
+ pattern = r"(- \*\*)(.*?)(\*\*)"
84
+
85
+ def replacement(match: re.Match[str]) -> str:
86
+ content = match.group(2) # The text inside - ** and **
87
+ words = re.findall(r"\w+", content)
88
+ if len(words) < 10:
89
+ # If fewer than 10 words, remove the entire '- **...**' portion
90
+ return ""
91
+ else:
92
+ # If 10 or more words, keep the entire section as it is
93
+ return match.group(0)
94
+
95
+ text = re.sub(pattern, replacement, text)
96
+
97
+ # Remove common prefixes
98
+ prefixes = ["Caption:", "#####", "####", "- ", "* ", ","]
99
+ for prefix in prefixes:
100
+ # lstrip(prefix) won't strip entire strings, but character sets.
101
+ # For more reliable prefix removal, do:
102
+ if text.startswith(prefix):
103
+ text = text[len(prefix) :].lstrip()
104
+
105
+ # Remove extra spaces
106
+ text = " ".join(text.split())
107
+
108
+ # Strip any remaining leading/trailing punctuation, whitespace, and quotes
109
+ text = text.strip(' -,*:"\'"“”')
110
+
111
+ return text
112
+
113
+
114
+ def parse_args():
115
+ parser = argparse.ArgumentParser(description="Run prompt upsampler inference")
116
+ parser.add_argument("--input", type=str, default="A dog is playing with a ball.")
117
+ parser.add_argument("--temperature", type=float, default=0.01, help="Inference temperature")
118
+ parser.add_argument(
119
+ "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints"
120
+ )
121
+ parser.add_argument(
122
+ "--prompt_upsampler_dir",
123
+ type=str,
124
+ default="Cosmos-1.0-Prompt-Upsampler-12B-Text2World",
125
+ help="Prompt upsampler weights directory relative to checkpoint_dir",
126
+ )
127
+ parser.add_argument(
128
+ "--guardrail_dir",
129
+ type=str,
130
+ default="Cosmos-1.0-Guardrail",
131
+ help="Guardrail weights directory relative to checkpoint_dir",
132
+ )
133
+ return parser.parse_args()
134
+
135
+
136
+ def main(args):
137
+ guardrail_runner = guardrail_presets.create_text_guardrail_runner(
138
+ os.path.join(args.checkpoint_dir, args.guardrail_dir)
139
+ )
140
+ is_safe = guardrail_presets.run_text_guardrail(args.input, guardrail_runner)
141
+ if not is_safe:
142
+ log.critical("Input text prompt is not safe.")
143
+ return
144
+
145
+ prompt_upsampler = create_prompt_upsampler(os.path.join(args.checkpoint_dir, args.prompt_upsampler_dir))
146
+ upsampled_prompt = run_chat_completion(prompt_upsampler, args.input, temperature=args.temperature)
147
+ is_safe = guardrail_presets.run_text_guardrail(upsampled_prompt, guardrail_runner)
148
+ if not is_safe:
149
+ log.critical("Upsampled text prompt is not safe.")
150
+ return
151
+
152
+ log.info(f"Upsampled prompt: {upsampled_prompt}")
153
+
154
+
155
+ if __name__ == "__main__":
156
+ args = parse_args()
157
+ main(args)
types.py ADDED
@@ -0,0 +1,28 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ from __future__ import annotations
17
+
18
+ from dataclasses import dataclass
19
+ from typing import Optional
20
+
21
+ import torch
22
+
23
+
24
+ @dataclass
25
+ class DenoisePrediction:
26
+ x0: torch.Tensor # clean data prediction
27
+ eps: Optional[torch.Tensor] = None # noise prediction
28
+ logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty
video2world.py ADDED
@@ -0,0 +1,179 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+
19
+ from Cosmos.utils import misc
20
+ import torch
21
+
22
+ from Cosmos.inference_utils import add_common_arguments, check_input_frames, validate_args
23
+ from Cosmos.world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline
24
+ from Cosmos.utils import log
25
+ from Cosmos.utils.io import read_prompts_from_file, save_video
26
+
27
+ torch.enable_grad(False)
28
+
29
+
30
+ def parse_arguments() -> argparse.Namespace:
31
+ parser = argparse.ArgumentParser(description="Video to world generation demo script")
32
+ # Add common arguments
33
+ add_common_arguments(parser)
34
+
35
+ # Add video2world specific arguments
36
+ parser.add_argument(
37
+ "--diffusion_transformer_dir",
38
+ type=str,
39
+ default="Cosmos-1.0-Diffusion-7B-Video2World",
40
+ help="DiT model weights directory name relative to checkpoint_dir",
41
+ choices=[
42
+ "Cosmos-1.0-Diffusion-7B-Video2World",
43
+ "Cosmos-1.0-Diffusion-14B-Video2World",
44
+ ],
45
+ )
46
+ parser.add_argument(
47
+ "--prompt_upsampler_dir",
48
+ type=str,
49
+ default="Pixtral-12B",
50
+ help="Prompt upsampler weights directory relative to checkpoint_dir",
51
+ )
52
+ parser.add_argument(
53
+ "--input_image_or_video_path",
54
+ type=str,
55
+ help="Input video/image path for generating a single video",
56
+ )
57
+ parser.add_argument(
58
+ "--num_input_frames",
59
+ type=int,
60
+ default=1,
61
+ help="Number of input frames for video2world prediction",
62
+ choices=[1, 9],
63
+ )
64
+
65
+ return parser.parse_args()
66
+
67
+
68
+ def demo(cfg):
69
+ """Run video-to-world generation demo.
70
+
71
+ This function handles the main video-to-world generation pipeline, including:
72
+ - Setting up the random seed for reproducibility
73
+ - Initializing the generation pipeline with the provided configuration
74
+ - Processing single or multiple prompts/images/videos from input
75
+ - Generating videos from prompts and images/videos
76
+ - Saving the generated videos and corresponding prompts to disk
77
+
78
+ Args:
79
+ cfg (argparse.Namespace): Configuration namespace containing:
80
+ - Model configuration (checkpoint paths, model settings)
81
+ - Generation parameters (guidance, steps, dimensions)
82
+ - Input/output settings (prompts/images/videos, save paths)
83
+ - Performance options (model offloading settings)
84
+
85
+ The function will save:
86
+ - Generated MP4 video files
87
+ - Text files containing the processed prompts
88
+
89
+ If guardrails block the generation, a critical log message is displayed
90
+ and the function continues to the next prompt if available.
91
+ """
92
+ misc.set_random_seed(cfg.seed)
93
+ inference_type = "video2world"
94
+ validate_args(cfg, inference_type)
95
+
96
+ # Initialize video2world generation model pipeline
97
+ pipeline = DiffusionVideo2WorldGenerationPipeline(
98
+ inference_type=inference_type,
99
+ checkpoint_dir=cfg.checkpoint_dir,
100
+ checkpoint_name=cfg.diffusion_transformer_dir,
101
+ prompt_upsampler_dir=cfg.prompt_upsampler_dir,
102
+ enable_prompt_upsampler=not cfg.disable_prompt_upsampler,
103
+ offload_network=cfg.offload_diffusion_transformer,
104
+ offload_tokenizer=cfg.offload_tokenizer,
105
+ offload_text_encoder_model=cfg.offload_text_encoder_model,
106
+ offload_prompt_upsampler=cfg.offload_prompt_upsampler,
107
+ offload_guardrail_models=cfg.offload_guardrail_models,
108
+ guidance=cfg.guidance,
109
+ num_steps=cfg.num_steps,
110
+ height=cfg.height,
111
+ width=cfg.width,
112
+ fps=cfg.fps,
113
+ num_video_frames=cfg.num_video_frames,
114
+ seed=cfg.seed,
115
+ num_input_frames=cfg.num_input_frames,
116
+ )
117
+
118
+ # Handle multiple prompts if prompt file is provided
119
+ if cfg.batch_input_path:
120
+ log.info(f"Reading batch inputs from path: {args.batch_input_path}")
121
+ prompts = read_prompts_from_file(cfg.batch_input_path)
122
+ else:
123
+ # Single prompt case
124
+ prompts = [{"prompt": cfg.prompt, "visual_input": cfg.input_image_or_video_path}]
125
+
126
+ os.makedirs(cfg.video_save_folder, exist_ok=True)
127
+ for i, input_dict in enumerate(prompts):
128
+ current_prompt = input_dict.get("prompt", None)
129
+ if current_prompt is None and cfg.disable_prompt_upsampler:
130
+ log.critical("Prompt is missing, skipping world generation.")
131
+ continue
132
+ current_image_or_video_path = input_dict.get("visual_input", None)
133
+ if current_image_or_video_path is None:
134
+ log.critical("Visual input is missing, skipping world generation.")
135
+ continue
136
+
137
+ # Check input frames
138
+ if not check_input_frames(current_image_or_video_path, cfg.num_input_frames):
139
+ continue
140
+
141
+ # Generate video
142
+ generated_output = pipeline.generate(
143
+ prompt=current_prompt,
144
+ image_or_video_path=current_image_or_video_path,
145
+ negative_prompt=cfg.negative_prompt,
146
+ )
147
+ if generated_output is None:
148
+ log.critical("Guardrail blocked video2world generation.")
149
+ continue
150
+ video, prompt = generated_output
151
+
152
+ if cfg.batch_input_path:
153
+ video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4")
154
+ prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt")
155
+ else:
156
+ video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
157
+ prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
158
+
159
+ # Save video
160
+ save_video(
161
+ video=video,
162
+ fps=cfg.fps,
163
+ H=cfg.height,
164
+ W=cfg.width,
165
+ video_save_quality=5,
166
+ video_save_path=video_save_path,
167
+ )
168
+
169
+ # Save prompt to text file alongside video
170
+ with open(prompt_save_path, "wb") as f:
171
+ f.write(prompt.encode("utf-8"))
172
+
173
+ log.info(f"Saved video to {video_save_path}")
174
+ log.info(f"Saved prompt to {prompt_save_path}")
175
+
176
+
177
+ if __name__ == "__main__":
178
+ args = parse_arguments()
179
+ demo(args)
video2world_hf.py ADDED
@@ -0,0 +1,283 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import argparse
17
+ import os
18
+
19
+ from Cosmos.utils import misc
20
+ import torch
21
+
22
+ from Cosmos.inference_utils import add_common_arguments, check_input_frames, validate_args
23
+ from Cosmos.world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline
24
+ from Cosmos.utils import log
25
+ from Cosmos.utils.io import read_prompts_from_file, save_video
26
+
27
+ from Cosmos.download_diffusion import main as download_diffusion
28
+ from transformers import PreTrainedModel, PretrainedConfig
29
+
30
+ torch.enable_grad(False)
31
+
32
+ #custom config class
33
+ class DiffusionVideo2WorldConfig(PretrainedConfig):
34
+ model_type = "DiffusionVideo2World"
35
+ def __init__(self, **kwargs):
36
+ super().__init__(**kwargs)
37
+ self.checkpoint_dir = kwargs.get("checkpoint_dir", "checkpoints")
38
+ self.tokenizer_dir = kwargs.get("tokenizer_dir", "Cosmos-1.0-Tokenizer-CV8x8x8")
39
+ self.video_save_name = kwargs.get("video_save_name", "output")
40
+ self.video_save_folder = kwargs.get("video_save_folder", "outputs/")
41
+ self.prompt = kwargs.get("prompt", None)
42
+ self.batch_input_path = kwargs.get("batch_input_path", None)
43
+ self.negative_prompt = kwargs.get("negative_prompt", None)
44
+ self.num_steps = kwargs.get("num_steps", 35)
45
+ self.guidance = kwargs.get("guidance", 7)
46
+ self.num_video_frames = kwargs.get("num_video_frames", 121)
47
+ self.height = kwargs.get("height", 704)
48
+ self.width = kwargs.get("width", 1280)
49
+ self.fps = kwargs.get("fps", 24)
50
+ self.seed = kwargs.get("seed", 1)
51
+ self.disable_prompt_upsampler = kwargs.get("disable_prompt_upsampler", False)
52
+ self.offload_diffusion_transformer = kwargs.get("offload_diffusion_transformer", False)
53
+ self.offload_tokenizer = kwargs.get("offload_tokenizer", False)
54
+ self.offload_text_encoder_model = kwargs.get("offload_text_encoder_model", False)
55
+ self.offload_prompt_upsampler = kwargs.get("offload_prompt_upsampler", False)
56
+ self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
57
+
58
+ self.diffusion_transformer_dir = kwargs.get("diffusion_transformer_dir", "Cosmos-1.0-Diffusion-7B-Video2World")
59
+ self.prompt_upsampler_dir = kwargs.get("prompt_upsampler_dir", "Pixtral-12B")
60
+ self.input_image_or_video_path = kwargs.get("input_image_or_video_path", None)
61
+ self.num_input_frames = kwargs.get("num_input_frames", 1)
62
+
63
+ class DiffusionVideo2World(PreTrainedModel):
64
+ config_class = DiffusionVideo2WorldConfig
65
+
66
+ def __init__(self, config=DiffusionVideo2WorldConfig()):
67
+ super().__init__(config)
68
+ cfg = config
69
+
70
+ misc.set_random_seed(cfg.seed)
71
+ inference_type = "video2world"
72
+ validate_args(cfg, inference_type)
73
+
74
+ self.pipeline = DiffusionVideo2WorldGenerationPipeline(
75
+ inference_type=inference_type,
76
+ checkpoint_dir=cfg.checkpoint_dir,
77
+ checkpoint_name=cfg.diffusion_transformer_dir,
78
+ prompt_upsampler_dir=cfg.prompt_upsampler_dir,
79
+ enable_prompt_upsampler=not cfg.disable_prompt_upsampler,
80
+ offload_network=cfg.offload_diffusion_transformer,
81
+ offload_tokenizer=cfg.offload_tokenizer,
82
+ offload_text_encoder_model=cfg.offload_text_encoder_model,
83
+ offload_prompt_upsampler=cfg.offload_prompt_upsampler,
84
+ offload_guardrail_models=cfg.offload_guardrail_models,
85
+ guidance=cfg.guidance,
86
+ num_steps=cfg.num_steps,
87
+ height=cfg.height,
88
+ width=cfg.width,
89
+ fps=cfg.fps,
90
+ num_video_frames=cfg.num_video_frames,
91
+ seed=cfg.seed,
92
+ num_input_frames=cfg.num_input_frames,
93
+ )
94
+
95
+ def forward(self):
96
+ cfg = self.config
97
+
98
+ # Handle multiple prompts if prompt file is provided
99
+ if cfg.batch_input_path:
100
+ log.info(f"Reading batch inputs from path: {args.batch_input_path}")
101
+ prompts = read_prompts_from_file(cfg.batch_input_path)
102
+ else:
103
+ # Single prompt case
104
+ prompts = [{"prompt": cfg.prompt, "visual_input": cfg.input_image_or_video_path}]
105
+
106
+ os.makedirs(cfg.video_save_folder, exist_ok=True)
107
+ for i, input_dict in enumerate(prompts):
108
+ current_prompt = input_dict.get("prompt", None)
109
+ if current_prompt is None and cfg.disable_prompt_upsampler:
110
+ log.critical("Prompt is missing, skipping world generation.")
111
+ continue
112
+ current_image_or_video_path = input_dict.get("visual_input", None)
113
+ if current_image_or_video_path is None:
114
+ log.critical("Visual input is missing, skipping world generation.")
115
+ continue
116
+
117
+ # Check input frames
118
+ if not check_input_frames(current_image_or_video_path, cfg.num_input_frames):
119
+ continue
120
+
121
+ # Generate video
122
+ generated_output = pipeline.generate(
123
+ prompt=current_prompt,
124
+ image_or_video_path=current_image_or_video_path,
125
+ negative_prompt=cfg.negative_prompt,
126
+ )
127
+ if generated_output is None:
128
+ log.critical("Guardrail blocked video2world generation.")
129
+ continue
130
+ video, prompt = generated_output
131
+
132
+ if cfg.batch_input_path:
133
+ video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4")
134
+ prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt")
135
+ else:
136
+ video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
137
+ prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
138
+
139
+ # Save video
140
+ save_video(
141
+ video=video,
142
+ fps=cfg.fps,
143
+ H=cfg.height,
144
+ W=cfg.width,
145
+ video_save_quality=5,
146
+ video_save_path=video_save_path,
147
+ )
148
+
149
+ # Save prompt to text file alongside video
150
+ with open(prompt_save_path, "wb") as f:
151
+ f.write(prompt.encode("utf-8"))
152
+
153
+ log.info(f"Saved video to {video_save_path}")
154
+ log.info(f"Saved prompt to {prompt_save_path}")
155
+
156
+ def save_pretrained(self, save_directory, **kwargs):
157
+ # We don't save anything, but need this function to override
158
+ pass
159
+
160
+ @classmethod
161
+ def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
162
+ config = kwargs["config"]
163
+ other_args = kwargs.copy()
164
+ other_args.pop("config")
165
+ config.update(other_args)
166
+ model_sizes = ["7B",] if "7B" in config.diffusion_transformer_dir else ["14B",]
167
+ model_types = ["Video2World",]
168
+ download_diffusion(model_types, model_sizes, config.checkpoint_dir)
169
+ model = cls(config)
170
+ return model
171
+
172
+ def demo(cfg):
173
+ """Run video-to-world generation demo.
174
+
175
+ This function handles the main video-to-world generation pipeline, including:
176
+ - Setting up the random seed for reproducibility
177
+ - Initializing the generation pipeline with the provided configuration
178
+ - Processing single or multiple prompts/images/videos from input
179
+ - Generating videos from prompts and images/videos
180
+ - Saving the generated videos and corresponding prompts to disk
181
+
182
+ Args:
183
+ cfg (argparse.Namespace): Configuration namespace containing:
184
+ - Model configuration (checkpoint paths, model settings)
185
+ - Generation parameters (guidance, steps, dimensions)
186
+ - Input/output settings (prompts/images/videos, save paths)
187
+ - Performance options (model offloading settings)
188
+
189
+ The function will save:
190
+ - Generated MP4 video files
191
+ - Text files containing the processed prompts
192
+
193
+ If guardrails block the generation, a critical log message is displayed
194
+ and the function continues to the next prompt if available.
195
+ """
196
+ misc.set_random_seed(cfg.seed)
197
+ inference_type = "video2world"
198
+ validate_args(cfg, inference_type)
199
+
200
+ # Initialize video2world generation model pipeline
201
+ pipeline = DiffusionVideo2WorldGenerationPipeline(
202
+ inference_type=inference_type,
203
+ checkpoint_dir=cfg.checkpoint_dir,
204
+ checkpoint_name=cfg.diffusion_transformer_dir,
205
+ prompt_upsampler_dir=cfg.prompt_upsampler_dir,
206
+ enable_prompt_upsampler=not cfg.disable_prompt_upsampler,
207
+ offload_network=cfg.offload_diffusion_transformer,
208
+ offload_tokenizer=cfg.offload_tokenizer,
209
+ offload_text_encoder_model=cfg.offload_text_encoder_model,
210
+ offload_prompt_upsampler=cfg.offload_prompt_upsampler,
211
+ offload_guardrail_models=cfg.offload_guardrail_models,
212
+ guidance=cfg.guidance,
213
+ num_steps=cfg.num_steps,
214
+ height=cfg.height,
215
+ width=cfg.width,
216
+ fps=cfg.fps,
217
+ num_video_frames=cfg.num_video_frames,
218
+ seed=cfg.seed,
219
+ num_input_frames=cfg.num_input_frames,
220
+ )
221
+
222
+ # Handle multiple prompts if prompt file is provided
223
+ if cfg.batch_input_path:
224
+ log.info(f"Reading batch inputs from path: {args.batch_input_path}")
225
+ prompts = read_prompts_from_file(cfg.batch_input_path)
226
+ else:
227
+ # Single prompt case
228
+ prompts = [{"prompt": cfg.prompt, "visual_input": cfg.input_image_or_video_path}]
229
+
230
+ os.makedirs(cfg.video_save_folder, exist_ok=True)
231
+ for i, input_dict in enumerate(prompts):
232
+ current_prompt = input_dict.get("prompt", None)
233
+ if current_prompt is None and cfg.disable_prompt_upsampler:
234
+ log.critical("Prompt is missing, skipping world generation.")
235
+ continue
236
+ current_image_or_video_path = input_dict.get("visual_input", None)
237
+ if current_image_or_video_path is None:
238
+ log.critical("Visual input is missing, skipping world generation.")
239
+ continue
240
+
241
+ # Check input frames
242
+ if not check_input_frames(current_image_or_video_path, cfg.num_input_frames):
243
+ continue
244
+
245
+ # Generate video
246
+ generated_output = pipeline.generate(
247
+ prompt=current_prompt,
248
+ image_or_video_path=current_image_or_video_path,
249
+ negative_prompt=cfg.negative_prompt,
250
+ )
251
+ if generated_output is None:
252
+ log.critical("Guardrail blocked video2world generation.")
253
+ continue
254
+ video, prompt = generated_output
255
+
256
+ if cfg.batch_input_path:
257
+ video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4")
258
+ prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt")
259
+ else:
260
+ video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
261
+ prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
262
+
263
+ # Save video
264
+ save_video(
265
+ video=video,
266
+ fps=cfg.fps,
267
+ H=cfg.height,
268
+ W=cfg.width,
269
+ video_save_quality=5,
270
+ video_save_path=video_save_path,
271
+ )
272
+
273
+ # Save prompt to text file alongside video
274
+ with open(prompt_save_path, "wb") as f:
275
+ f.write(prompt.encode("utf-8"))
276
+
277
+ log.info(f"Saved video to {video_save_path}")
278
+ log.info(f"Saved prompt to {prompt_save_path}")
279
+
280
+
281
+ if __name__ == "__main__":
282
+ args = parse_arguments()
283
+ demo(args)
video2world_prompt_upsampler_inference.py ADDED
@@ -0,0 +1,167 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ """
17
+ This demo script is used to run inference for Pixtral-12B.
18
+ Command:
19
+ PYTHONPATH=$(pwd) python cosmos1/models/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py
20
+
21
+ """
22
+
23
+ import argparse
24
+ import os
25
+ from math import ceil
26
+
27
+ from PIL import Image
28
+
29
+ from cosmos1.models.autoregressive.configs.base.model_config import create_vision_language_model_config
30
+ from cosmos1.models.autoregressive.model import AutoRegressiveModel
31
+ from cosmos1.models.diffusion.prompt_upsampler.inference import chat_completion
32
+ from Cosmos import guardrail_presets as guardrail_presets
33
+ from Cosmos.utils import log
34
+ from Cosmos.utils.io import load_from_fileobj
35
+
36
+
37
+ def create_vlm_prompt_upsampler(
38
+ checkpoint_dir: str, tokenizer_ckpt_path: str = "mistral-community/pixtral-12b"
39
+ ) -> AutoRegressiveModel:
40
+ """
41
+ Load the fine-tuned pixtral model for SimReady.
42
+ If pixtral_ckpt is not provided, use the pretrained checkpoint.
43
+ """
44
+ model_ckpt_path = os.path.join(checkpoint_dir, "model.pt")
45
+ model_config, tokenizer_config = create_vision_language_model_config(
46
+ model_ckpt_path=model_ckpt_path,
47
+ tokenizer_ckpt_path=tokenizer_ckpt_path,
48
+ model_family="pixtral",
49
+ model_size="12b",
50
+ is_instruct_model=True,
51
+ max_batch_size=1,
52
+ max_seq_len=4300,
53
+ pytorch_rope_version="v1",
54
+ )
55
+ # during instantiate, the weights will be downloaded (if not already cached) and loaded
56
+ return AutoRegressiveModel.build(
57
+ model_config=model_config,
58
+ tokenizer_config=tokenizer_config,
59
+ ).to("cuda")
60
+
61
+
62
+ def resize_image(image: Image.Image, max_size: int = 1024) -> Image.Image:
63
+ """
64
+ Ensure that the image is no larger than max_size in both dimensions.
65
+ """
66
+ image_width, image_height = image.size
67
+ max_width, max_height = max_size, max_size
68
+ ratio = max(image_width / max_width, image_height / max_height)
69
+ if ratio > 1:
70
+ image = image.resize((ceil(image_width / ratio), ceil(image_height / ratio)))
71
+ return image
72
+
73
+
74
+ def prepare_dialog(image_or_video_path: str) -> list[dict]:
75
+ if image_or_video_path.endswith(".mp4"):
76
+ video_np, _ = load_from_fileobj(image_or_video_path, format="mp4")
77
+ image_frame = video_np[-1]
78
+ image = Image.fromarray(image_frame)
79
+ else:
80
+ image: Image.Image = Image.open(image_or_video_path)
81
+
82
+ image = resize_image(image, max_size=1024)
83
+ prompt = """\
84
+ Your task is to transform a given prompt into a refined and concise video description, no more than 150 words.
85
+ Focus only on the content, no filler words or descriptions on the style. Never mention things outside the video.
86
+ """.strip()
87
+
88
+ return [
89
+ {
90
+ "role": "user",
91
+ "content": "[IMG]\n" + prompt,
92
+ "images": [image],
93
+ }
94
+ ]
95
+
96
+
97
+ def run_chat_completion(pixtral: AutoRegressiveModel, dialog: list[dict], **inference_args) -> str:
98
+ default_args = {
99
+ "max_gen_len": 400,
100
+ "temperature": 0,
101
+ "top_p": 0.9,
102
+ "logprobs": False,
103
+ "compile_sampling": False,
104
+ "compile_prefill": False,
105
+ }
106
+ default_args.update(inference_args)
107
+ results = chat_completion(
108
+ pixtral,
109
+ [dialog],
110
+ **default_args,
111
+ )
112
+ assert len(results) == 1
113
+ upsampled_prompt = str(results[0]["generation"]["content"])
114
+ return upsampled_prompt
115
+
116
+
117
+ def parse_args():
118
+ parser = argparse.ArgumentParser(description="Run prompt upsampler inference")
119
+ parser.add_argument(
120
+ "--image_or_video_path", type=str, default="cosmos1/models/diffusion/assets/v1p0/video2world_input0.jpg"
121
+ )
122
+ parser.add_argument("--temperature", type=float, default=0.01, help="Inference temperature")
123
+ parser.add_argument("--top_p", type=float, default=0.9, help="Top-p value for top-p sampling")
124
+ parser.add_argument(
125
+ "--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints"
126
+ )
127
+ parser.add_argument(
128
+ "--prompt_upsampler_dir",
129
+ type=str,
130
+ default="Pixtral-12B",
131
+ help="Prompt upsampler weights directory relative to checkpoint_dir",
132
+ )
133
+ parser.add_argument(
134
+ "--guardrail_dir",
135
+ type=str,
136
+ default="Cosmos-1.0-Guardrail",
137
+ help="Guardrail weights directory relative to checkpoint_dir",
138
+ )
139
+ return parser.parse_args()
140
+
141
+
142
+ def main(args):
143
+ guardrail_runner = guardrail_presets.create_text_guardrail_runner(
144
+ os.path.join(args.checkpoint_dir, args.guardrail_dir)
145
+ )
146
+
147
+ pixtral = create_vlm_prompt_upsampler(os.path.join(args.checkpoint_dir, args.prompt_upsampler_dir))
148
+ dialog = prepare_dialog(args.image_or_video_path)
149
+ upsampled_prompt = run_chat_completion(
150
+ pixtral,
151
+ dialog,
152
+ max_gen_len=400,
153
+ temperature=args.temperature,
154
+ top_p=args.top_p,
155
+ logprobs=False,
156
+ )
157
+ is_safe = guardrail_presets.run_text_guardrail(upsampled_prompt, guardrail_runner)
158
+ if not is_safe:
159
+ log.critical("Upsampled text prompt is not safe.")
160
+ return
161
+
162
+ log.info(f"Upsampled prompt: {upsampled_prompt}")
163
+
164
+
165
+ if __name__ == "__main__":
166
+ args = parse_args()
167
+ main(args)
world_generation_pipeline.py ADDED
@@ -0,0 +1,658 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
2
+ # SPDX-License-Identifier: Apache-2.0
3
+ #
4
+ # Licensed under the Apache License, Version 2.0 (the "License");
5
+ # you may not use this file except in compliance with the License.
6
+ # You may obtain a copy of the License at
7
+ #
8
+ # http://www.apache.org/licenses/LICENSE-2.0
9
+ #
10
+ # Unless required by applicable law or agreed to in writing, software
11
+ # distributed under the License is distributed on an "AS IS" BASIS,
12
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
13
+ # See the License for the specific language governing permissions and
14
+ # limitations under the License.
15
+
16
+ import gc
17
+ import os
18
+ from typing import Any, Optional
19
+
20
+ import numpy as np
21
+ import torch
22
+
23
+ from Cosmos.base_world_generation_pipeline import BaseWorldGenerationPipeline
24
+ from Cosmos.inference_utils import (
25
+ generate_world_from_text,
26
+ generate_world_from_video,
27
+ get_condition_latent,
28
+ get_video_batch,
29
+ load_model_by_config,
30
+ load_network_model,
31
+ load_tokenizer_model,
32
+ )
33
+ from Cosmos.model_t2w import DiffusionT2WModel
34
+ from Cosmos.model_v2w import DiffusionV2WModel
35
+ from Cosmos.text2world_prompt_upsampler_inference import (
36
+ create_prompt_upsampler,
37
+ run_chat_completion,
38
+ )
39
+ from Cosmos.video2world_prompt_upsampler_inference import (
40
+ create_vlm_prompt_upsampler,
41
+ prepare_dialog,
42
+ )
43
+ from Cosmos.video2world_prompt_upsampler_inference import (
44
+ run_chat_completion as run_chat_completion_vlm,
45
+ )
46
+ from Cosmos.utils import log
47
+
48
+ MODEL_NAME_DICT = {
49
+ "Cosmos-1.0-Diffusion-7B-Text2World": "Cosmos_1_0_Diffusion_Text2World_7B",
50
+ "Cosmos-1.0-Diffusion-14B-Text2World": "Cosmos_1_0_Diffusion_Text2World_14B",
51
+ "Cosmos-1.0-Diffusion-7B-Video2World": "Cosmos_1_0_Diffusion_Video2World_7B",
52
+ "Cosmos-1.0-Diffusion-14B-Video2World": "Cosmos_1_0_Diffusion_Video2World_14B",
53
+ }
54
+
55
+
56
+ class DiffusionText2WorldGenerationPipeline(BaseWorldGenerationPipeline):
57
+ def __init__(
58
+ self,
59
+ inference_type: str,
60
+ checkpoint_dir: str,
61
+ checkpoint_name: str,
62
+ prompt_upsampler_dir: Optional[str] = None,
63
+ enable_prompt_upsampler: bool = True,
64
+ enable_text_guardrail: bool = True,
65
+ enable_video_guardrail: bool = True,
66
+ offload_network: bool = False,
67
+ offload_tokenizer: bool = False,
68
+ offload_text_encoder_model: bool = False,
69
+ offload_prompt_upsampler: bool = False,
70
+ offload_guardrail_models: bool = False,
71
+ guidance: float = 7.0,
72
+ num_steps: int = 35,
73
+ height: int = 704,
74
+ width: int = 1280,
75
+ fps: int = 24,
76
+ num_video_frames: int = 121,
77
+ seed: int = 0,
78
+ ):
79
+ """Initialize the diffusion world generation pipeline.
80
+
81
+ Args:
82
+ inference_type: Type of world generation ('text2world' or 'video2world')
83
+ checkpoint_dir: Base directory containing model checkpoints
84
+ checkpoint_name: Name of the diffusion transformer checkpoint to use
85
+ prompt_upsampler_dir: Directory containing prompt upsampler model weights
86
+ enable_prompt_upsampler: Whether to use prompt upsampling
87
+ enable_text_guardrail: Whether to enable text guardrail
88
+ enable_video_guardrail: Whether to enable video guardrail
89
+ offload_network: Whether to offload diffusion transformer after inference
90
+ offload_tokenizer: Whether to offload tokenizer after inference
91
+ offload_text_encoder_model: Whether to offload T5 model after inference
92
+ offload_prompt_upsampler: Whether to offload prompt upsampler
93
+ offload_guardrail_models: Whether to offload guardrail models
94
+ guidance: Classifier-free guidance scale
95
+ num_steps: Number of diffusion sampling steps
96
+ height: Height of output video
97
+ width: Width of output video
98
+ fps: Frames per second of output video
99
+ num_video_frames: Number of frames to generate
100
+ seed: Random seed for sampling
101
+ """
102
+ assert inference_type in [
103
+ "text2world",
104
+ "video2world",
105
+ ], "Invalid inference_type, must be 'text2world' or 'video2world'"
106
+
107
+ self.model_name = MODEL_NAME_DICT[checkpoint_name]
108
+ self.guidance = guidance
109
+ self.num_steps = num_steps
110
+ self.height = height
111
+ self.width = width
112
+ self.fps = fps
113
+ self.num_video_frames = num_video_frames
114
+ self.seed = seed
115
+
116
+ super().__init__(
117
+ inference_type=inference_type,
118
+ checkpoint_dir=checkpoint_dir,
119
+ checkpoint_name=checkpoint_name,
120
+ enable_text_guardrail=enable_text_guardrail,
121
+ enable_video_guardrail=enable_video_guardrail,
122
+ offload_network=offload_network,
123
+ offload_tokenizer=offload_tokenizer,
124
+ offload_text_encoder_model=offload_text_encoder_model,
125
+ offload_guardrail_models=offload_guardrail_models,
126
+ )
127
+ self.prompt_upsampler_dir = prompt_upsampler_dir
128
+ self.enable_prompt_upsampler = enable_prompt_upsampler
129
+ self.offload_prompt_upsampler = offload_prompt_upsampler
130
+
131
+ self.prompt_upsampler = None
132
+ if enable_prompt_upsampler and not offload_prompt_upsampler:
133
+ self._load_prompt_upsampler_model()
134
+
135
+ def _load_prompt_upsampler_model(self):
136
+ self.prompt_upsampler = create_prompt_upsampler(
137
+ checkpoint_dir=os.path.join(self.checkpoint_dir, self.prompt_upsampler_dir),
138
+ )
139
+
140
+ def _load_model(self):
141
+ self.model = load_model_by_config(
142
+ config_job_name=self.model_name,
143
+ config_file="cosmos1/models/diffusion/config/config.py",
144
+ model_class=DiffusionT2WModel,
145
+ )
146
+
147
+ def _load_network(self):
148
+ load_network_model(self.model, f"{self.checkpoint_dir}/{self.checkpoint_name}/model.pt")
149
+
150
+ def _load_tokenizer(self):
151
+ load_tokenizer_model(self.model, f"{self.checkpoint_dir}/Cosmos-1.0-Tokenizer-CV8x8x8")
152
+
153
+ def _offload_prompt_upsampler_model(self):
154
+ """Move prompt enhancement model to CPU/disk.
155
+
156
+ Offloads prompt upsampling model after processing input
157
+ to reduce GPU memory usage.
158
+ """
159
+ if self.prompt_upsampler:
160
+ del self.prompt_upsampler
161
+ self.prompt_upsampler = None
162
+ gc.collect()
163
+ torch.cuda.empty_cache()
164
+
165
+ def _run_prompt_upsampler_on_prompt(self, prompt: str) -> str:
166
+ """Enhance the input prompt using the prompt upsampler model.
167
+
168
+ Args:
169
+ prompt: Raw text prompt to be enhanced
170
+
171
+ Returns:
172
+ str: Enhanced version of the input prompt with more descriptive details
173
+ """
174
+ upsampled_prompt = run_chat_completion(self.prompt_upsampler, prompt)
175
+ log.info(f"Upsampled prompt: {upsampled_prompt}")
176
+ return upsampled_prompt
177
+
178
+ def _run_prompt_upsampler_on_prompt_with_offload(self, *args: Any, **kwargs: Any) -> str:
179
+ """Enhance prompt with prompt upsampler model.
180
+
181
+ Args:
182
+ *args: Positional arguments
183
+ **kwargs: Keyword arguments
184
+
185
+ Returns:
186
+ Enhanced prompt string
187
+ """
188
+ if self.offload_prompt_upsampler:
189
+ self._load_prompt_upsampler_model()
190
+
191
+ enhanced_prompt = self._run_prompt_upsampler_on_prompt(*args, **kwargs)
192
+
193
+ if self.offload_prompt_upsampler:
194
+ self._offload_prompt_upsampler_model()
195
+
196
+ return enhanced_prompt
197
+
198
+ def _run_tokenizer_decoding(self, sample: torch.Tensor) -> np.ndarray:
199
+ """Decode latent samples to video frames using the tokenizer decoder.
200
+
201
+ Args:
202
+ sample: Latent tensor from diffusion model [B, C, T, H, W]
203
+
204
+ Returns:
205
+ np.ndarray: Decoded video frames as uint8 numpy array [T, H, W, C]
206
+ with values in range [0, 255]
207
+ """
208
+ # Decode video
209
+ video = (1.0 + self.model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W]
210
+ video = (video[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy()
211
+
212
+ return video
213
+
214
+ def _run_model(
215
+ self,
216
+ embedding: torch.Tensor,
217
+ negative_prompt_embedding: Optional[torch.Tensor] = None,
218
+ ) -> torch.Tensor:
219
+ """Generate video latents using the diffusion model.
220
+
221
+ Args:
222
+ embedding: Text embedding tensor from text encoder
223
+ negative_prompt_embedding: Optional embedding for negative prompt guidance
224
+
225
+ Returns:
226
+ torch.Tensor: Generated video latents before tokenizer decoding
227
+
228
+ Note:
229
+ The model and tokenizer are automatically offloaded after inference
230
+ if offloading is enabled in the config.
231
+ """
232
+ # Get video batch and state shape
233
+ data_batch, state_shape = get_video_batch(
234
+ model=self.model,
235
+ prompt_embedding=embedding,
236
+ negative_prompt_embedding=negative_prompt_embedding,
237
+ height=self.height,
238
+ width=self.width,
239
+ fps=self.fps,
240
+ num_video_frames=self.num_video_frames,
241
+ )
242
+
243
+ # Generate video frames
244
+ sample = generate_world_from_text(
245
+ model=self.model,
246
+ state_shape=state_shape,
247
+ is_negative_prompt=True if negative_prompt_embedding is not None else False,
248
+ data_batch=data_batch,
249
+ guidance=self.guidance,
250
+ num_steps=self.num_steps,
251
+ seed=self.seed,
252
+ )
253
+
254
+ return sample
255
+
256
+ def _run_model_with_offload(
257
+ self, prompt_embedding: torch.Tensor, negative_prompt_embedding: Optional[torch.Tensor] = None
258
+ ) -> np.ndarray:
259
+ """Generate world representation with automatic model offloading.
260
+
261
+ Wraps the core generation process with model loading/offloading logic
262
+ to minimize GPU memory usage during inference.
263
+
264
+ Args:
265
+ *args: Positional arguments passed to _run_model
266
+ **kwargs: Keyword arguments passed to _run_model
267
+
268
+ Returns:
269
+ np.ndarray: Generated world representation as numpy array
270
+ """
271
+ if self.offload_network:
272
+ self._load_network()
273
+
274
+ if self.offload_tokenizer:
275
+ self._load_tokenizer()
276
+
277
+ sample = self._run_model(prompt_embedding, negative_prompt_embedding)
278
+
279
+ if self.offload_network:
280
+ self._offload_network()
281
+
282
+ if self.offload_tokenizer:
283
+ self._load_tokenizer()
284
+
285
+ sample = self._run_tokenizer_decoding(sample)
286
+
287
+ if self.offload_tokenizer:
288
+ self._offload_tokenizer()
289
+ return sample
290
+
291
+ def generate(
292
+ self,
293
+ prompt: str,
294
+ negative_prompt: Optional[str] = None,
295
+ word_limit_to_skip_upsampler: Optional[int] = None,
296
+ ) -> tuple[np.ndarray, str] | None:
297
+ """Generate video from text prompt with optional negative prompt guidance.
298
+
299
+ Pipeline steps:
300
+ 1. Run safety checks on input prompt
301
+ 2. Enhance prompt using upsampler if enabled
302
+ 3. Run safety checks on upsampled prompt if applicable
303
+ 4. Convert prompt to embeddings
304
+ 5. Generate video frames using diffusion
305
+ 6. Run safety checks and apply face blur on generated video frames
306
+
307
+ Args:
308
+ prompt: Text description of desired video
309
+ negative_prompt: Optional text to guide what not to generate
310
+ word_limit_to_skip_upsampler: Skip prompt upsampler for better robustness if the number of words in the prompt is greater than this value
311
+ Returns:
312
+ tuple: (
313
+ Generated video frames as uint8 np.ndarray [T, H, W, C],
314
+ Final prompt used for generation (may be enhanced)
315
+ ), or None if content fails guardrail safety checks
316
+ """
317
+ log.info(f"Run with prompt: {prompt}")
318
+ log.info(f"Run with negative prompt: {negative_prompt}")
319
+ log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}")
320
+
321
+ if self.enable_text_guardrail:
322
+ log.info("Run guardrail on prompt")
323
+ is_safe = self._run_guardrail_on_prompt_with_offload(prompt)
324
+ if not is_safe:
325
+ log.critical("Input text prompt is not safe")
326
+ return None
327
+ log.info("Pass guardrail on prompt")
328
+
329
+ # Enhance prompt
330
+ if self.enable_prompt_upsampler:
331
+ word_count = len(prompt.split())
332
+ if word_limit_to_skip_upsampler is None or word_count <= word_limit_to_skip_upsampler:
333
+ log.info("Run prompt upsampler on prompt")
334
+ prompt = self._run_prompt_upsampler_on_prompt_with_offload(prompt)
335
+ if self.enable_text_guardrail:
336
+ log.info("Run guardrail on upsampled prompt")
337
+ is_safe = self._run_guardrail_on_prompt_with_offload(prompt=prompt)
338
+ if not is_safe:
339
+ log.critical("Upsampled text prompt is not safe")
340
+ return None
341
+ log.info("Pass guardrail on upsampled prompt")
342
+ else:
343
+ log.info(
344
+ f"Skip prompt upsampler for better robustness because the number of words ({word_count}) in the prompt is greater than {word_limit_to_skip_upsampler}"
345
+ )
346
+
347
+ log.info("Run text embedding on prompt")
348
+ if negative_prompt:
349
+ prompts = [prompt, negative_prompt]
350
+ else:
351
+ prompts = [prompt]
352
+ prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts)
353
+ prompt_embedding = prompt_embeddings[0]
354
+ negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None
355
+ log.info("Finish text embedding on prompt")
356
+
357
+ # Generate video
358
+ log.info("Run generation")
359
+ video = self._run_model_with_offload(
360
+ prompt_embedding,
361
+ negative_prompt_embedding=negative_prompt_embedding,
362
+ )
363
+ log.info("Finish generation")
364
+
365
+ if self.enable_video_guardrail:
366
+ log.info("Run guardrail on generated video")
367
+ video = self._run_guardrail_on_video_with_offload(video)
368
+ if video is None:
369
+ log.critical("Generated video is not safe")
370
+ return None
371
+ log.info("Pass guardrail on generated video")
372
+
373
+ return video, prompt
374
+
375
+
376
+ class DiffusionVideo2WorldGenerationPipeline(DiffusionText2WorldGenerationPipeline):
377
+ def __init__(
378
+ self,
379
+ inference_type: str,
380
+ checkpoint_dir: str,
381
+ checkpoint_name: str,
382
+ prompt_upsampler_dir: Optional[str] = None,
383
+ enable_prompt_upsampler: bool = True,
384
+ enable_text_guardrail: bool = True,
385
+ enable_video_guardrail: bool = True,
386
+ offload_network: bool = False,
387
+ offload_tokenizer: bool = False,
388
+ offload_text_encoder_model: bool = False,
389
+ offload_prompt_upsampler: bool = False,
390
+ offload_guardrail_models: bool = False,
391
+ guidance: float = 7.0,
392
+ num_steps: int = 35,
393
+ height: int = 704,
394
+ width: int = 1280,
395
+ fps: int = 24,
396
+ num_video_frames: int = 121,
397
+ seed: int = 0,
398
+ num_input_frames: int = 1,
399
+ ):
400
+ """Initialize diffusion world generation pipeline.
401
+
402
+ Args:
403
+ inference_type: Type of world generation ('text2world' or 'video2world')
404
+ checkpoint_dir: Base directory containing model checkpoints
405
+ checkpoint_name: Name of the diffusion transformer checkpoint to use
406
+ prompt_upsampler_dir: Directory containing prompt upsampler model weights
407
+ enable_prompt_upsampler: Whether to use prompt upsampling
408
+ enable_text_guardrail: Whether to enable text guardrail
409
+ enable_video_guardrail: Whether to enable video guardrail
410
+ offload_network: Whether to offload diffusion transformer after inference
411
+ offload_tokenizer: Whether to offload tokenizer after inference
412
+ offload_text_encoder_model: Whether to offload T5 model after inference
413
+ offload_prompt_upsampler: Whether to offload prompt upsampler
414
+ offload_guardrail_models: Whether to offload guardrail models
415
+ guidance: Classifier-free guidance scale
416
+ num_steps: Number of diffusion sampling steps
417
+ height: Height of output video
418
+ width: Width of output video
419
+ fps: Frames per second of output video
420
+ num_video_frames: Number of frames to generate
421
+ seed: Random seed for sampling
422
+ num_input_frames: Number of latent conditions
423
+ """
424
+ self.num_input_frames = num_input_frames
425
+ super().__init__(
426
+ inference_type=inference_type,
427
+ checkpoint_dir=checkpoint_dir,
428
+ checkpoint_name=checkpoint_name,
429
+ prompt_upsampler_dir=prompt_upsampler_dir,
430
+ enable_prompt_upsampler=enable_prompt_upsampler,
431
+ enable_text_guardrail=enable_text_guardrail,
432
+ enable_video_guardrail=enable_video_guardrail,
433
+ offload_network=offload_network,
434
+ offload_tokenizer=offload_tokenizer,
435
+ offload_text_encoder_model=offload_text_encoder_model,
436
+ offload_prompt_upsampler=offload_prompt_upsampler,
437
+ offload_guardrail_models=offload_guardrail_models,
438
+ guidance=guidance,
439
+ num_steps=num_steps,
440
+ height=height,
441
+ width=width,
442
+ fps=fps,
443
+ num_video_frames=num_video_frames,
444
+ seed=seed,
445
+ )
446
+
447
+ def _run_prompt_upsampler_on_prompt(self, image_or_video_path: str) -> str:
448
+ """Enhance the input prompt using visual context from the conditioning image.
449
+
450
+ Args:
451
+ image_or_video_path: Path to conditioning image or video used for visual context
452
+
453
+ Returns:
454
+ str: Enhanced prompt incorporating visual details from the image
455
+ """
456
+ dialog = prepare_dialog(image_or_video_path)
457
+ upsampled_prompt = run_chat_completion_vlm(
458
+ self.prompt_upsampler, dialog, max_gen_len=400, temperature=0.01, top_p=0.9, logprobs=False
459
+ )
460
+ log.info(f"Upsampled prompt: {upsampled_prompt}")
461
+ return upsampled_prompt
462
+
463
+ def _load_prompt_upsampler_model(self):
464
+ self.prompt_upsampler = create_vlm_prompt_upsampler(
465
+ checkpoint_dir=os.path.join(self.checkpoint_dir, self.prompt_upsampler_dir),
466
+ )
467
+
468
+ def _load_model(self):
469
+ self.model = load_model_by_config(
470
+ config_job_name=self.model_name,
471
+ config_file="cosmos1/models/diffusion/config/config.py",
472
+ model_class=DiffusionV2WModel,
473
+ )
474
+
475
+ def _run_model(
476
+ self,
477
+ embedding: torch.Tensor,
478
+ condition_latent: torch.Tensor,
479
+ negative_prompt_embedding: torch.Tensor | None = None,
480
+ ) -> torch.Tensor:
481
+ """Generate video frames using the diffusion model.
482
+
483
+ Args:
484
+ embedding: Text embedding tensor from T5 encoder
485
+ condition_latent: Latent tensor from conditioning image or video
486
+ negative_prompt_embedding: Optional embedding for negative prompt guidance
487
+
488
+ Returns:
489
+ Tensor of generated video frames
490
+
491
+ Note:
492
+ Model and tokenizer are automatically offloaded after inference
493
+ if offloading is enabled.
494
+ """
495
+ # Get video batch and state shape
496
+ data_batch, state_shape = get_video_batch(
497
+ model=self.model,
498
+ prompt_embedding=embedding,
499
+ negative_prompt_embedding=negative_prompt_embedding,
500
+ height=self.height,
501
+ width=self.width,
502
+ fps=self.fps,
503
+ num_video_frames=self.num_video_frames,
504
+ )
505
+
506
+ # Generate video frames
507
+ video = generate_world_from_video(
508
+ model=self.model,
509
+ state_shape=self.model.state_shape,
510
+ is_negative_prompt=True,
511
+ data_batch=data_batch,
512
+ guidance=self.guidance,
513
+ num_steps=self.num_steps,
514
+ seed=self.seed,
515
+ condition_latent=condition_latent,
516
+ num_input_frames=self.num_input_frames,
517
+ )
518
+
519
+ return video
520
+
521
+ def _run_tokenizer_encoding(self, image_or_video_path: str) -> torch.Tensor:
522
+ """
523
+ Encode image to latent space
524
+
525
+ Args:
526
+ image_or_video_path: Path to conditioning image
527
+
528
+ Returns:
529
+ torch.Tensor: Latent tensor from tokenizer encoding
530
+ """
531
+ condition_latent = get_condition_latent(
532
+ model=self.model,
533
+ input_image_or_video_path=image_or_video_path,
534
+ num_input_frames=self.num_input_frames,
535
+ state_shape=self.model.state_shape,
536
+ )
537
+
538
+ return condition_latent
539
+
540
+ def _run_model_with_offload(
541
+ self,
542
+ prompt_embedding: torch.Tensor,
543
+ image_or_video_path: str,
544
+ negative_prompt_embedding: Optional[torch.Tensor] = None,
545
+ ) -> np.ndarray:
546
+ """Generate world representation with automatic model offloading.
547
+
548
+ Wraps the core generation process with model loading/offloading logic
549
+ to minimize GPU memory usage during inference.
550
+
551
+ Args:
552
+ prompt_embedding: Text embedding tensor from T5 encoder
553
+ image_or_video_path: Path to conditioning image or video
554
+ negative_prompt_embedding: Optional embedding for negative prompt guidance
555
+
556
+ Returns:
557
+ np.ndarray: Generated world representation as numpy array
558
+ """
559
+ if self.offload_tokenizer:
560
+ self._load_tokenizer()
561
+
562
+ condition_latent = self._run_tokenizer_encoding(image_or_video_path)
563
+
564
+ if self.offload_network:
565
+ self._load_network()
566
+
567
+ sample = self._run_model(prompt_embedding, condition_latent, negative_prompt_embedding)
568
+
569
+ if self.offload_network:
570
+ self._offload_network()
571
+
572
+ sample = self._run_tokenizer_decoding(sample)
573
+
574
+ if self.offload_tokenizer:
575
+ self._offload_tokenizer()
576
+
577
+ return sample
578
+
579
+ def generate(
580
+ self,
581
+ prompt: str,
582
+ image_or_video_path: str,
583
+ negative_prompt: Optional[str] = None,
584
+ ) -> tuple[np.ndarray, str] | None:
585
+ """Generate video from text prompt and optional image.
586
+
587
+ Pipeline steps:
588
+ 1. Run safety checks on input prompt
589
+ 2. Enhance prompt using upsampler if enabled
590
+ 3. Run safety checks on upsampled prompt if applicable
591
+ 4. Convert prompt to embeddings
592
+ 5. Generate video frames using diffusion
593
+ 6. Run safety checks and apply face blur on generated video frames
594
+
595
+ Args:
596
+ prompt: Text description of desired video
597
+ image_or_video_path: Path to conditioning image or video
598
+ negative_prompt: Optional text to guide what not to generate
599
+
600
+ Returns:
601
+ tuple: (
602
+ Generated video frames as uint8 np.ndarray [T, H, W, C],
603
+ Final prompt used for generation (may be enhanced)
604
+ ), or None if content fails guardrail safety checks
605
+ """
606
+ log.info(f"Run with prompt: {prompt}")
607
+ log.info(f"Run with image or video path: {image_or_video_path}")
608
+ log.info(f"Run with negative prompt: {negative_prompt}")
609
+ log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}")
610
+
611
+ if self.enable_text_guardrail and not self.enable_prompt_upsampler:
612
+ log.info("Run guardrail on prompt")
613
+ is_safe = self._run_guardrail_on_prompt_with_offload(prompt)
614
+ if not is_safe:
615
+ log.critical("Input text prompt is not safe")
616
+ return None
617
+ log.info("Pass guardrail on prompt")
618
+
619
+ # Enhance prompt
620
+ if self.enable_prompt_upsampler:
621
+ log.info("Run prompt upsampler on image or video, input prompt is not used")
622
+ prompt = self._run_prompt_upsampler_on_prompt_with_offload(image_or_video_path=image_or_video_path)
623
+ if self.enable_text_guardrail:
624
+ log.info("Run guardrail on upsampled prompt")
625
+ is_safe = self._run_guardrail_on_prompt_with_offload(prompt)
626
+ if not is_safe:
627
+ log.critical("Upsampled text prompt is not safe")
628
+ return None
629
+ log.info("Pass guardrail on upsampled prompt")
630
+
631
+ log.info("Run text embedding on prompt")
632
+ if negative_prompt:
633
+ prompts = [prompt, negative_prompt]
634
+ else:
635
+ prompts = [prompt]
636
+ prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts)
637
+ prompt_embedding = prompt_embeddings[0]
638
+ negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None
639
+ log.info("Finish text embedding on prompt")
640
+
641
+ # Generate video
642
+ log.info("Run generation")
643
+ video = self._run_model_with_offload(
644
+ prompt_embedding,
645
+ negative_prompt_embedding=negative_prompt_embedding,
646
+ image_or_video_path=image_or_video_path,
647
+ )
648
+ log.info("Finish generation")
649
+
650
+ if self.enable_video_guardrail:
651
+ log.info("Run guardrail on generated video")
652
+ video = self._run_guardrail_on_video_with_offload(video)
653
+ if video is None:
654
+ log.critical("Generated video is not safe")
655
+ return None
656
+ log.info("Pass guardrail on generated video")
657
+
658
+ return video, prompt