Upload modified files
Browse files- base_world_generation_pipeline.py +362 -0
- conditioner.py +323 -0
- config.json +10 -0
- convert_pixtral_ckpt.py +209 -0
- download_diffusion.py +123 -0
- guardrail_presets.py +77 -0
- inference_utils.py +726 -0
- model_t2w.py +282 -0
- model_v2w.py +341 -0
- t5_text_encoder.py +108 -0
- text2world.py +161 -0
- text2world_prompt_upsampler_inference.py +157 -0
- types.py +28 -0
- video2world.py +179 -0
- video2world_hf.py +283 -0
- video2world_prompt_upsampler_inference.py +167 -0
- world_generation_pipeline.py +658 -0
base_world_generation_pipeline.py
ADDED
|
@@ -0,0 +1,362 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import gc
|
| 17 |
+
import os
|
| 18 |
+
from abc import ABC
|
| 19 |
+
from typing import Any
|
| 20 |
+
|
| 21 |
+
import numpy as np
|
| 22 |
+
import torch
|
| 23 |
+
|
| 24 |
+
from Cosmos.t5_text_encoder import CosmosT5TextEncoder
|
| 25 |
+
from Cosmos import guardrail_presets as guardrail_presets
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
class BaseWorldGenerationPipeline(ABC):
|
| 29 |
+
def __init__(
|
| 30 |
+
self,
|
| 31 |
+
inference_type: str | None = None,
|
| 32 |
+
checkpoint_dir: str | None = None,
|
| 33 |
+
checkpoint_name: str | None = None,
|
| 34 |
+
enable_text_guardrail: bool = False,
|
| 35 |
+
enable_video_guardrail: bool = False,
|
| 36 |
+
offload_network: bool = False,
|
| 37 |
+
offload_tokenizer: bool = False,
|
| 38 |
+
offload_text_encoder_model: bool = False,
|
| 39 |
+
offload_guardrail_models: bool = False,
|
| 40 |
+
):
|
| 41 |
+
"""Initialize base world generation pipeline.
|
| 42 |
+
|
| 43 |
+
This abstract base class provides core functionality for world generation models including:
|
| 44 |
+
- Model loading and initialization
|
| 45 |
+
- Text encoding and embedding
|
| 46 |
+
- Safety checks and content filtering
|
| 47 |
+
- Memory management through model offloading
|
| 48 |
+
|
| 49 |
+
Args:
|
| 50 |
+
inference_type: The type of inference pipeline ("text2world" or "video2world")
|
| 51 |
+
checkpoint_dir: Root directory containing model checkpoints
|
| 52 |
+
checkpoint_name: Name of the specific checkpoint file to load
|
| 53 |
+
enable_text_guardrail: If True, validates input prompts for safety
|
| 54 |
+
enable_video_guardrail: If True, validates generated videos for safety
|
| 55 |
+
offload_network: If True, moves main model to CPU after inference
|
| 56 |
+
offload_tokenizer: If True, moves tokenizer to CPU after use
|
| 57 |
+
offload_text_encoder_model: If True, moves T5 encoder to CPU after encoding
|
| 58 |
+
offload_guardrail_models: If True, moves safety models to CPU after checks
|
| 59 |
+
"""
|
| 60 |
+
self.inference_type = inference_type
|
| 61 |
+
self.checkpoint_dir = checkpoint_dir
|
| 62 |
+
self.checkpoint_name = checkpoint_name
|
| 63 |
+
self.guardrail_dir = "Cosmos-1.0-Guardrail"
|
| 64 |
+
self.enable_text_guardrail = enable_text_guardrail
|
| 65 |
+
self.enable_video_guardrail = enable_video_guardrail
|
| 66 |
+
|
| 67 |
+
# Add offloading flags
|
| 68 |
+
self.offload_network = offload_network
|
| 69 |
+
self.offload_tokenizer = offload_tokenizer
|
| 70 |
+
self.offload_text_encoder_model = offload_text_encoder_model
|
| 71 |
+
self.offload_guardrail_models = offload_guardrail_models
|
| 72 |
+
|
| 73 |
+
# Initialize model instances
|
| 74 |
+
self.text_guardrail = None
|
| 75 |
+
self.video_guardrail = None
|
| 76 |
+
self.text_encoder = None
|
| 77 |
+
self.model = None
|
| 78 |
+
|
| 79 |
+
self._load_model()
|
| 80 |
+
|
| 81 |
+
if not self.offload_text_encoder_model:
|
| 82 |
+
self._load_text_encoder_model()
|
| 83 |
+
if not self.offload_guardrail_models:
|
| 84 |
+
if self.enable_text_guardrail:
|
| 85 |
+
self._load_text_guardrail()
|
| 86 |
+
if self.enable_video_guardrail:
|
| 87 |
+
self._load_video_guardrail()
|
| 88 |
+
if not self.offload_network:
|
| 89 |
+
self._load_network()
|
| 90 |
+
if not self.offload_tokenizer:
|
| 91 |
+
self._load_tokenizer()
|
| 92 |
+
|
| 93 |
+
def _load_tokenizer(self):
|
| 94 |
+
pass
|
| 95 |
+
|
| 96 |
+
def _load_network(self):
|
| 97 |
+
pass
|
| 98 |
+
|
| 99 |
+
def _load_model(self, checkpoint_name: str) -> Any:
|
| 100 |
+
"""Load the world generation model from a checkpoint.
|
| 101 |
+
|
| 102 |
+
This abstract method must be implemented by subclasses to load their specific
|
| 103 |
+
model architecture and weights.
|
| 104 |
+
|
| 105 |
+
Args:
|
| 106 |
+
checkpoint_name: Path to the model checkpoint file
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
The loaded model instance
|
| 110 |
+
|
| 111 |
+
Raises:
|
| 112 |
+
NotImplementedError: Must be implemented by subclasses
|
| 113 |
+
"""
|
| 114 |
+
pass
|
| 115 |
+
|
| 116 |
+
def _load_text_encoder_model(self):
|
| 117 |
+
"""Load the T5 text encoder model.
|
| 118 |
+
|
| 119 |
+
Initializes and loads the T5 encoder model used for converting text prompts
|
| 120 |
+
into embeddings that condition the world generation model.
|
| 121 |
+
|
| 122 |
+
Returns:
|
| 123 |
+
Loaded T5 text encoder model instance
|
| 124 |
+
"""
|
| 125 |
+
self.text_encoder = CosmosT5TextEncoder(cache_dir=self.checkpoint_dir)
|
| 126 |
+
|
| 127 |
+
def _load_text_guardrail(self):
|
| 128 |
+
"""Load text safety classifier models.
|
| 129 |
+
|
| 130 |
+
Initializes models used for checking input prompts against safety policies.
|
| 131 |
+
Models are loaded from the specified guardrail directory.
|
| 132 |
+
"""
|
| 133 |
+
self.text_guardrail = guardrail_presets.create_text_guardrail_runner(
|
| 134 |
+
checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir)
|
| 135 |
+
)
|
| 136 |
+
|
| 137 |
+
def _load_video_guardrail(self):
|
| 138 |
+
"""Load video safety classifier models.
|
| 139 |
+
|
| 140 |
+
Initializes models used for validating generated video content against
|
| 141 |
+
safety policies. Models are loaded from the specified guardrail directory.
|
| 142 |
+
"""
|
| 143 |
+
self.video_guardrail = guardrail_presets.create_video_guardrail_runner(
|
| 144 |
+
checkpoint_dir=os.path.join(self.checkpoint_dir, self.guardrail_dir)
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def _offload_network(self):
|
| 148 |
+
if self.model.model:
|
| 149 |
+
del self.model.model
|
| 150 |
+
self.model.model = None
|
| 151 |
+
gc.collect()
|
| 152 |
+
torch.cuda.empty_cache()
|
| 153 |
+
|
| 154 |
+
def _offload_tokenizer(self):
|
| 155 |
+
if self.model.tokenizer:
|
| 156 |
+
del self.model.tokenizer
|
| 157 |
+
self.model.tokenizer = None
|
| 158 |
+
gc.collect()
|
| 159 |
+
torch.cuda.empty_cache()
|
| 160 |
+
|
| 161 |
+
def _offload_guardrail_models(self):
|
| 162 |
+
"""Offload safety classifier models to reduce memory usage.
|
| 163 |
+
|
| 164 |
+
Moves safety models to CPU and clears GPU memory if they are no longer needed.
|
| 165 |
+
This helps manage memory when processing multiple inputs sequentially.
|
| 166 |
+
"""
|
| 167 |
+
if self.text_guardrail:
|
| 168 |
+
del self.text_guardrail
|
| 169 |
+
self.text_guardrail = None
|
| 170 |
+
if self.video_guardrail:
|
| 171 |
+
del self.video_guardrail
|
| 172 |
+
self.video_guardrail = None
|
| 173 |
+
gc.collect()
|
| 174 |
+
torch.cuda.empty_cache()
|
| 175 |
+
|
| 176 |
+
def _offload_text_encoder_model(self):
|
| 177 |
+
"""Offload T5 text encoder to reduce memory usage.
|
| 178 |
+
|
| 179 |
+
Moves the T5 encoder to CPU and clears GPU memory after text encoding is complete.
|
| 180 |
+
This helps manage memory when processing multiple inputs sequentially.
|
| 181 |
+
"""
|
| 182 |
+
if self.text_encoder:
|
| 183 |
+
del self.text_encoder
|
| 184 |
+
self.text_encoder = None
|
| 185 |
+
gc.collect()
|
| 186 |
+
torch.cuda.empty_cache()
|
| 187 |
+
|
| 188 |
+
def _run_model(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
| 189 |
+
"""Generate world latents using the model.
|
| 190 |
+
|
| 191 |
+
This abstract method must be implemented by subclasses to define their specific
|
| 192 |
+
generation process.
|
| 193 |
+
|
| 194 |
+
Args:
|
| 195 |
+
*args: Variable positional arguments for model inference
|
| 196 |
+
**kwargs: Variable keyword arguments for model inference
|
| 197 |
+
|
| 198 |
+
Returns:
|
| 199 |
+
torch.Tensor: Generated world representation tensor
|
| 200 |
+
"""
|
| 201 |
+
pass
|
| 202 |
+
|
| 203 |
+
def _run_model_with_offload(self, *args: Any, **kwargs: Any) -> torch.Tensor:
|
| 204 |
+
"""Generate world representation with memory management.
|
| 205 |
+
|
| 206 |
+
Handles loading the model before inference and offloading afterward if enabled.
|
| 207 |
+
This helps minimize GPU memory usage during inference.
|
| 208 |
+
|
| 209 |
+
Args:
|
| 210 |
+
*args: Arguments passed to _run_model
|
| 211 |
+
**kwargs: Keyword arguments passed to _run_model
|
| 212 |
+
|
| 213 |
+
Returns:
|
| 214 |
+
np.ndarray: Generated world representation as numpy array
|
| 215 |
+
"""
|
| 216 |
+
pass
|
| 217 |
+
|
| 218 |
+
def _run_guardrail_on_prompt(self, prompt: str) -> bool:
|
| 219 |
+
"""Check if prompt meets safety requirements.
|
| 220 |
+
|
| 221 |
+
Validates the input prompt against safety policies using loaded guardrail models.
|
| 222 |
+
|
| 223 |
+
Args:
|
| 224 |
+
prompt: Raw text prompt to validate
|
| 225 |
+
|
| 226 |
+
Returns:
|
| 227 |
+
bool: True if prompt passes all safety checks, False otherwise
|
| 228 |
+
"""
|
| 229 |
+
return guardrail_presets.run_text_guardrail(prompt, self.text_guardrail)
|
| 230 |
+
|
| 231 |
+
def _run_guardrail_on_prompt_with_offload(self, prompt: str) -> bool:
|
| 232 |
+
"""Check prompt safety with memory management.
|
| 233 |
+
|
| 234 |
+
Validates prompt safety while handling model loading/offloading to manage memory.
|
| 235 |
+
|
| 236 |
+
Args:
|
| 237 |
+
prompt: Raw text prompt to validate
|
| 238 |
+
|
| 239 |
+
Returns:
|
| 240 |
+
bool: True if prompt passes all safety checks, False otherwise
|
| 241 |
+
"""
|
| 242 |
+
if self.offload_guardrail_models:
|
| 243 |
+
self._load_text_guardrail()
|
| 244 |
+
|
| 245 |
+
is_safe = self._run_guardrail_on_prompt(prompt)
|
| 246 |
+
|
| 247 |
+
if self.offload_guardrail_models:
|
| 248 |
+
self._offload_guardrail_models()
|
| 249 |
+
|
| 250 |
+
return is_safe
|
| 251 |
+
|
| 252 |
+
def _run_guardrail_on_video(self, video: np.ndarray) -> np.ndarray | None:
|
| 253 |
+
"""Check if video meets safety requirements.
|
| 254 |
+
|
| 255 |
+
Validates generated video content against safety policies using guardrail models.
|
| 256 |
+
|
| 257 |
+
Args:
|
| 258 |
+
video: Video frames to validate
|
| 259 |
+
|
| 260 |
+
Returns:
|
| 261 |
+
np.ndarray: Processed video if safe, None if unsafe
|
| 262 |
+
"""
|
| 263 |
+
return guardrail_presets.run_video_guardrail(video, self.video_guardrail)
|
| 264 |
+
|
| 265 |
+
def _run_guardrail_on_video_with_offload(self, video: np.ndarray) -> np.ndarray | None:
|
| 266 |
+
"""Check if generated video meets safety requirements.
|
| 267 |
+
|
| 268 |
+
Args:
|
| 269 |
+
video: Video frames to validate
|
| 270 |
+
|
| 271 |
+
Returns:
|
| 272 |
+
np.ndarray: Processed video frames if safe, None otherwise
|
| 273 |
+
|
| 274 |
+
Note:
|
| 275 |
+
Guardrail models are offloaded after checks if enabled.
|
| 276 |
+
"""
|
| 277 |
+
if self.offload_guardrail_models:
|
| 278 |
+
self._load_video_guardrail()
|
| 279 |
+
|
| 280 |
+
video = self._run_guardrail_on_video(video)
|
| 281 |
+
|
| 282 |
+
if self.offload_guardrail_models:
|
| 283 |
+
self._offload_guardrail_models()
|
| 284 |
+
return video
|
| 285 |
+
|
| 286 |
+
def _run_text_embedding_on_prompt(
|
| 287 |
+
self, prompts: list[str], **kwargs: Any
|
| 288 |
+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
| 289 |
+
"""Convert text prompts to embeddings.
|
| 290 |
+
|
| 291 |
+
Processes text prompts into embedding tensors that condition the generation model.
|
| 292 |
+
|
| 293 |
+
Args:
|
| 294 |
+
prompts: List of text prompts to encode
|
| 295 |
+
**kwargs: Additional arguments for text encoding
|
| 296 |
+
|
| 297 |
+
Returns:
|
| 298 |
+
tuple containing:
|
| 299 |
+
- List of text embedding tensors for each prompt
|
| 300 |
+
- List of attention masks for each embedding
|
| 301 |
+
"""
|
| 302 |
+
|
| 303 |
+
embeddings = []
|
| 304 |
+
masks = []
|
| 305 |
+
for prompt in prompts:
|
| 306 |
+
embedding, mask = self.text_encoder.encode_prompts(
|
| 307 |
+
[prompt],
|
| 308 |
+
**kwargs,
|
| 309 |
+
)
|
| 310 |
+
embeddings.append(embedding)
|
| 311 |
+
masks.append(mask)
|
| 312 |
+
|
| 313 |
+
return embeddings, masks
|
| 314 |
+
|
| 315 |
+
def _run_text_embedding_on_prompt_with_offload(
|
| 316 |
+
self, prompts: list[str], **kwargs: Any
|
| 317 |
+
) -> tuple[list[torch.Tensor], list[torch.Tensor]]:
|
| 318 |
+
"""Convert text prompt into embeddings using T5 encoder.
|
| 319 |
+
|
| 320 |
+
Args:
|
| 321 |
+
prompt: Processed and validated text prompt
|
| 322 |
+
|
| 323 |
+
Returns:
|
| 324 |
+
Text embedding tensor to condition diffusion model
|
| 325 |
+
|
| 326 |
+
Note:
|
| 327 |
+
T5 model is offloaded after encoding if enabled.
|
| 328 |
+
"""
|
| 329 |
+
if self.offload_text_encoder_model:
|
| 330 |
+
self._load_text_encoder_model()
|
| 331 |
+
|
| 332 |
+
embeddings, masks = self._run_text_embedding_on_prompt(prompts, **kwargs)
|
| 333 |
+
|
| 334 |
+
if self.offload_text_encoder_model:
|
| 335 |
+
self._offload_text_encoder_model()
|
| 336 |
+
return embeddings, masks
|
| 337 |
+
|
| 338 |
+
def _run_tokenizer_decoding(self, samples: torch.Tensor) -> np.ndarray:
|
| 339 |
+
"""Decode model outputs into final world representation.
|
| 340 |
+
|
| 341 |
+
This abstract method must be implemented by subclasses to convert raw model
|
| 342 |
+
outputs into their specific world representation format.
|
| 343 |
+
|
| 344 |
+
Args:
|
| 345 |
+
samples: Raw output tensor from the generation model
|
| 346 |
+
|
| 347 |
+
Returns:
|
| 348 |
+
np.ndarray: Decoded world representation
|
| 349 |
+
"""
|
| 350 |
+
pass
|
| 351 |
+
|
| 352 |
+
def generate(self, *args: Any, **kwargs: Any):
|
| 353 |
+
"""Generate world representation.
|
| 354 |
+
|
| 355 |
+
This abstract method must be implemented by subclasses to convert raw model
|
| 356 |
+
outputs into their specific world representation format.
|
| 357 |
+
|
| 358 |
+
Args:
|
| 359 |
+
*args: Variable positional arguments for model inference
|
| 360 |
+
**kwargs: Variable keyword arguments for model inference
|
| 361 |
+
"""
|
| 362 |
+
pass
|
conditioner.py
ADDED
|
@@ -0,0 +1,323 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import copy
|
| 17 |
+
from abc import ABC, abstractmethod
|
| 18 |
+
from collections import defaultdict
|
| 19 |
+
from dataclasses import dataclass, fields
|
| 20 |
+
from enum import Enum
|
| 21 |
+
from typing import Any, Dict, List, Optional, Tuple, Union
|
| 22 |
+
|
| 23 |
+
import torch
|
| 24 |
+
import torch.nn as nn
|
| 25 |
+
|
| 26 |
+
from cosmos1.models.diffusion.diffusion.functional.batch_ops import batch_mul
|
| 27 |
+
from Cosmos.utils import log
|
| 28 |
+
from Cosmos.lazy_config import instantiate
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
class BaseConditionEntry(nn.Module):
|
| 32 |
+
def __init__(self):
|
| 33 |
+
super().__init__()
|
| 34 |
+
|
| 35 |
+
self._dropout_rate = None
|
| 36 |
+
self._input_key = None
|
| 37 |
+
self._return_dict = False
|
| 38 |
+
|
| 39 |
+
@property
|
| 40 |
+
def dropout_rate(self) -> Union[float, torch.Tensor]:
|
| 41 |
+
return self._dropout_rate
|
| 42 |
+
|
| 43 |
+
@property
|
| 44 |
+
def input_key(self) -> str:
|
| 45 |
+
return self._input_key
|
| 46 |
+
|
| 47 |
+
@property
|
| 48 |
+
def is_return_dict(self) -> bool:
|
| 49 |
+
return self._return_dict
|
| 50 |
+
|
| 51 |
+
@dropout_rate.setter
|
| 52 |
+
def dropout_rate(self, value: Union[float, torch.Tensor]):
|
| 53 |
+
self._dropout_rate = value
|
| 54 |
+
|
| 55 |
+
@input_key.setter
|
| 56 |
+
def input_key(self, value: str):
|
| 57 |
+
self._input_key = value
|
| 58 |
+
|
| 59 |
+
@is_return_dict.setter
|
| 60 |
+
def is_return_dict(self, value: bool):
|
| 61 |
+
self._return_dict = value
|
| 62 |
+
|
| 63 |
+
@dropout_rate.deleter
|
| 64 |
+
def dropout_rate(self):
|
| 65 |
+
del self._dropout_rate
|
| 66 |
+
|
| 67 |
+
@input_key.deleter
|
| 68 |
+
def input_key(self):
|
| 69 |
+
del self._input_key
|
| 70 |
+
|
| 71 |
+
@is_return_dict.deleter
|
| 72 |
+
def is_return_dict(self):
|
| 73 |
+
del self._return_dict
|
| 74 |
+
|
| 75 |
+
def random_dropout_input(
|
| 76 |
+
self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None
|
| 77 |
+
) -> torch.Tensor:
|
| 78 |
+
del key
|
| 79 |
+
dropout_rate = dropout_rate if dropout_rate is not None else self.dropout_rate
|
| 80 |
+
return batch_mul(
|
| 81 |
+
torch.bernoulli((1.0 - dropout_rate) * torch.ones(in_tensor.shape[0])).type_as(in_tensor),
|
| 82 |
+
in_tensor,
|
| 83 |
+
)
|
| 84 |
+
|
| 85 |
+
def summary(self) -> str:
|
| 86 |
+
pass
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
class DataType(Enum):
|
| 90 |
+
IMAGE = "image"
|
| 91 |
+
VIDEO = "video"
|
| 92 |
+
|
| 93 |
+
|
| 94 |
+
class TextAttr(BaseConditionEntry):
|
| 95 |
+
def __init__(self):
|
| 96 |
+
super().__init__()
|
| 97 |
+
|
| 98 |
+
def forward(self, token: torch.Tensor, mask: torch.Tensor):
|
| 99 |
+
return {"crossattn_emb": token, "crossattn_mask": mask}
|
| 100 |
+
|
| 101 |
+
def random_dropout_input(
|
| 102 |
+
self, in_tensor: torch.Tensor, dropout_rate: Optional[float] = None, key: Optional[str] = None
|
| 103 |
+
) -> torch.Tensor:
|
| 104 |
+
if key is not None and "mask" in key:
|
| 105 |
+
return in_tensor
|
| 106 |
+
return super().random_dropout_input(in_tensor, dropout_rate, key)
|
| 107 |
+
|
| 108 |
+
|
| 109 |
+
@dataclass
|
| 110 |
+
class BaseVideoCondition:
|
| 111 |
+
crossattn_emb: torch.Tensor
|
| 112 |
+
crossattn_mask: torch.Tensor
|
| 113 |
+
data_type: DataType = DataType.VIDEO
|
| 114 |
+
padding_mask: Optional[torch.Tensor] = None
|
| 115 |
+
fps: Optional[torch.Tensor] = None
|
| 116 |
+
num_frames: Optional[torch.Tensor] = None
|
| 117 |
+
image_size: Optional[torch.Tensor] = None
|
| 118 |
+
scalar_feature: Optional[torch.Tensor] = None
|
| 119 |
+
|
| 120 |
+
def to_dict(self) -> Dict[str, Optional[torch.Tensor]]:
|
| 121 |
+
return {f.name: getattr(self, f.name) for f in fields(self)}
|
| 122 |
+
|
| 123 |
+
|
| 124 |
+
@dataclass
|
| 125 |
+
class VideoExtendCondition(BaseVideoCondition):
|
| 126 |
+
video_cond_bool: Optional[torch.Tensor] = None # whether or not it conditioned on video
|
| 127 |
+
gt_latent: Optional[torch.Tensor] = None
|
| 128 |
+
condition_video_indicator: Optional[torch.Tensor] = None # 1 for condition region
|
| 129 |
+
|
| 130 |
+
# condition_video_input_mask will concat to the input of network, along channel dim;
|
| 131 |
+
# Will be concat with the input tensor
|
| 132 |
+
condition_video_input_mask: Optional[torch.Tensor] = None
|
| 133 |
+
# condition_video_augment_sigma: (B, T) tensor of sigma value for the conditional input augmentation, only valid when apply_corruption_to_condition_region is "noise_with_sigma" or "noise_with_sigma_fixed"
|
| 134 |
+
condition_video_augment_sigma: Optional[torch.Tensor] = None
|
| 135 |
+
|
| 136 |
+
|
| 137 |
+
class GeneralConditioner(nn.Module, ABC):
|
| 138 |
+
"""
|
| 139 |
+
An abstract module designed to handle various embedding models with conditional and
|
| 140 |
+
unconditional configurations. This abstract base class initializes and manages a collection
|
| 141 |
+
of embedders that can dynamically adjust their dropout rates based on conditioning.
|
| 142 |
+
|
| 143 |
+
Attributes:
|
| 144 |
+
KEY2DIM (dict): A mapping from output keys to dimensions used for concatenation.
|
| 145 |
+
embedders (nn.ModuleDict): A dictionary containing all embedded models initialized and
|
| 146 |
+
configured based on the provided configurations.
|
| 147 |
+
|
| 148 |
+
Parameters:
|
| 149 |
+
emb_models (Union[List, Any]): A dictionary where keys are embedder names and values
|
| 150 |
+
are configurations for initializing the embedders.
|
| 151 |
+
|
| 152 |
+
"""
|
| 153 |
+
|
| 154 |
+
KEY2DIM = {"crossattn_emb": 1, "crossattn_mask": 1}
|
| 155 |
+
|
| 156 |
+
def __init__(self, **emb_models: Union[List, Any]):
|
| 157 |
+
super().__init__()
|
| 158 |
+
self.embedders = nn.ModuleDict()
|
| 159 |
+
for n, (emb_name, embconfig) in enumerate(emb_models.items()):
|
| 160 |
+
embedder = instantiate(embconfig.obj)
|
| 161 |
+
assert isinstance(
|
| 162 |
+
embedder, BaseConditionEntry
|
| 163 |
+
), f"embedder model {embedder.__class__.__name__} has to inherit from AbstractEmbModel"
|
| 164 |
+
embedder.dropout_rate = getattr(embconfig, "dropout_rate", 0.0)
|
| 165 |
+
|
| 166 |
+
if hasattr(embconfig, "input_key"):
|
| 167 |
+
embedder.input_key = embconfig.input_key
|
| 168 |
+
elif hasattr(embconfig, "input_keys"):
|
| 169 |
+
embedder.input_keys = embconfig.input_keys
|
| 170 |
+
else:
|
| 171 |
+
raise KeyError(f"need either 'input_key' or 'input_keys' for embedder {embedder.__class__.__name__}")
|
| 172 |
+
|
| 173 |
+
log.debug(f"Initialized embedder #{n}-{emb_name}: \n {embedder.summary()}")
|
| 174 |
+
self.embedders[emb_name] = embedder
|
| 175 |
+
|
| 176 |
+
@abstractmethod
|
| 177 |
+
def forward(
|
| 178 |
+
self,
|
| 179 |
+
batch: Dict,
|
| 180 |
+
override_dropout_rate: Optional[Dict[str, float]] = None,
|
| 181 |
+
) -> Any:
|
| 182 |
+
"""Should be implemented in subclasses to handle conditon datatype"""
|
| 183 |
+
raise NotImplementedError
|
| 184 |
+
|
| 185 |
+
def _forward(
|
| 186 |
+
self,
|
| 187 |
+
batch: Dict,
|
| 188 |
+
override_dropout_rate: Optional[Dict[str, float]] = None,
|
| 189 |
+
) -> Dict:
|
| 190 |
+
"""
|
| 191 |
+
Processes the input batch through all configured embedders, applying conditional dropout rates if specified.
|
| 192 |
+
Output tensors for each key are concatenated along the dimensions specified in KEY2DIM.
|
| 193 |
+
|
| 194 |
+
Parameters:
|
| 195 |
+
batch (Dict): The input data batch to process.
|
| 196 |
+
override_dropout_rate (Optional[Dict[str, float]]): Optional dictionary to override default dropout rates
|
| 197 |
+
per embedder key.
|
| 198 |
+
|
| 199 |
+
Returns:
|
| 200 |
+
Dict: A dictionary of output tensors concatenated by specified dimensions.
|
| 201 |
+
|
| 202 |
+
Note:
|
| 203 |
+
In case the network code is sensitive to the order of concatenation, you can either control the order via \
|
| 204 |
+
config file or make sure the embedders return a unique key for each output.
|
| 205 |
+
"""
|
| 206 |
+
output = defaultdict(list)
|
| 207 |
+
if override_dropout_rate is None:
|
| 208 |
+
override_dropout_rate = {}
|
| 209 |
+
|
| 210 |
+
# make sure emb_name in override_dropout_rate is valid
|
| 211 |
+
for emb_name in override_dropout_rate.keys():
|
| 212 |
+
assert emb_name in self.embedders, f"invalid name found {emb_name}"
|
| 213 |
+
|
| 214 |
+
for emb_name, embedder in self.embedders.items():
|
| 215 |
+
with torch.no_grad():
|
| 216 |
+
if hasattr(embedder, "input_key") and (embedder.input_key is not None):
|
| 217 |
+
emb_out = embedder(
|
| 218 |
+
embedder.random_dropout_input(
|
| 219 |
+
batch[embedder.input_key], override_dropout_rate.get(emb_name, None)
|
| 220 |
+
)
|
| 221 |
+
)
|
| 222 |
+
elif hasattr(embedder, "input_keys"):
|
| 223 |
+
emb_out = embedder(
|
| 224 |
+
*[
|
| 225 |
+
embedder.random_dropout_input(batch[k], override_dropout_rate.get(emb_name, None), k)
|
| 226 |
+
for k in embedder.input_keys
|
| 227 |
+
]
|
| 228 |
+
)
|
| 229 |
+
for k, v in emb_out.items():
|
| 230 |
+
output[k].append(v)
|
| 231 |
+
# Concatenate the outputs
|
| 232 |
+
return {k: torch.cat(v, dim=self.KEY2DIM.get(k, -1)) for k, v in output.items()}
|
| 233 |
+
|
| 234 |
+
def get_condition_uncondition(
|
| 235 |
+
self,
|
| 236 |
+
data_batch: Dict,
|
| 237 |
+
) -> Tuple[Any, Any]:
|
| 238 |
+
"""
|
| 239 |
+
Processes the provided data batch to generate conditioned and unconditioned outputs.
|
| 240 |
+
|
| 241 |
+
This method manipulates dropout rates to simulate two scenarios:
|
| 242 |
+
1. All conditions applied (conditioned)
|
| 243 |
+
2. Conditions removed/reduced to minimum (unconditioned)
|
| 244 |
+
|
| 245 |
+
This method sets dropout rates to zero for the conditioned scenario to fully apply
|
| 246 |
+
embedders' effects. For unconditioned, it sets rates to 1 (or 0 if initial rate is
|
| 247 |
+
insignificant) to minimize embedder influences.
|
| 248 |
+
|
| 249 |
+
Parameters:
|
| 250 |
+
data_batch (Dict): Input data batch containing all necessary information for
|
| 251 |
+
embedding processing.
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
Tuple[Any, Any]: A tuple containing:
|
| 255 |
+
- Outputs with all embedders fully applied (conditioned)
|
| 256 |
+
- Outputs with embedders minimized/not applied (unconditioned)
|
| 257 |
+
"""
|
| 258 |
+
cond_dropout_rates, dropout_rates = {}, {}
|
| 259 |
+
for emb_name, embedder in self.embedders.items():
|
| 260 |
+
cond_dropout_rates[emb_name] = 0.0
|
| 261 |
+
dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0
|
| 262 |
+
|
| 263 |
+
condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates)
|
| 264 |
+
un_condition: Any = self(data_batch, override_dropout_rate=dropout_rates)
|
| 265 |
+
return condition, un_condition
|
| 266 |
+
|
| 267 |
+
def get_condition_with_negative_prompt(
|
| 268 |
+
self,
|
| 269 |
+
data_batch: Dict,
|
| 270 |
+
) -> Tuple[Any, Any]:
|
| 271 |
+
"""
|
| 272 |
+
Similar functionality as get_condition_uncondition
|
| 273 |
+
But use negative prompts for unconditon
|
| 274 |
+
"""
|
| 275 |
+
cond_dropout_rates, uncond_dropout_rates = {}, {}
|
| 276 |
+
for emb_name, embedder in self.embedders.items():
|
| 277 |
+
cond_dropout_rates[emb_name] = 0.0
|
| 278 |
+
if isinstance(embedder, TextAttr):
|
| 279 |
+
uncond_dropout_rates[emb_name] = 0.0
|
| 280 |
+
else:
|
| 281 |
+
uncond_dropout_rates[emb_name] = 1.0 if embedder.dropout_rate > 1e-4 else 0.0
|
| 282 |
+
|
| 283 |
+
data_batch_neg_prompt = copy.deepcopy(data_batch)
|
| 284 |
+
if "neg_t5_text_embeddings" in data_batch_neg_prompt:
|
| 285 |
+
if isinstance(data_batch_neg_prompt["neg_t5_text_embeddings"], torch.Tensor):
|
| 286 |
+
data_batch_neg_prompt["t5_text_embeddings"] = data_batch_neg_prompt["neg_t5_text_embeddings"]
|
| 287 |
+
data_batch_neg_prompt["t5_text_mask"] = data_batch_neg_prompt["neg_t5_text_mask"]
|
| 288 |
+
|
| 289 |
+
condition: Any = self(data_batch, override_dropout_rate=cond_dropout_rates)
|
| 290 |
+
un_condition: Any = self(data_batch_neg_prompt, override_dropout_rate=uncond_dropout_rates)
|
| 291 |
+
|
| 292 |
+
return condition, un_condition
|
| 293 |
+
|
| 294 |
+
|
| 295 |
+
@dataclass
|
| 296 |
+
class CosmosCondition:
|
| 297 |
+
crossattn_emb: torch.Tensor
|
| 298 |
+
crossattn_mask: torch.Tensor
|
| 299 |
+
padding_mask: Optional[torch.Tensor] = None
|
| 300 |
+
scalar_feature: Optional[torch.Tensor] = None
|
| 301 |
+
|
| 302 |
+
def to_dict(self) -> Dict[str, Optional[torch.Tensor]]:
|
| 303 |
+
return {f.name: getattr(self, f.name) for f in fields(self)}
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
class VideoConditioner(GeneralConditioner):
|
| 307 |
+
def forward(
|
| 308 |
+
self,
|
| 309 |
+
batch: Dict,
|
| 310 |
+
override_dropout_rate: Optional[Dict[str, float]] = None,
|
| 311 |
+
) -> BaseVideoCondition:
|
| 312 |
+
output = super()._forward(batch, override_dropout_rate)
|
| 313 |
+
return BaseVideoCondition(**output)
|
| 314 |
+
|
| 315 |
+
|
| 316 |
+
class VideoExtendConditioner(GeneralConditioner):
|
| 317 |
+
def forward(
|
| 318 |
+
self,
|
| 319 |
+
batch: Dict,
|
| 320 |
+
override_dropout_rate: Optional[Dict[str, float]] = None,
|
| 321 |
+
) -> VideoExtendCondition:
|
| 322 |
+
output = super()._forward(batch, override_dropout_rate)
|
| 323 |
+
return VideoExtendCondition(**output)
|
config.json
ADDED
|
@@ -0,0 +1,10 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"architectures": [
|
| 3 |
+
"DiffusionVideo2World"
|
| 4 |
+
],
|
| 5 |
+
"auto_map": {
|
| 6 |
+
"AutoConfig": "video2world_hf.DiffusionVideo2WorldConfig",
|
| 7 |
+
"AutoModel": "video2world_hf.DiffusionVideo2World"
|
| 8 |
+
},
|
| 9 |
+
"model_type": "AutoModel"
|
| 10 |
+
}
|
convert_pixtral_ckpt.py
ADDED
|
@@ -0,0 +1,209 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""Convert pretrained Pixtral vision model weights to checkpoint and verify the checkpoint loading.
|
| 17 |
+
|
| 18 |
+
Usage:
|
| 19 |
+
|
| 20 |
+
PYTHONPATH=$(pwd) python cosmos1/scripts/convert_pixtral_ckpt.py
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import json
|
| 25 |
+
import os
|
| 26 |
+
import shutil
|
| 27 |
+
from glob import glob
|
| 28 |
+
|
| 29 |
+
import torch
|
| 30 |
+
from huggingface_hub import snapshot_download
|
| 31 |
+
from safetensors.torch import load_file
|
| 32 |
+
|
| 33 |
+
|
| 34 |
+
def convert_pixtral_checkpoint(checkpoint_dir: str, checkpoint_name: str, vit_type: str):
|
| 35 |
+
"""
|
| 36 |
+
Main function to convert Pixtral vision model weights to checkpoint and optionally verify and save the converted checkpoint.
|
| 37 |
+
|
| 38 |
+
Args:
|
| 39 |
+
checkpoint_dir (str): Path to the checkpoint directory
|
| 40 |
+
checkpoint_name (str): Name of the checkpoint
|
| 41 |
+
vit_type (str): Type of ViT used in the Pixtral model
|
| 42 |
+
|
| 43 |
+
This function performs the following steps:
|
| 44 |
+
0. Download the checkpoint from Hugging Face
|
| 45 |
+
1. Loads the original Pixtral checkpoint
|
| 46 |
+
2. Splits the checkpoint into vision encoder, projector, and LLM weights
|
| 47 |
+
3. Reorganizes the weights to match the expected format
|
| 48 |
+
4. Extracts and verifies the vision encoder configuration
|
| 49 |
+
5. Optionally verifies the converted checkpoint by loading it into a VisionTransformer
|
| 50 |
+
6. Optionally saves the converted checkpoint and configuration
|
| 51 |
+
"""
|
| 52 |
+
|
| 53 |
+
save_dir = os.path.join(checkpoint_dir, checkpoint_name)
|
| 54 |
+
os.makedirs(save_dir, exist_ok=True)
|
| 55 |
+
# Save the converted checkpoint
|
| 56 |
+
save_path = os.path.join(save_dir, "model.pt")
|
| 57 |
+
if os.path.exists(save_path) and os.path.getsize(save_path) > 0:
|
| 58 |
+
print(f"Checkpoint {save_path} already exists and is not empty")
|
| 59 |
+
return
|
| 60 |
+
|
| 61 |
+
pixtral_ckpt_dir = os.path.join(checkpoint_dir, "Pixtral-12B-2409")
|
| 62 |
+
os.makedirs(pixtral_ckpt_dir, exist_ok=True)
|
| 63 |
+
repo_id = "mistralai/Pixtral-12B-2409"
|
| 64 |
+
print(f"Downloading {repo_id} to {pixtral_ckpt_dir}...")
|
| 65 |
+
snapshot_download(
|
| 66 |
+
repo_id=repo_id,
|
| 67 |
+
allow_patterns=["params.json", "consolidated.safetensors"],
|
| 68 |
+
local_dir=pixtral_ckpt_dir,
|
| 69 |
+
local_dir_use_symlinks=False,
|
| 70 |
+
)
|
| 71 |
+
orig_dtype = torch.get_default_dtype()
|
| 72 |
+
dtype = torch.bfloat16
|
| 73 |
+
torch.set_default_dtype(dtype)
|
| 74 |
+
|
| 75 |
+
# Load checkpoint file
|
| 76 |
+
ckpt_files = glob(os.path.join(pixtral_ckpt_dir, "*.safetensors"))
|
| 77 |
+
assert len(ckpt_files) == 1, "ckpt_dir should contain only one file"
|
| 78 |
+
ckpt_path = ckpt_files[0]
|
| 79 |
+
ckpt = load_file(ckpt_path)
|
| 80 |
+
|
| 81 |
+
# Split checkpoint into weights of vision encoder, projector, and LLM
|
| 82 |
+
vit_key_prefix = "vision_encoder."
|
| 83 |
+
vit_ckpt = {}
|
| 84 |
+
for key, value in ckpt.items():
|
| 85 |
+
if key.startswith(vit_key_prefix):
|
| 86 |
+
vit_ckpt[key.lstrip(vit_key_prefix)] = value
|
| 87 |
+
|
| 88 |
+
projector_key_prefix = "vision_language_adapter."
|
| 89 |
+
projector_ckpt = {}
|
| 90 |
+
substring_replacement_map = {
|
| 91 |
+
"w_in.": "projector.0.",
|
| 92 |
+
"w_out.": "projector.2.",
|
| 93 |
+
}
|
| 94 |
+
for key, value in ckpt.items():
|
| 95 |
+
if key.startswith(projector_key_prefix):
|
| 96 |
+
key = key.lstrip(projector_key_prefix)
|
| 97 |
+
for old, new in substring_replacement_map.items():
|
| 98 |
+
key = key.replace(old, new)
|
| 99 |
+
projector_ckpt[key] = value
|
| 100 |
+
|
| 101 |
+
llm_ckpt = {}
|
| 102 |
+
for key, value in ckpt.items():
|
| 103 |
+
if key.startswith(vit_key_prefix) or key.startswith(projector_key_prefix):
|
| 104 |
+
continue
|
| 105 |
+
llm_ckpt[key] = value
|
| 106 |
+
|
| 107 |
+
vlm_ckpt = {}
|
| 108 |
+
for key, value in llm_ckpt.items():
|
| 109 |
+
vlm_ckpt["model." + key] = value
|
| 110 |
+
for key, value in projector_ckpt.items():
|
| 111 |
+
vlm_ckpt["mm_projector." + key] = value
|
| 112 |
+
for key, value in vit_ckpt.items():
|
| 113 |
+
vlm_ckpt["vision_encoder." + key] = value
|
| 114 |
+
|
| 115 |
+
# Load config
|
| 116 |
+
config_path = os.path.join(pixtral_ckpt_dir, "params.json")
|
| 117 |
+
with open(config_path, "r") as f:
|
| 118 |
+
pixtral_config = json.load(f)
|
| 119 |
+
|
| 120 |
+
# Extract the vision encoder configuration
|
| 121 |
+
vision_encoder_config = {
|
| 122 |
+
"dim": pixtral_config["vision_encoder"]["hidden_size"],
|
| 123 |
+
"num_channels": pixtral_config["vision_encoder"]["num_channels"],
|
| 124 |
+
"image_size": pixtral_config["vision_encoder"]["image_size"],
|
| 125 |
+
"patch_size": pixtral_config["vision_encoder"]["patch_size"],
|
| 126 |
+
"rope_theta": pixtral_config["vision_encoder"]["rope_theta"],
|
| 127 |
+
"ffn_hidden_size": pixtral_config["vision_encoder"]["intermediate_size"],
|
| 128 |
+
"n_layers": pixtral_config["vision_encoder"]["num_hidden_layers"],
|
| 129 |
+
"n_heads": pixtral_config["vision_encoder"]["num_attention_heads"],
|
| 130 |
+
"n_kv_heads": pixtral_config["vision_encoder"]["num_attention_heads"],
|
| 131 |
+
"norm_type": "rmsnorm",
|
| 132 |
+
"norm_eps": pixtral_config["norm_eps"],
|
| 133 |
+
"image_token_id": pixtral_config["vision_encoder"]["image_token_id"],
|
| 134 |
+
}
|
| 135 |
+
# Configuration for the 400M ViT of Pixtral 12B VLM
|
| 136 |
+
vit_config = dict(
|
| 137 |
+
dim=1024,
|
| 138 |
+
num_channels=3,
|
| 139 |
+
image_size=1024,
|
| 140 |
+
patch_size=16,
|
| 141 |
+
rope_theta=10000,
|
| 142 |
+
ffn_hidden_size=4096,
|
| 143 |
+
n_layers=24,
|
| 144 |
+
n_heads=16,
|
| 145 |
+
n_kv_heads=16,
|
| 146 |
+
norm_type="rmsnorm",
|
| 147 |
+
norm_eps=1e-5,
|
| 148 |
+
image_token_id=10,
|
| 149 |
+
)
|
| 150 |
+
# Compare the two configurations
|
| 151 |
+
for key, value in vit_config.items():
|
| 152 |
+
assert vision_encoder_config[key] == value, f"Mismatch in {key}: {vision_encoder_config[key]} != {value}"
|
| 153 |
+
|
| 154 |
+
llm_config_keys = [
|
| 155 |
+
"dim",
|
| 156 |
+
"n_layers",
|
| 157 |
+
"head_dim",
|
| 158 |
+
"hidden_dim",
|
| 159 |
+
"n_heads",
|
| 160 |
+
"n_kv_heads",
|
| 161 |
+
"rope_theta",
|
| 162 |
+
"norm_eps",
|
| 163 |
+
"vocab_size",
|
| 164 |
+
]
|
| 165 |
+
assert set(list(pixtral_config.keys())) == set(llm_config_keys + ["vision_encoder"]), "Config keys mismatch"
|
| 166 |
+
replace_map = {
|
| 167 |
+
"hidden_dim": "ffn_hidden_size",
|
| 168 |
+
}
|
| 169 |
+
llm_config = {}
|
| 170 |
+
for k, v in pixtral_config.items():
|
| 171 |
+
if k in llm_config_keys:
|
| 172 |
+
llm_config[replace_map.get(k, k)] = v
|
| 173 |
+
elif k == "vision_encoder":
|
| 174 |
+
llm_config["vision_encoder"] = vit_type
|
| 175 |
+
else:
|
| 176 |
+
raise ValueError(f"Unknown key: {k}")
|
| 177 |
+
|
| 178 |
+
ckpt_to_save = {"model": vlm_ckpt, "mm_projector": projector_ckpt, "vision_encoder": vit_ckpt}
|
| 179 |
+
torch.save(ckpt_to_save, save_path)
|
| 180 |
+
print(f"Model saved to {save_path}")
|
| 181 |
+
|
| 182 |
+
# Save config
|
| 183 |
+
config_path = os.path.join(save_dir, "config.json")
|
| 184 |
+
with open(config_path, "w") as f:
|
| 185 |
+
json.dump(llm_config, f)
|
| 186 |
+
|
| 187 |
+
torch.set_default_dtype(orig_dtype) # Reset the default dtype
|
| 188 |
+
|
| 189 |
+
# Remove the original Pixtral checkpoint
|
| 190 |
+
shutil.rmtree(pixtral_ckpt_dir, ignore_errors=True)
|
| 191 |
+
print(f"Removed {pixtral_ckpt_dir}")
|
| 192 |
+
|
| 193 |
+
|
| 194 |
+
if __name__ == "__main__":
|
| 195 |
+
parser = argparse.ArgumentParser(
|
| 196 |
+
description="Convert pretrained Pixtral vision model weights to checkpoint and verify accuracy"
|
| 197 |
+
)
|
| 198 |
+
parser.add_argument("--checkpoint_dir", type=str, default="checkpoints", help="Path to the checkpoint directory")
|
| 199 |
+
parser.add_argument(
|
| 200 |
+
"--checkpoint_name",
|
| 201 |
+
type=str,
|
| 202 |
+
default="Pixtral-12B",
|
| 203 |
+
help="Name of the checkpoint",
|
| 204 |
+
)
|
| 205 |
+
parser.add_argument("--vit_type", default="pixtral-12b-vit", help="Type of ViT used in the Pixtral model")
|
| 206 |
+
args = parser.parse_args()
|
| 207 |
+
convert_pixtral_checkpoint(
|
| 208 |
+
checkpoint_dir=args.checkpoint_dir, checkpoint_name=args.checkpoint_name, vit_type=args.vit_type
|
| 209 |
+
)
|
download_diffusion.py
ADDED
|
@@ -0,0 +1,123 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
from pathlib import Path
|
| 18 |
+
|
| 19 |
+
from huggingface_hub import snapshot_download
|
| 20 |
+
|
| 21 |
+
from Cosmos.convert_pixtral_ckpt import convert_pixtral_checkpoint
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def parse_args():
|
| 25 |
+
parser = argparse.ArgumentParser(description="Download NVIDIA Cosmos-1.0 Diffusion models from Hugging Face")
|
| 26 |
+
parser.add_argument(
|
| 27 |
+
"--model_sizes",
|
| 28 |
+
nargs="*",
|
| 29 |
+
default=[
|
| 30 |
+
"7B",
|
| 31 |
+
"14B",
|
| 32 |
+
], # Download all by default
|
| 33 |
+
choices=["7B", "14B"],
|
| 34 |
+
help="Which model sizes to download. Possible values: 7B, 14B",
|
| 35 |
+
)
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--model_types",
|
| 38 |
+
nargs="*",
|
| 39 |
+
default=[
|
| 40 |
+
"Text2World",
|
| 41 |
+
"Video2World",
|
| 42 |
+
], # Download all by default
|
| 43 |
+
choices=["Text2World", "Video2World"],
|
| 44 |
+
help="Which model types to download. Possible values: Text2World, Video2World",
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--cosmos_version",
|
| 48 |
+
type=str,
|
| 49 |
+
default="1.0",
|
| 50 |
+
choices=["1.0"],
|
| 51 |
+
help="Which version of Cosmos to download. Only 1.0 is available at the moment.",
|
| 52 |
+
)
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--checkpoint_dir", type=str, default="checkpoints", help="Directory to save the downloaded checkpoints."
|
| 55 |
+
)
|
| 56 |
+
args = parser.parse_args()
|
| 57 |
+
return args
|
| 58 |
+
|
| 59 |
+
|
| 60 |
+
def main(args):
|
| 61 |
+
ORG_NAME = "nvidia"
|
| 62 |
+
|
| 63 |
+
# Mapping from size argument to Hugging Face repository name
|
| 64 |
+
model_map = {
|
| 65 |
+
"7B": "Cosmos-1.0-Diffusion-7B",
|
| 66 |
+
"14B": "Cosmos-1.0-Diffusion-14B",
|
| 67 |
+
}
|
| 68 |
+
|
| 69 |
+
# Additional models that are always downloaded
|
| 70 |
+
extra_models = [
|
| 71 |
+
"Cosmos-1.0-Guardrail",
|
| 72 |
+
"Cosmos-1.0-Tokenizer-CV8x8x8",
|
| 73 |
+
]
|
| 74 |
+
|
| 75 |
+
if "Text2World" in args.model_types:
|
| 76 |
+
extra_models.append("Cosmos-1.0-Prompt-Upsampler-12B-Text2World")
|
| 77 |
+
|
| 78 |
+
# Create local checkpoints folder
|
| 79 |
+
checkpoints_dir = Path(args.checkpoint_dir)
|
| 80 |
+
checkpoints_dir.mkdir(parents=True, exist_ok=True)
|
| 81 |
+
|
| 82 |
+
download_kwargs = dict(allow_patterns=["README.md", "model.pt", "config.json", "*.jit"])
|
| 83 |
+
|
| 84 |
+
# Download the requested Autoregressive models
|
| 85 |
+
for size in args.model_sizes:
|
| 86 |
+
for model_type in args.model_types:
|
| 87 |
+
suffix = f"-{model_type}"
|
| 88 |
+
model_name = model_map[size] + suffix
|
| 89 |
+
repo_id = f"{ORG_NAME}/{model_name}"
|
| 90 |
+
local_dir = checkpoints_dir.joinpath(model_name)
|
| 91 |
+
local_dir.mkdir(parents=True, exist_ok=True)
|
| 92 |
+
|
| 93 |
+
print(f"Downloading {repo_id} to {local_dir}...")
|
| 94 |
+
snapshot_download(
|
| 95 |
+
repo_id=repo_id, local_dir=str(local_dir), local_dir_use_symlinks=False, **download_kwargs
|
| 96 |
+
)
|
| 97 |
+
|
| 98 |
+
# Download the always-included models
|
| 99 |
+
for model_name in extra_models:
|
| 100 |
+
repo_id = f"{ORG_NAME}/{model_name}"
|
| 101 |
+
local_dir = checkpoints_dir.joinpath(model_name)
|
| 102 |
+
local_dir.mkdir(parents=True, exist_ok=True)
|
| 103 |
+
|
| 104 |
+
print(f"Downloading {repo_id} to {local_dir}...")
|
| 105 |
+
# Download all files for Guardrail
|
| 106 |
+
snapshot_download(
|
| 107 |
+
repo_id=repo_id,
|
| 108 |
+
local_dir=str(local_dir),
|
| 109 |
+
local_dir_use_symlinks=False,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if "Video2World" in args.model_types:
|
| 113 |
+
# Prompt Upsampler for Cosmos-1.0-Diffusion-Video2World models
|
| 114 |
+
convert_pixtral_checkpoint(
|
| 115 |
+
checkpoint_dir=args.checkpoint_dir,
|
| 116 |
+
checkpoint_name="Pixtral-12B",
|
| 117 |
+
vit_type="pixtral-12b-vit",
|
| 118 |
+
)
|
| 119 |
+
|
| 120 |
+
|
| 121 |
+
if __name__ == "__main__":
|
| 122 |
+
args = parse_args()
|
| 123 |
+
main(args)
|
guardrail_presets.py
ADDED
|
@@ -0,0 +1,77 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import os
|
| 17 |
+
|
| 18 |
+
import numpy as np
|
| 19 |
+
|
| 20 |
+
from cosmos1.models.guardrail.aegis.aegis import Aegis
|
| 21 |
+
from cosmos1.models.guardrail.blocklist.blocklist import Blocklist
|
| 22 |
+
from cosmos1.models.guardrail.common.core import GuardrailRunner
|
| 23 |
+
from cosmos1.models.guardrail.face_blur_filter.face_blur_filter import RetinaFaceFilter
|
| 24 |
+
from cosmos1.models.guardrail.video_content_safety_filter.video_content_safety_filter import VideoContentSafetyFilter
|
| 25 |
+
from Cosmos.utils import log
|
| 26 |
+
|
| 27 |
+
|
| 28 |
+
def create_text_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
|
| 29 |
+
"""Create the text guardrail runner."""
|
| 30 |
+
blocklist_checkpoint_dir = os.path.join(checkpoint_dir, "blocklist")
|
| 31 |
+
aegis_checkpoint_dir = os.path.join(checkpoint_dir, "aegis")
|
| 32 |
+
return GuardrailRunner(safety_models=[Blocklist(blocklist_checkpoint_dir), Aegis(aegis_checkpoint_dir)])
|
| 33 |
+
|
| 34 |
+
|
| 35 |
+
def create_video_guardrail_runner(checkpoint_dir: str) -> GuardrailRunner:
|
| 36 |
+
"""Create the video guardrail runner."""
|
| 37 |
+
video_filter_checkpoint_dir = os.path.join(checkpoint_dir, "video_content_safety_filter")
|
| 38 |
+
retinaface_checkpoint_path = os.path.join(checkpoint_dir, "face_blur_filter/Resnet50_Final.pth")
|
| 39 |
+
return GuardrailRunner(
|
| 40 |
+
safety_models=[VideoContentSafetyFilter(video_filter_checkpoint_dir)],
|
| 41 |
+
postprocessors=[RetinaFaceFilter(retinaface_checkpoint_path)],
|
| 42 |
+
)
|
| 43 |
+
|
| 44 |
+
|
| 45 |
+
def run_text_guardrail(prompt: str, guardrail_runner: GuardrailRunner) -> bool:
|
| 46 |
+
"""Run the text guardrail on the prompt, checking for content safety.
|
| 47 |
+
|
| 48 |
+
Args:
|
| 49 |
+
prompt: The text prompt.
|
| 50 |
+
guardrail_runner: The text guardrail runner.
|
| 51 |
+
|
| 52 |
+
Returns:
|
| 53 |
+
bool: Whether the prompt is safe.
|
| 54 |
+
"""
|
| 55 |
+
is_safe, message = guardrail_runner.run_safety_check(prompt)
|
| 56 |
+
if not is_safe:
|
| 57 |
+
log.critical(f"GUARDRAIL BLOCKED: {message}")
|
| 58 |
+
return is_safe
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
def run_video_guardrail(frames: np.ndarray, guardrail_runner: GuardrailRunner) -> np.ndarray | None:
|
| 62 |
+
"""Run the video guardrail on the frames, checking for content safety and applying face blur.
|
| 63 |
+
|
| 64 |
+
Args:
|
| 65 |
+
frames: The frames of the generated video.
|
| 66 |
+
guardrail_runner: The video guardrail runner.
|
| 67 |
+
|
| 68 |
+
Returns:
|
| 69 |
+
The processed frames if safe, otherwise None.
|
| 70 |
+
"""
|
| 71 |
+
is_safe, message = guardrail_runner.run_safety_check(frames)
|
| 72 |
+
if not is_safe:
|
| 73 |
+
log.critical(f"GUARDRAIL BLOCKED: {message}")
|
| 74 |
+
return None
|
| 75 |
+
|
| 76 |
+
frames = guardrail_runner.postprocess(frames)
|
| 77 |
+
return frames
|
inference_utils.py
ADDED
|
@@ -0,0 +1,726 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import importlib
|
| 18 |
+
from contextlib import contextmanager
|
| 19 |
+
from typing import List, NamedTuple, Optional, Tuple
|
| 20 |
+
|
| 21 |
+
from Cosmos.utils import misc
|
| 22 |
+
import einops
|
| 23 |
+
import imageio
|
| 24 |
+
import numpy as np
|
| 25 |
+
import torch
|
| 26 |
+
import torchvision.transforms.functional as transforms_F
|
| 27 |
+
|
| 28 |
+
from Cosmos.model_t2w import DiffusionT2WModel
|
| 29 |
+
from Cosmos.model_v2w import DiffusionV2WModel
|
| 30 |
+
from Cosmos.utils import log
|
| 31 |
+
from Cosmos.utils.config_helper import get_config_module, override
|
| 32 |
+
from Cosmos.utils.io import load_from_fileobj
|
| 33 |
+
|
| 34 |
+
TORCH_VERSION: Tuple[int, ...] = tuple(int(x) for x in torch.__version__.split(".")[:2])
|
| 35 |
+
if TORCH_VERSION >= (1, 11):
|
| 36 |
+
from torch.ao import quantization
|
| 37 |
+
from torch.ao.quantization import FakeQuantizeBase, ObserverBase
|
| 38 |
+
elif (
|
| 39 |
+
TORCH_VERSION >= (1, 8)
|
| 40 |
+
and hasattr(torch.quantization, "FakeQuantizeBase")
|
| 41 |
+
and hasattr(torch.quantization, "ObserverBase")
|
| 42 |
+
):
|
| 43 |
+
from torch import quantization
|
| 44 |
+
from torch.quantization import FakeQuantizeBase, ObserverBase
|
| 45 |
+
|
| 46 |
+
DEFAULT_AUGMENT_SIGMA = 0.001
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def add_common_arguments(parser):
|
| 50 |
+
"""Add common command line arguments for text2world and video2world generation.
|
| 51 |
+
|
| 52 |
+
Args:
|
| 53 |
+
parser (ArgumentParser): Argument parser to add arguments to
|
| 54 |
+
|
| 55 |
+
The arguments include:
|
| 56 |
+
- checkpoint_dir: Base directory containing model weights
|
| 57 |
+
- tokenizer_dir: Directory containing tokenizer weights
|
| 58 |
+
- video_save_name: Output video filename for single video generation
|
| 59 |
+
- video_save_folder: Output directory for batch video generation
|
| 60 |
+
- prompt: Text prompt for single video generation
|
| 61 |
+
- batch_input_path: Path to JSONL file with input prompts for batch video generation
|
| 62 |
+
- negative_prompt: Text prompt describing undesired attributes
|
| 63 |
+
- num_steps: Number of diffusion sampling steps
|
| 64 |
+
- guidance: Classifier-free guidance scale
|
| 65 |
+
- num_video_frames: Number of frames to generate
|
| 66 |
+
- height/width: Output video dimensions
|
| 67 |
+
- fps: Output video frame rate
|
| 68 |
+
- seed: Random seed for reproducibility
|
| 69 |
+
- Various model offloading flags
|
| 70 |
+
"""
|
| 71 |
+
parser.add_argument(
|
| 72 |
+
"--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints"
|
| 73 |
+
)
|
| 74 |
+
parser.add_argument(
|
| 75 |
+
"--tokenizer_dir",
|
| 76 |
+
type=str,
|
| 77 |
+
default="Cosmos-1.0-Tokenizer-CV8x8x8",
|
| 78 |
+
help="Tokenizer weights directory relative to checkpoint_dir",
|
| 79 |
+
)
|
| 80 |
+
parser.add_argument(
|
| 81 |
+
"--video_save_name",
|
| 82 |
+
type=str,
|
| 83 |
+
default="output",
|
| 84 |
+
help="Output filename for generating a single video",
|
| 85 |
+
)
|
| 86 |
+
parser.add_argument(
|
| 87 |
+
"--video_save_folder",
|
| 88 |
+
type=str,
|
| 89 |
+
default="outputs/",
|
| 90 |
+
help="Output folder for generating a batch of videos",
|
| 91 |
+
)
|
| 92 |
+
parser.add_argument(
|
| 93 |
+
"--prompt",
|
| 94 |
+
type=str,
|
| 95 |
+
help="Text prompt for generating a single video",
|
| 96 |
+
)
|
| 97 |
+
parser.add_argument(
|
| 98 |
+
"--batch_input_path",
|
| 99 |
+
type=str,
|
| 100 |
+
help="Path to a JSONL file of input prompts for generating a batch of videos",
|
| 101 |
+
)
|
| 102 |
+
parser.add_argument(
|
| 103 |
+
"--negative_prompt",
|
| 104 |
+
type=str,
|
| 105 |
+
default="The video captures a series of frames showing ugly scenes, static with no motion, motion blur, "
|
| 106 |
+
"over-saturation, shaky footage, low resolution, grainy texture, pixelated images, poorly lit areas, "
|
| 107 |
+
"underexposed and overexposed scenes, poor color balance, washed out colors, choppy sequences, "
|
| 108 |
+
"jerky movements, low frame rate, artifacting, color banding, unnatural transitions, outdated special "
|
| 109 |
+
"effects, fake elements, unconvincing visuals, poorly edited content, jump cuts, visual noise, and "
|
| 110 |
+
"flickering. Overall, the video is of poor quality.",
|
| 111 |
+
help="Negative prompt for the video",
|
| 112 |
+
)
|
| 113 |
+
parser.add_argument("--num_steps", type=int, default=35, help="Number of diffusion sampling steps")
|
| 114 |
+
parser.add_argument("--guidance", type=float, default=7, help="Guidance scale value")
|
| 115 |
+
parser.add_argument("--num_video_frames", type=int, default=121, help="Number of video frames to sample")
|
| 116 |
+
parser.add_argument("--height", type=int, default=704, help="Height of video to sample")
|
| 117 |
+
parser.add_argument("--width", type=int, default=1280, help="Width of video to sample")
|
| 118 |
+
parser.add_argument("--fps", type=int, default=24, help="FPS of the sampled video")
|
| 119 |
+
parser.add_argument("--seed", type=int, default=1, help="Random seed")
|
| 120 |
+
parser.add_argument(
|
| 121 |
+
"--disable_prompt_upsampler",
|
| 122 |
+
action="store_true",
|
| 123 |
+
help="Disable prompt upsampling",
|
| 124 |
+
)
|
| 125 |
+
parser.add_argument(
|
| 126 |
+
"--offload_diffusion_transformer",
|
| 127 |
+
action="store_true",
|
| 128 |
+
help="Offload DiT after inference",
|
| 129 |
+
)
|
| 130 |
+
parser.add_argument(
|
| 131 |
+
"--offload_tokenizer",
|
| 132 |
+
action="store_true",
|
| 133 |
+
help="Offload tokenizer after inference",
|
| 134 |
+
)
|
| 135 |
+
parser.add_argument(
|
| 136 |
+
"--offload_text_encoder_model",
|
| 137 |
+
action="store_true",
|
| 138 |
+
help="Offload text encoder model after inference",
|
| 139 |
+
)
|
| 140 |
+
parser.add_argument(
|
| 141 |
+
"--offload_prompt_upsampler",
|
| 142 |
+
action="store_true",
|
| 143 |
+
help="Offload prompt upsampler after inference",
|
| 144 |
+
)
|
| 145 |
+
parser.add_argument(
|
| 146 |
+
"--offload_guardrail_models",
|
| 147 |
+
action="store_true",
|
| 148 |
+
help="Offload guardrail models after inference",
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
def validate_args(args: argparse.Namespace, inference_type: str) -> None:
|
| 153 |
+
"""Validate command line arguments for text2world and video2world generation."""
|
| 154 |
+
assert inference_type in [
|
| 155 |
+
"text2world",
|
| 156 |
+
"video2world",
|
| 157 |
+
], "Invalid inference_type, must be 'text2world' or 'video2world'"
|
| 158 |
+
|
| 159 |
+
# Validate prompt/image/video args for single or batch generation
|
| 160 |
+
if inference_type == "text2world" or (inference_type == "video2world" and args.disable_prompt_upsampler):
|
| 161 |
+
assert args.prompt or args.batch_input_path, "--prompt or --batch_input_path must be provided."
|
| 162 |
+
if inference_type == "video2world" and not args.batch_input_path:
|
| 163 |
+
assert (
|
| 164 |
+
args.input_image_or_video_path
|
| 165 |
+
), "--input_image_or_video_path must be provided for single video generation."
|
| 166 |
+
|
| 167 |
+
|
| 168 |
+
class _IncompatibleKeys(
|
| 169 |
+
NamedTuple(
|
| 170 |
+
"IncompatibleKeys",
|
| 171 |
+
[
|
| 172 |
+
("missing_keys", List[str]),
|
| 173 |
+
("unexpected_keys", List[str]),
|
| 174 |
+
("incorrect_shapes", List[Tuple[str, Tuple[int], Tuple[int]]]),
|
| 175 |
+
],
|
| 176 |
+
)
|
| 177 |
+
):
|
| 178 |
+
pass
|
| 179 |
+
|
| 180 |
+
|
| 181 |
+
def non_strict_load_model(model: torch.nn.Module, checkpoint_state_dict: dict) -> _IncompatibleKeys:
|
| 182 |
+
"""Load a model checkpoint with non-strict matching, handling shape mismatches.
|
| 183 |
+
|
| 184 |
+
Args:
|
| 185 |
+
model (torch.nn.Module): Model to load weights into
|
| 186 |
+
checkpoint_state_dict (dict): State dict from checkpoint
|
| 187 |
+
|
| 188 |
+
Returns:
|
| 189 |
+
_IncompatibleKeys: Named tuple containing:
|
| 190 |
+
- missing_keys: Keys present in model but missing from checkpoint
|
| 191 |
+
- unexpected_keys: Keys present in checkpoint but not in model
|
| 192 |
+
- incorrect_shapes: Keys with mismatched tensor shapes
|
| 193 |
+
|
| 194 |
+
The function handles special cases like:
|
| 195 |
+
- Uninitialized parameters
|
| 196 |
+
- Quantization observers
|
| 197 |
+
- TransformerEngine FP8 states
|
| 198 |
+
"""
|
| 199 |
+
# workaround https://github.com/pytorch/pytorch/issues/24139
|
| 200 |
+
model_state_dict = model.state_dict()
|
| 201 |
+
incorrect_shapes = []
|
| 202 |
+
for k in list(checkpoint_state_dict.keys()):
|
| 203 |
+
if k in model_state_dict:
|
| 204 |
+
if "_extra_state" in k: # Key introduced by TransformerEngine for FP8
|
| 205 |
+
log.debug(f"Skipping key {k} introduced by TransformerEngine for FP8 in the checkpoint.")
|
| 206 |
+
continue
|
| 207 |
+
model_param = model_state_dict[k]
|
| 208 |
+
# Allow mismatch for uninitialized parameters
|
| 209 |
+
if TORCH_VERSION >= (1, 8) and isinstance(model_param, torch.nn.parameter.UninitializedParameter):
|
| 210 |
+
continue
|
| 211 |
+
if not isinstance(model_param, torch.Tensor):
|
| 212 |
+
raise ValueError(
|
| 213 |
+
f"Find non-tensor parameter {k} in the model. type: {type(model_param)} {type(checkpoint_state_dict[k])}, please check if this key is safe to skip or not."
|
| 214 |
+
)
|
| 215 |
+
|
| 216 |
+
shape_model = tuple(model_param.shape)
|
| 217 |
+
shape_checkpoint = tuple(checkpoint_state_dict[k].shape)
|
| 218 |
+
if shape_model != shape_checkpoint:
|
| 219 |
+
has_observer_base_classes = (
|
| 220 |
+
TORCH_VERSION >= (1, 8)
|
| 221 |
+
and hasattr(quantization, "ObserverBase")
|
| 222 |
+
and hasattr(quantization, "FakeQuantizeBase")
|
| 223 |
+
)
|
| 224 |
+
if has_observer_base_classes:
|
| 225 |
+
# Handle the special case of quantization per channel observers,
|
| 226 |
+
# where buffer shape mismatches are expected.
|
| 227 |
+
def _get_module_for_key(model: torch.nn.Module, key: str) -> torch.nn.Module:
|
| 228 |
+
# foo.bar.param_or_buffer_name -> [foo, bar]
|
| 229 |
+
key_parts = key.split(".")[:-1]
|
| 230 |
+
cur_module = model
|
| 231 |
+
for key_part in key_parts:
|
| 232 |
+
cur_module = getattr(cur_module, key_part)
|
| 233 |
+
return cur_module
|
| 234 |
+
|
| 235 |
+
cls_to_skip = (
|
| 236 |
+
ObserverBase,
|
| 237 |
+
FakeQuantizeBase,
|
| 238 |
+
)
|
| 239 |
+
target_module = _get_module_for_key(model, k)
|
| 240 |
+
if isinstance(target_module, cls_to_skip):
|
| 241 |
+
# Do not remove modules with expected shape mismatches
|
| 242 |
+
# them from the state_dict loading. They have special logic
|
| 243 |
+
# in _load_from_state_dict to handle the mismatches.
|
| 244 |
+
continue
|
| 245 |
+
|
| 246 |
+
incorrect_shapes.append((k, shape_checkpoint, shape_model))
|
| 247 |
+
checkpoint_state_dict.pop(k)
|
| 248 |
+
incompatible = model.load_state_dict(checkpoint_state_dict, strict=False)
|
| 249 |
+
# Remove keys with "_extra_state" suffix, which are non-parameter items introduced by TransformerEngine for FP8 handling
|
| 250 |
+
missing_keys = [k for k in incompatible.missing_keys if "_extra_state" not in k]
|
| 251 |
+
unexpected_keys = [k for k in incompatible.unexpected_keys if "_extra_state" not in k]
|
| 252 |
+
return _IncompatibleKeys(
|
| 253 |
+
missing_keys=missing_keys,
|
| 254 |
+
unexpected_keys=unexpected_keys,
|
| 255 |
+
incorrect_shapes=incorrect_shapes,
|
| 256 |
+
)
|
| 257 |
+
|
| 258 |
+
|
| 259 |
+
@contextmanager
|
| 260 |
+
def skip_init_linear():
|
| 261 |
+
# skip init of nn.Linear
|
| 262 |
+
orig_reset_parameters = torch.nn.Linear.reset_parameters
|
| 263 |
+
torch.nn.Linear.reset_parameters = lambda x: x
|
| 264 |
+
xavier_uniform_ = torch.nn.init.xavier_uniform_
|
| 265 |
+
torch.nn.init.xavier_uniform_ = lambda x: x
|
| 266 |
+
yield
|
| 267 |
+
torch.nn.Linear.reset_parameters = orig_reset_parameters
|
| 268 |
+
torch.nn.init.xavier_uniform_ = xavier_uniform_
|
| 269 |
+
|
| 270 |
+
|
| 271 |
+
def load_model_by_config(
|
| 272 |
+
config_job_name,
|
| 273 |
+
config_file="projects/cosmos_video/config/config.py",
|
| 274 |
+
model_class=DiffusionT2WModel,
|
| 275 |
+
):
|
| 276 |
+
config_module = get_config_module(config_file)
|
| 277 |
+
config = importlib.import_module(config_module).make_config()
|
| 278 |
+
|
| 279 |
+
config = override(config, ["--", f"experiment={config_job_name}"])
|
| 280 |
+
|
| 281 |
+
# Check that the config is valid
|
| 282 |
+
config.validate()
|
| 283 |
+
# Freeze the config so developers don't change it during training.
|
| 284 |
+
config.freeze() # type: ignore
|
| 285 |
+
|
| 286 |
+
# Initialize model
|
| 287 |
+
with skip_init_linear():
|
| 288 |
+
model = model_class(config.model)
|
| 289 |
+
return model
|
| 290 |
+
|
| 291 |
+
|
| 292 |
+
def load_network_model(model: DiffusionT2WModel, ckpt_path: str):
|
| 293 |
+
with skip_init_linear():
|
| 294 |
+
model.set_up_model()
|
| 295 |
+
net_state_dict = torch.load(ckpt_path, map_location="cpu", weights_only=True)
|
| 296 |
+
log.debug(non_strict_load_model(model.model, net_state_dict))
|
| 297 |
+
model.cuda()
|
| 298 |
+
|
| 299 |
+
|
| 300 |
+
def load_tokenizer_model(model: DiffusionT2WModel, tokenizer_dir: str):
|
| 301 |
+
with skip_init_linear():
|
| 302 |
+
model.set_up_tokenizer(tokenizer_dir)
|
| 303 |
+
model.cuda()
|
| 304 |
+
|
| 305 |
+
|
| 306 |
+
def prepare_data_batch(
|
| 307 |
+
height: int,
|
| 308 |
+
width: int,
|
| 309 |
+
num_frames: int,
|
| 310 |
+
fps: int,
|
| 311 |
+
prompt_embedding: torch.Tensor,
|
| 312 |
+
negative_prompt_embedding: Optional[torch.Tensor] = None,
|
| 313 |
+
):
|
| 314 |
+
"""Prepare input batch tensors for video generation.
|
| 315 |
+
|
| 316 |
+
Args:
|
| 317 |
+
height (int): Height of video frames
|
| 318 |
+
width (int): Width of video frames
|
| 319 |
+
num_frames (int): Number of frames to generate
|
| 320 |
+
fps (int): Frames per second
|
| 321 |
+
prompt_embedding (torch.Tensor): Encoded text prompt embeddings
|
| 322 |
+
negative_prompt_embedding (torch.Tensor, optional): Encoded negative prompt embeddings
|
| 323 |
+
|
| 324 |
+
Returns:
|
| 325 |
+
dict: Batch dictionary containing:
|
| 326 |
+
- video: Zero tensor of target video shape
|
| 327 |
+
- t5_text_mask: Attention mask for text embeddings
|
| 328 |
+
- image_size: Target frame dimensions
|
| 329 |
+
- fps: Target frame rate
|
| 330 |
+
- num_frames: Number of frames
|
| 331 |
+
- padding_mask: Frame padding mask
|
| 332 |
+
- t5_text_embeddings: Prompt embeddings
|
| 333 |
+
- neg_t5_text_embeddings: Negative prompt embeddings (if provided)
|
| 334 |
+
- neg_t5_text_mask: Mask for negative embeddings (if provided)
|
| 335 |
+
"""
|
| 336 |
+
# Create base data batch
|
| 337 |
+
data_batch = {
|
| 338 |
+
"video": torch.zeros((1, 3, num_frames, height, width), dtype=torch.uint8).cuda(),
|
| 339 |
+
"t5_text_mask": torch.ones(1, 512, dtype=torch.bfloat16).cuda(),
|
| 340 |
+
"image_size": torch.tensor([[height, width, height, width]] * 1, dtype=torch.bfloat16).cuda(),
|
| 341 |
+
"fps": torch.tensor([fps] * 1, dtype=torch.bfloat16).cuda(),
|
| 342 |
+
"num_frames": torch.tensor([num_frames] * 1, dtype=torch.bfloat16).cuda(),
|
| 343 |
+
"padding_mask": torch.zeros((1, 1, height, width), dtype=torch.bfloat16).cuda(),
|
| 344 |
+
}
|
| 345 |
+
|
| 346 |
+
# Handle text embeddings
|
| 347 |
+
|
| 348 |
+
t5_embed = prompt_embedding.to(dtype=torch.bfloat16).cuda()
|
| 349 |
+
data_batch["t5_text_embeddings"] = t5_embed
|
| 350 |
+
|
| 351 |
+
if negative_prompt_embedding is not None:
|
| 352 |
+
neg_t5_embed = negative_prompt_embedding.to(dtype=torch.bfloat16).cuda()
|
| 353 |
+
data_batch["neg_t5_text_embeddings"] = neg_t5_embed
|
| 354 |
+
data_batch["neg_t5_text_mask"] = torch.ones(1, 512, dtype=torch.bfloat16).cuda()
|
| 355 |
+
|
| 356 |
+
return data_batch
|
| 357 |
+
|
| 358 |
+
|
| 359 |
+
def get_video_batch(model, prompt_embedding, negative_prompt_embedding, height, width, fps, num_video_frames):
|
| 360 |
+
"""Prepare complete input batch for video generation including latent dimensions.
|
| 361 |
+
|
| 362 |
+
Args:
|
| 363 |
+
model: Diffusion model instance
|
| 364 |
+
prompt_embedding (torch.Tensor): Text prompt embeddings
|
| 365 |
+
negative_prompt_embedding (torch.Tensor): Negative prompt embeddings
|
| 366 |
+
height (int): Output video height
|
| 367 |
+
width (int): Output video width
|
| 368 |
+
fps (int): Output video frame rate
|
| 369 |
+
num_video_frames (int): Number of frames to generate
|
| 370 |
+
|
| 371 |
+
Returns:
|
| 372 |
+
tuple:
|
| 373 |
+
- data_batch (dict): Complete model input batch
|
| 374 |
+
- state_shape (list): Shape of latent state [C,T,H,W] accounting for VAE compression
|
| 375 |
+
"""
|
| 376 |
+
raw_video_batch = prepare_data_batch(
|
| 377 |
+
height=height,
|
| 378 |
+
width=width,
|
| 379 |
+
num_frames=num_video_frames,
|
| 380 |
+
fps=fps,
|
| 381 |
+
prompt_embedding=prompt_embedding,
|
| 382 |
+
negative_prompt_embedding=negative_prompt_embedding,
|
| 383 |
+
)
|
| 384 |
+
state_shape = [
|
| 385 |
+
model.tokenizer.channel,
|
| 386 |
+
model.tokenizer.get_latent_num_frames(num_video_frames),
|
| 387 |
+
height // model.tokenizer.spatial_compression_factor,
|
| 388 |
+
width // model.tokenizer.spatial_compression_factor,
|
| 389 |
+
]
|
| 390 |
+
return raw_video_batch, state_shape
|
| 391 |
+
|
| 392 |
+
|
| 393 |
+
def generate_world_from_text(
|
| 394 |
+
model: DiffusionT2WModel,
|
| 395 |
+
state_shape: list[int],
|
| 396 |
+
is_negative_prompt: bool,
|
| 397 |
+
data_batch: dict,
|
| 398 |
+
guidance: float,
|
| 399 |
+
num_steps: int,
|
| 400 |
+
seed: int,
|
| 401 |
+
):
|
| 402 |
+
"""Generate video from text prompt using diffusion model.
|
| 403 |
+
|
| 404 |
+
Args:
|
| 405 |
+
model (DiffusionT2WModel): Text-to-video diffusion model
|
| 406 |
+
state_shape (list[int]): Latent state dimensions [C,T,H,W]
|
| 407 |
+
is_negative_prompt (bool): Whether negative prompt is provided
|
| 408 |
+
data_batch (dict): Model input batch with embeddings
|
| 409 |
+
guidance (float): Classifier-free guidance scale
|
| 410 |
+
num_steps (int): Number of diffusion sampling steps
|
| 411 |
+
seed (int): Random seed for reproducibility
|
| 412 |
+
|
| 413 |
+
Returns:
|
| 414 |
+
np.ndarray: Generated video frames [T,H,W,C], range [0,255]
|
| 415 |
+
|
| 416 |
+
The function:
|
| 417 |
+
1. Initializes random latent with maximum noise
|
| 418 |
+
2. Performs guided diffusion sampling
|
| 419 |
+
3. Decodes latents to pixel space
|
| 420 |
+
"""
|
| 421 |
+
x_sigma_max = (
|
| 422 |
+
misc.arch_invariant_rand(
|
| 423 |
+
(1,) + tuple(state_shape),
|
| 424 |
+
torch.float32,
|
| 425 |
+
model.tensor_kwargs["device"],
|
| 426 |
+
seed,
|
| 427 |
+
)
|
| 428 |
+
* model.sde.sigma_max
|
| 429 |
+
)
|
| 430 |
+
|
| 431 |
+
# Generate video
|
| 432 |
+
sample = model.generate_samples_from_batch(
|
| 433 |
+
data_batch,
|
| 434 |
+
guidance=guidance,
|
| 435 |
+
state_shape=state_shape,
|
| 436 |
+
num_steps=num_steps,
|
| 437 |
+
is_negative_prompt=is_negative_prompt,
|
| 438 |
+
seed=seed,
|
| 439 |
+
x_sigma_max=x_sigma_max,
|
| 440 |
+
)
|
| 441 |
+
|
| 442 |
+
return sample
|
| 443 |
+
|
| 444 |
+
|
| 445 |
+
def generate_world_from_video(
|
| 446 |
+
model: DiffusionV2WModel,
|
| 447 |
+
state_shape: list[int],
|
| 448 |
+
is_negative_prompt: bool,
|
| 449 |
+
data_batch: dict,
|
| 450 |
+
guidance: float,
|
| 451 |
+
num_steps: int,
|
| 452 |
+
seed: int,
|
| 453 |
+
condition_latent: torch.Tensor,
|
| 454 |
+
num_input_frames: int,
|
| 455 |
+
) -> Tuple[np.array, list, list]:
|
| 456 |
+
"""Generate video using a conditioning video/image input.
|
| 457 |
+
|
| 458 |
+
Args:
|
| 459 |
+
model (DiffusionV2WModel): The diffusion model instance
|
| 460 |
+
state_shape (list[int]): Shape of the latent state [C,T,H,W]
|
| 461 |
+
is_negative_prompt (bool): Whether negative prompt is provided
|
| 462 |
+
data_batch (dict): Batch containing model inputs including text embeddings
|
| 463 |
+
guidance (float): Classifier-free guidance scale for sampling
|
| 464 |
+
num_steps (int): Number of diffusion sampling steps
|
| 465 |
+
seed (int): Random seed for generation
|
| 466 |
+
condition_latent (torch.Tensor): Latent tensor from conditioning video/image file
|
| 467 |
+
num_input_frames (int): Number of input frames
|
| 468 |
+
|
| 469 |
+
Returns:
|
| 470 |
+
np.array: Generated video frames in shape [T,H,W,C], range [0,255]
|
| 471 |
+
"""
|
| 472 |
+
assert not model.config.conditioner.video_cond_bool.sample_tokens_start_from_p_or_i, "not supported"
|
| 473 |
+
augment_sigma = DEFAULT_AUGMENT_SIGMA
|
| 474 |
+
|
| 475 |
+
if condition_latent.shape[2] < state_shape[1]:
|
| 476 |
+
# Padding condition latent to state shape
|
| 477 |
+
b, c, t, h, w = condition_latent.shape
|
| 478 |
+
condition_latent = torch.cat(
|
| 479 |
+
[
|
| 480 |
+
condition_latent,
|
| 481 |
+
condition_latent.new_zeros(b, c, state_shape[1] - t, h, w),
|
| 482 |
+
],
|
| 483 |
+
dim=2,
|
| 484 |
+
).contiguous()
|
| 485 |
+
num_of_latent_condition = compute_num_latent_frames(model, num_input_frames)
|
| 486 |
+
|
| 487 |
+
x_sigma_max = (
|
| 488 |
+
misc.arch_invariant_rand(
|
| 489 |
+
(1,) + tuple(state_shape),
|
| 490 |
+
torch.float32,
|
| 491 |
+
model.tensor_kwargs["device"],
|
| 492 |
+
seed,
|
| 493 |
+
)
|
| 494 |
+
* model.sde.sigma_max
|
| 495 |
+
)
|
| 496 |
+
|
| 497 |
+
sample = model.generate_samples_from_batch(
|
| 498 |
+
data_batch,
|
| 499 |
+
guidance=guidance,
|
| 500 |
+
state_shape=state_shape,
|
| 501 |
+
num_steps=num_steps,
|
| 502 |
+
is_negative_prompt=is_negative_prompt,
|
| 503 |
+
seed=seed,
|
| 504 |
+
condition_latent=condition_latent,
|
| 505 |
+
num_condition_t=num_of_latent_condition,
|
| 506 |
+
condition_video_augment_sigma_in_inference=augment_sigma,
|
| 507 |
+
x_sigma_max=x_sigma_max,
|
| 508 |
+
)
|
| 509 |
+
return sample
|
| 510 |
+
|
| 511 |
+
|
| 512 |
+
def read_video_or_image_into_frames_BCTHW(
|
| 513 |
+
input_path: str,
|
| 514 |
+
input_path_format: str = "mp4",
|
| 515 |
+
H: int = None,
|
| 516 |
+
W: int = None,
|
| 517 |
+
normalize: bool = True,
|
| 518 |
+
max_frames: int = -1,
|
| 519 |
+
also_return_fps: bool = False,
|
| 520 |
+
) -> torch.Tensor:
|
| 521 |
+
"""Read video or image file and convert to tensor format.
|
| 522 |
+
|
| 523 |
+
Args:
|
| 524 |
+
input_path (str): Path to input video/image file
|
| 525 |
+
input_path_format (str): Format of input file (default: "mp4")
|
| 526 |
+
H (int, optional): Height to resize frames to
|
| 527 |
+
W (int, optional): Width to resize frames to
|
| 528 |
+
normalize (bool): Whether to normalize pixel values to [-1,1] (default: True)
|
| 529 |
+
max_frames (int): Maximum number of frames to read (-1 for all frames)
|
| 530 |
+
also_return_fps (bool): Whether to return fps along with frames
|
| 531 |
+
|
| 532 |
+
Returns:
|
| 533 |
+
torch.Tensor | tuple: Video tensor in shape [B,C,T,H,W], optionally with fps if requested
|
| 534 |
+
"""
|
| 535 |
+
log.debug(f"Reading video from {input_path}")
|
| 536 |
+
|
| 537 |
+
loaded_data = load_from_fileobj(input_path, format=input_path_format)
|
| 538 |
+
frames, meta_data = loaded_data
|
| 539 |
+
if input_path.endswith(".png") or input_path.endswith(".jpg") or input_path.endswith(".jpeg"):
|
| 540 |
+
frames = np.array(frames[0]) # HWC, [0,255]
|
| 541 |
+
if frames.shape[-1] > 3: # RGBA, set the transparent to white
|
| 542 |
+
# Separate the RGB and Alpha channels
|
| 543 |
+
rgb_channels = frames[..., :3]
|
| 544 |
+
alpha_channel = frames[..., 3] / 255.0 # Normalize alpha channel to [0, 1]
|
| 545 |
+
|
| 546 |
+
# Create a white background
|
| 547 |
+
white_bg = np.ones_like(rgb_channels) * 255 # White background in RGB
|
| 548 |
+
|
| 549 |
+
# Blend the RGB channels with the white background based on the alpha channel
|
| 550 |
+
frames = (rgb_channels * alpha_channel[..., None] + white_bg * (1 - alpha_channel[..., None])).astype(
|
| 551 |
+
np.uint8
|
| 552 |
+
)
|
| 553 |
+
frames = [frames]
|
| 554 |
+
fps = 0
|
| 555 |
+
else:
|
| 556 |
+
fps = int(meta_data.get("fps"))
|
| 557 |
+
if max_frames != -1:
|
| 558 |
+
frames = frames[:max_frames]
|
| 559 |
+
input_tensor = np.stack(frames, axis=0)
|
| 560 |
+
input_tensor = einops.rearrange(input_tensor, "t h w c -> t c h w")
|
| 561 |
+
if normalize:
|
| 562 |
+
input_tensor = input_tensor / 128.0 - 1.0
|
| 563 |
+
input_tensor = torch.from_numpy(input_tensor).bfloat16() # TCHW
|
| 564 |
+
log.debug(f"Raw data shape: {input_tensor.shape}")
|
| 565 |
+
if H is not None and W is not None:
|
| 566 |
+
input_tensor = transforms_F.resize(
|
| 567 |
+
input_tensor,
|
| 568 |
+
size=(H, W), # type: ignore
|
| 569 |
+
interpolation=transforms_F.InterpolationMode.BICUBIC,
|
| 570 |
+
antialias=True,
|
| 571 |
+
)
|
| 572 |
+
input_tensor = einops.rearrange(input_tensor, "(b t) c h w -> b c t h w", b=1)
|
| 573 |
+
if normalize:
|
| 574 |
+
input_tensor = input_tensor.to("cuda")
|
| 575 |
+
log.debug(f"Load shape {input_tensor.shape} value {input_tensor.min()}, {input_tensor.max()}")
|
| 576 |
+
if also_return_fps:
|
| 577 |
+
return input_tensor, fps
|
| 578 |
+
return input_tensor
|
| 579 |
+
|
| 580 |
+
|
| 581 |
+
def compute_num_latent_frames(model: DiffusionV2WModel, num_input_frames: int, downsample_factor=8) -> int:
|
| 582 |
+
"""This function computes the number of latent frames given the number of input frames.
|
| 583 |
+
Args:
|
| 584 |
+
model (DiffusionV2WModel): video generation model
|
| 585 |
+
num_input_frames (int): number of input frames
|
| 586 |
+
downsample_factor (int): downsample factor for temporal reduce
|
| 587 |
+
Returns:
|
| 588 |
+
int: number of latent frames
|
| 589 |
+
"""
|
| 590 |
+
num_latent_frames = (
|
| 591 |
+
num_input_frames
|
| 592 |
+
// model.tokenizer.video_vae.pixel_chunk_duration
|
| 593 |
+
* model.tokenizer.video_vae.latent_chunk_duration
|
| 594 |
+
)
|
| 595 |
+
if num_input_frames % model.tokenizer.video_vae.latent_chunk_duration == 1:
|
| 596 |
+
num_latent_frames += 1
|
| 597 |
+
elif num_input_frames % model.tokenizer.video_vae.latent_chunk_duration > 1:
|
| 598 |
+
assert (
|
| 599 |
+
num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1
|
| 600 |
+
) % downsample_factor == 0, f"num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1 must be divisible by {downsample_factor}"
|
| 601 |
+
num_latent_frames += (
|
| 602 |
+
1 + (num_input_frames % model.tokenizer.video_vae.pixel_chunk_duration - 1) // downsample_factor
|
| 603 |
+
)
|
| 604 |
+
|
| 605 |
+
return num_latent_frames
|
| 606 |
+
|
| 607 |
+
|
| 608 |
+
def create_condition_latent_from_input_frames(
|
| 609 |
+
model: DiffusionV2WModel,
|
| 610 |
+
input_frames: torch.Tensor,
|
| 611 |
+
num_frames_condition: int = 25,
|
| 612 |
+
):
|
| 613 |
+
"""Create condition latent for video generation from input frames.
|
| 614 |
+
|
| 615 |
+
Takes the last num_frames_condition frames from input as conditioning.
|
| 616 |
+
|
| 617 |
+
Args:
|
| 618 |
+
model (DiffusionV2WModel): Video generation model
|
| 619 |
+
input_frames (torch.Tensor): Input video tensor [B,C,T,H,W], range [-1,1]
|
| 620 |
+
num_frames_condition (int): Number of frames to use for conditioning
|
| 621 |
+
|
| 622 |
+
Returns:
|
| 623 |
+
tuple: (condition_latent, encode_input_frames) where:
|
| 624 |
+
- condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W]
|
| 625 |
+
- encode_input_frames (torch.Tensor): Padded input frames used for encoding
|
| 626 |
+
"""
|
| 627 |
+
B, C, T, H, W = input_frames.shape
|
| 628 |
+
num_frames_encode = (
|
| 629 |
+
model.tokenizer.pixel_chunk_duration
|
| 630 |
+
) # (model.state_shape[1] - 1) / model.vae.pixel_chunk_duration + 1
|
| 631 |
+
log.debug(
|
| 632 |
+
f"num_frames_encode not set, set it based on pixel chunk duration and model state shape: {num_frames_encode}"
|
| 633 |
+
)
|
| 634 |
+
|
| 635 |
+
log.debug(
|
| 636 |
+
f"Create condition latent from input frames {input_frames.shape}, value {input_frames.min()}, {input_frames.max()}, dtype {input_frames.dtype}"
|
| 637 |
+
)
|
| 638 |
+
|
| 639 |
+
assert (
|
| 640 |
+
input_frames.shape[2] >= num_frames_condition
|
| 641 |
+
), f"input_frames not enough for condition, require at least {num_frames_condition}, get {input_frames.shape[2]}, {input_frames.shape}"
|
| 642 |
+
assert (
|
| 643 |
+
num_frames_encode >= num_frames_condition
|
| 644 |
+
), f"num_frames_encode should be larger than num_frames_condition, get {num_frames_encode}, {num_frames_condition}"
|
| 645 |
+
|
| 646 |
+
# Put the conditioal frames to the begining of the video, and pad the end with zero
|
| 647 |
+
condition_frames = input_frames[:, :, -num_frames_condition:]
|
| 648 |
+
padding_frames = condition_frames.new_zeros(B, C, num_frames_encode - num_frames_condition, H, W)
|
| 649 |
+
encode_input_frames = torch.cat([condition_frames, padding_frames], dim=2)
|
| 650 |
+
|
| 651 |
+
log.debug(
|
| 652 |
+
f"create latent with input shape {encode_input_frames.shape} including padding {num_frames_encode - num_frames_condition} at the end"
|
| 653 |
+
)
|
| 654 |
+
latent = model.encode(encode_input_frames)
|
| 655 |
+
return latent, encode_input_frames
|
| 656 |
+
|
| 657 |
+
|
| 658 |
+
def get_condition_latent(
|
| 659 |
+
model: DiffusionV2WModel,
|
| 660 |
+
input_image_or_video_path: str,
|
| 661 |
+
num_input_frames: int = 1,
|
| 662 |
+
state_shape: list[int] = None,
|
| 663 |
+
):
|
| 664 |
+
"""Get condition latent from input image/video file.
|
| 665 |
+
|
| 666 |
+
Args:
|
| 667 |
+
model (DiffusionV2WModel): Video generation model
|
| 668 |
+
input_image_or_video_path (str): Path to conditioning image/video
|
| 669 |
+
num_input_frames (int): Number of input frames for video2world prediction
|
| 670 |
+
|
| 671 |
+
Returns:
|
| 672 |
+
tuple: (condition_latent, input_frames) where:
|
| 673 |
+
- condition_latent (torch.Tensor): Encoded latent condition [B,C,T,H,W]
|
| 674 |
+
- input_frames (torch.Tensor): Input frames tensor [B,C,T,H,W]
|
| 675 |
+
"""
|
| 676 |
+
if state_shape is None:
|
| 677 |
+
state_shape = model.state_shape
|
| 678 |
+
assert num_input_frames > 0, "num_input_frames must be greater than 0"
|
| 679 |
+
|
| 680 |
+
H, W = (
|
| 681 |
+
state_shape[-2] * model.tokenizer.spatial_compression_factor,
|
| 682 |
+
state_shape[-1] * model.tokenizer.spatial_compression_factor,
|
| 683 |
+
)
|
| 684 |
+
|
| 685 |
+
input_path_format = input_image_or_video_path.split(".")[-1]
|
| 686 |
+
input_frames = read_video_or_image_into_frames_BCTHW(
|
| 687 |
+
input_image_or_video_path,
|
| 688 |
+
input_path_format=input_path_format,
|
| 689 |
+
H=H,
|
| 690 |
+
W=W,
|
| 691 |
+
)
|
| 692 |
+
|
| 693 |
+
condition_latent, _ = create_condition_latent_from_input_frames(model, input_frames, num_input_frames)
|
| 694 |
+
condition_latent = condition_latent.to(torch.bfloat16)
|
| 695 |
+
|
| 696 |
+
return condition_latent
|
| 697 |
+
|
| 698 |
+
|
| 699 |
+
def check_input_frames(input_path: str, required_frames: int) -> bool:
|
| 700 |
+
"""Check if input video/image has sufficient frames.
|
| 701 |
+
|
| 702 |
+
Args:
|
| 703 |
+
input_path: Path to input video or image
|
| 704 |
+
required_frames: Number of required frames
|
| 705 |
+
|
| 706 |
+
Returns:
|
| 707 |
+
np.ndarray of frames if valid, None if invalid
|
| 708 |
+
"""
|
| 709 |
+
if input_path.endswith((".jpg", ".jpeg", ".png")):
|
| 710 |
+
if required_frames > 1:
|
| 711 |
+
log.error(f"Input ({input_path}) is an image but {required_frames} frames are required")
|
| 712 |
+
return False
|
| 713 |
+
return True # Let the pipeline handle image loading
|
| 714 |
+
# For video input
|
| 715 |
+
try:
|
| 716 |
+
vid = imageio.get_reader(input_path, "ffmpeg")
|
| 717 |
+
frame_count = vid.count_frames()
|
| 718 |
+
|
| 719 |
+
if frame_count < required_frames:
|
| 720 |
+
log.error(f"Input video has {frame_count} frames but {required_frames} frames are required")
|
| 721 |
+
return False
|
| 722 |
+
else:
|
| 723 |
+
return True
|
| 724 |
+
except Exception as e:
|
| 725 |
+
log.error(f"Error reading video file {input_path}: {e}")
|
| 726 |
+
return False
|
model_t2w.py
ADDED
|
@@ -0,0 +1,282 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import Callable, Dict, Optional, Tuple
|
| 17 |
+
|
| 18 |
+
from Cosmos.utils import misc
|
| 19 |
+
import torch
|
| 20 |
+
from torch import Tensor
|
| 21 |
+
|
| 22 |
+
from Cosmos.conditioner import CosmosCondition
|
| 23 |
+
from cosmos1.models.diffusion.diffusion.functional.batch_ops import batch_mul
|
| 24 |
+
from cosmos1.models.diffusion.diffusion.modules.denoiser_scaling import EDMScaling
|
| 25 |
+
from cosmos1.models.diffusion.diffusion.modules.res_sampler import COMMON_SOLVER_OPTIONS, Sampler
|
| 26 |
+
from Cosmos.types import DenoisePrediction
|
| 27 |
+
from Cosmos.module.blocks import FourierFeatures
|
| 28 |
+
from Cosmos.module.pretrained_vae import BaseVAE
|
| 29 |
+
from Cosmos.utils import log
|
| 30 |
+
from Cosmos.lazy_config import instantiate as lazy_instantiate
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
class EDMSDE:
|
| 34 |
+
def __init__(
|
| 35 |
+
self,
|
| 36 |
+
sigma_max: float,
|
| 37 |
+
sigma_min: float,
|
| 38 |
+
):
|
| 39 |
+
self.sigma_max = sigma_max
|
| 40 |
+
self.sigma_min = sigma_min
|
| 41 |
+
|
| 42 |
+
|
| 43 |
+
class DiffusionT2WModel(torch.nn.Module):
|
| 44 |
+
"""Text-to-world diffusion model that generates video frames from text descriptions.
|
| 45 |
+
|
| 46 |
+
This model implements a diffusion-based approach for generating videos conditioned on text input.
|
| 47 |
+
It handles the full pipeline including encoding/decoding through a VAE, diffusion sampling,
|
| 48 |
+
and classifier-free guidance.
|
| 49 |
+
"""
|
| 50 |
+
|
| 51 |
+
def __init__(self, config):
|
| 52 |
+
"""Initialize the diffusion model.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
config: Configuration object containing model parameters and architecture settings
|
| 56 |
+
"""
|
| 57 |
+
super().__init__()
|
| 58 |
+
# Initialize trained_data_record with defaultdict, key: image, video, iteration
|
| 59 |
+
self.config = config
|
| 60 |
+
|
| 61 |
+
self.precision = {
|
| 62 |
+
"float32": torch.float32,
|
| 63 |
+
"float16": torch.float16,
|
| 64 |
+
"bfloat16": torch.bfloat16,
|
| 65 |
+
}[config.precision]
|
| 66 |
+
self.tensor_kwargs = {"device": "cuda", "dtype": self.precision}
|
| 67 |
+
log.debug(f"DiffusionModel: precision {self.precision}")
|
| 68 |
+
# Timer passed to network to detect slow ranks.
|
| 69 |
+
# 1. set data keys and data information
|
| 70 |
+
self.sigma_data = config.sigma_data
|
| 71 |
+
self.state_shape = list(config.latent_shape)
|
| 72 |
+
self.setup_data_key()
|
| 73 |
+
|
| 74 |
+
# 2. setup up diffusion processing and scaling~(pre-condition), sampler
|
| 75 |
+
self.sde = EDMSDE(sigma_max=80, sigma_min=0.0002)
|
| 76 |
+
self.sampler = Sampler()
|
| 77 |
+
self.scaling = EDMScaling(self.sigma_data)
|
| 78 |
+
self.tokenizer = None
|
| 79 |
+
self.model = None
|
| 80 |
+
|
| 81 |
+
@property
|
| 82 |
+
def net(self):
|
| 83 |
+
return self.model.net
|
| 84 |
+
|
| 85 |
+
@property
|
| 86 |
+
def conditioner(self):
|
| 87 |
+
return self.model.conditioner
|
| 88 |
+
|
| 89 |
+
@property
|
| 90 |
+
def logvar(self):
|
| 91 |
+
return self.model.logvar
|
| 92 |
+
|
| 93 |
+
def set_up_tokenizer(self, tokenizer_dir: str):
|
| 94 |
+
self.tokenizer: BaseVAE = lazy_instantiate(self.config.tokenizer)
|
| 95 |
+
self.tokenizer.load_weights(tokenizer_dir)
|
| 96 |
+
if hasattr(self.tokenizer, "reset_dtype"):
|
| 97 |
+
self.tokenizer.reset_dtype()
|
| 98 |
+
|
| 99 |
+
@misc.timer("DiffusionModel: set_up_model")
|
| 100 |
+
def set_up_model(self, memory_format: torch.memory_format = torch.preserve_format):
|
| 101 |
+
"""Initialize the core model components including network, conditioner and logvar."""
|
| 102 |
+
self.model = self.build_model()
|
| 103 |
+
self.model = self.model.to(memory_format=memory_format, **self.tensor_kwargs)
|
| 104 |
+
|
| 105 |
+
def build_model(self) -> torch.nn.ModuleDict:
|
| 106 |
+
"""Construct the model's neural network components.
|
| 107 |
+
|
| 108 |
+
Returns:
|
| 109 |
+
ModuleDict containing the network, conditioner and logvar components
|
| 110 |
+
"""
|
| 111 |
+
config = self.config
|
| 112 |
+
net = lazy_instantiate(config.net)
|
| 113 |
+
conditioner = lazy_instantiate(config.conditioner)
|
| 114 |
+
logvar = torch.nn.Sequential(
|
| 115 |
+
FourierFeatures(num_channels=128, normalize=True), torch.nn.Linear(128, 1, bias=False)
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
return torch.nn.ModuleDict(
|
| 119 |
+
{
|
| 120 |
+
"net": net,
|
| 121 |
+
"conditioner": conditioner,
|
| 122 |
+
"logvar": logvar,
|
| 123 |
+
}
|
| 124 |
+
)
|
| 125 |
+
|
| 126 |
+
@torch.no_grad()
|
| 127 |
+
def encode(self, state: torch.Tensor) -> torch.Tensor:
|
| 128 |
+
"""Encode input state into latent representation using VAE.
|
| 129 |
+
|
| 130 |
+
Args:
|
| 131 |
+
state: Input tensor to encode
|
| 132 |
+
|
| 133 |
+
Returns:
|
| 134 |
+
Encoded latent representation scaled by sigma_data
|
| 135 |
+
"""
|
| 136 |
+
return self.tokenizer.encode(state) * self.sigma_data
|
| 137 |
+
|
| 138 |
+
@torch.no_grad()
|
| 139 |
+
def decode(self, latent: torch.Tensor) -> torch.Tensor:
|
| 140 |
+
"""Decode latent representation back to pixel space using VAE.
|
| 141 |
+
|
| 142 |
+
Args:
|
| 143 |
+
latent: Latent tensor to decode
|
| 144 |
+
|
| 145 |
+
Returns:
|
| 146 |
+
Decoded tensor in pixel space
|
| 147 |
+
"""
|
| 148 |
+
return self.tokenizer.decode(latent / self.sigma_data)
|
| 149 |
+
|
| 150 |
+
def setup_data_key(self) -> None:
|
| 151 |
+
"""Configure input data keys for video and image data."""
|
| 152 |
+
self.input_data_key = self.config.input_data_key # by default it is video key for Video diffusion model
|
| 153 |
+
|
| 154 |
+
def get_x0_fn_from_batch(
|
| 155 |
+
self,
|
| 156 |
+
data_batch: Dict,
|
| 157 |
+
guidance: float = 1.5,
|
| 158 |
+
is_negative_prompt: bool = False,
|
| 159 |
+
) -> Callable:
|
| 160 |
+
"""
|
| 161 |
+
Generates a callable function `x0_fn` based on the provided data batch and guidance factor.
|
| 162 |
+
|
| 163 |
+
This function processes the input data batch through a conditioning workflow to obtain
|
| 164 |
+
conditioned and unconditioned states. It then defines a nested function `x0_fn` which
|
| 165 |
+
applies denoising on an input `noise_x` at a given noise level `sigma`.
|
| 166 |
+
|
| 167 |
+
Args:
|
| 168 |
+
data_batch: A batch of data used for conditioning. Format should align with conditioner
|
| 169 |
+
guidance: Scalar value that modulates influence of conditioned vs unconditioned state
|
| 170 |
+
is_negative_prompt: Use negative prompt t5 in uncondition if true
|
| 171 |
+
|
| 172 |
+
Returns:
|
| 173 |
+
A function `x0_fn(noise_x, sigma)` that takes noise_x and sigma, returns x0 prediction
|
| 174 |
+
"""
|
| 175 |
+
if is_negative_prompt:
|
| 176 |
+
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
|
| 177 |
+
else:
|
| 178 |
+
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
|
| 179 |
+
|
| 180 |
+
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
| 181 |
+
cond_x0 = self.denoise(noise_x, sigma, condition).x0
|
| 182 |
+
uncond_x0 = self.denoise(noise_x, sigma, uncondition).x0
|
| 183 |
+
raw_x0 = cond_x0 + guidance * (cond_x0 - uncond_x0)
|
| 184 |
+
if "guided_image" in data_batch:
|
| 185 |
+
# replacement trick that enables inpainting with base model
|
| 186 |
+
assert "guided_mask" in data_batch, "guided_mask should be in data_batch if guided_image is present"
|
| 187 |
+
guide_image = data_batch["guided_image"]
|
| 188 |
+
guide_mask = data_batch["guided_mask"]
|
| 189 |
+
raw_x0 = guide_mask * guide_image + (1 - guide_mask) * raw_x0
|
| 190 |
+
|
| 191 |
+
return raw_x0
|
| 192 |
+
|
| 193 |
+
return x0_fn
|
| 194 |
+
|
| 195 |
+
def denoise(self, xt: torch.Tensor, sigma: torch.Tensor, condition: CosmosCondition) -> DenoisePrediction:
|
| 196 |
+
"""
|
| 197 |
+
Performs denoising on the input noise data, noise level, and condition
|
| 198 |
+
|
| 199 |
+
Args:
|
| 200 |
+
xt (torch.Tensor): The input noise data.
|
| 201 |
+
sigma (torch.Tensor): The noise level.
|
| 202 |
+
condition (CosmosCondition): conditional information, generated from self.conditioner
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
DenoisePrediction: The denoised prediction, it includes clean data predicton (x0), \
|
| 206 |
+
noise prediction (eps_pred) and optional confidence (logvar).
|
| 207 |
+
"""
|
| 208 |
+
|
| 209 |
+
xt = xt.to(**self.tensor_kwargs)
|
| 210 |
+
sigma = sigma.to(**self.tensor_kwargs)
|
| 211 |
+
# get precondition for the network
|
| 212 |
+
c_skip, c_out, c_in, c_noise = self.scaling(sigma=sigma)
|
| 213 |
+
|
| 214 |
+
# forward pass through the network
|
| 215 |
+
net_output = self.net(
|
| 216 |
+
x=batch_mul(c_in, xt), # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf
|
| 217 |
+
timesteps=c_noise, # Eq. 7 of https://arxiv.org/pdf/2206.00364.pdf
|
| 218 |
+
**condition.to_dict(),
|
| 219 |
+
)
|
| 220 |
+
|
| 221 |
+
logvar = self.model.logvar(c_noise)
|
| 222 |
+
x0_pred = batch_mul(c_skip, xt) + batch_mul(c_out, net_output)
|
| 223 |
+
|
| 224 |
+
# get noise prediction based on sde
|
| 225 |
+
eps_pred = batch_mul(xt - x0_pred, 1.0 / sigma)
|
| 226 |
+
|
| 227 |
+
return DenoisePrediction(x0_pred, eps_pred, logvar)
|
| 228 |
+
|
| 229 |
+
def generate_samples_from_batch(
|
| 230 |
+
self,
|
| 231 |
+
data_batch: Dict,
|
| 232 |
+
guidance: float = 1.5,
|
| 233 |
+
seed: int = 1,
|
| 234 |
+
state_shape: Tuple | None = None,
|
| 235 |
+
n_sample: int | None = None,
|
| 236 |
+
is_negative_prompt: bool = False,
|
| 237 |
+
num_steps: int = 35,
|
| 238 |
+
solver_option: COMMON_SOLVER_OPTIONS = "2ab",
|
| 239 |
+
x_sigma_max: Optional[torch.Tensor] = None,
|
| 240 |
+
sigma_max: float | None = None,
|
| 241 |
+
) -> Tensor:
|
| 242 |
+
"""Generate samples from a data batch using diffusion sampling.
|
| 243 |
+
|
| 244 |
+
This function generates samples from either image or video data batches using diffusion sampling.
|
| 245 |
+
It handles both conditional and unconditional generation with classifier-free guidance.
|
| 246 |
+
|
| 247 |
+
Args:
|
| 248 |
+
data_batch (Dict): Raw data batch from the training data loader
|
| 249 |
+
guidance (float, optional): Classifier-free guidance weight. Defaults to 1.5.
|
| 250 |
+
seed (int, optional): Random seed for reproducibility. Defaults to 1.
|
| 251 |
+
state_shape (Tuple | None, optional): Shape of the state tensor. Uses self.state_shape if None. Defaults to None.
|
| 252 |
+
n_sample (int | None, optional): Number of samples to generate. Defaults to None.
|
| 253 |
+
is_negative_prompt (bool, optional): Whether to use negative prompt for unconditional generation. Defaults to False.
|
| 254 |
+
num_steps (int, optional): Number of diffusion sampling steps. Defaults to 35.
|
| 255 |
+
solver_option (COMMON_SOLVER_OPTIONS, optional): Differential equation solver option. Defaults to "2ab" (multistep solver).
|
| 256 |
+
x_sigma_max (Optional[torch.Tensor], optional): Initial noisy tensor. If None, randomly initialized. Defaults to None.
|
| 257 |
+
sigma_max (float | None, optional): Maximum noise level. Uses self.sde.sigma_max if None. Defaults to None.
|
| 258 |
+
|
| 259 |
+
Returns:
|
| 260 |
+
Tensor: Generated samples after diffusion sampling
|
| 261 |
+
"""
|
| 262 |
+
x0_fn = self.get_x0_fn_from_batch(data_batch, guidance, is_negative_prompt=is_negative_prompt)
|
| 263 |
+
if sigma_max is None:
|
| 264 |
+
sigma_max = self.sde.sigma_max
|
| 265 |
+
else:
|
| 266 |
+
log.info("Using provided sigma_max for diffusion sampling.")
|
| 267 |
+
if x_sigma_max is None:
|
| 268 |
+
x_sigma_max = (
|
| 269 |
+
misc.arch_invariant_rand(
|
| 270 |
+
(n_sample,) + tuple(state_shape),
|
| 271 |
+
torch.float32,
|
| 272 |
+
self.tensor_kwargs["device"],
|
| 273 |
+
seed,
|
| 274 |
+
)
|
| 275 |
+
* sigma_max
|
| 276 |
+
)
|
| 277 |
+
|
| 278 |
+
samples = self.sampler(
|
| 279 |
+
x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=sigma_max, solver_option=solver_option
|
| 280 |
+
)
|
| 281 |
+
|
| 282 |
+
return samples
|
model_v2w.py
ADDED
|
@@ -0,0 +1,341 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Callable, Dict, Optional, Tuple, Union
|
| 18 |
+
|
| 19 |
+
from Cosmos.utils import misc
|
| 20 |
+
import torch
|
| 21 |
+
from torch import Tensor
|
| 22 |
+
|
| 23 |
+
from Cosmos.conditioner import VideoExtendCondition
|
| 24 |
+
from cosmos1.models.diffusion.config.base.conditioner import VideoCondBoolConfig
|
| 25 |
+
from cosmos1.models.diffusion.diffusion.functional.batch_ops import batch_mul
|
| 26 |
+
from Cosmos.model_t2w import DiffusionT2WModel
|
| 27 |
+
from Cosmos.utils import log
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
@dataclass
|
| 31 |
+
class VideoDenoisePrediction:
|
| 32 |
+
x0: torch.Tensor # clean data prediction
|
| 33 |
+
eps: Optional[torch.Tensor] = None # noise prediction
|
| 34 |
+
logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty
|
| 35 |
+
xt: Optional[torch.Tensor] = None # input to the network, before muliply with c_in
|
| 36 |
+
x0_pred_replaced: Optional[torch.Tensor] = None # x0 prediction with condition region replaced by gt_latent
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class DiffusionV2WModel(DiffusionT2WModel):
|
| 40 |
+
def __init__(self, config):
|
| 41 |
+
super().__init__(config)
|
| 42 |
+
|
| 43 |
+
def augment_conditional_latent_frames(
|
| 44 |
+
self,
|
| 45 |
+
condition: VideoExtendCondition,
|
| 46 |
+
cfg_video_cond_bool: VideoCondBoolConfig,
|
| 47 |
+
gt_latent: Tensor,
|
| 48 |
+
condition_video_augment_sigma_in_inference: float = 0.001,
|
| 49 |
+
sigma: Tensor = None,
|
| 50 |
+
seed: int = 1,
|
| 51 |
+
) -> Union[VideoExtendCondition, Tensor]:
|
| 52 |
+
"""Augments the conditional frames with noise during inference.
|
| 53 |
+
|
| 54 |
+
Args:
|
| 55 |
+
condition (VideoExtendCondition): condition object
|
| 56 |
+
condition_video_indicator: binary tensor indicating the region is condition(value=1) or generation(value=0). Bx1xTx1x1 tensor.
|
| 57 |
+
condition_video_input_mask: input mask for the network input, indicating the condition region. B,1,T,H,W tensor. will be concat with the input for the network.
|
| 58 |
+
cfg_video_cond_bool (VideoCondBoolConfig): video condition bool config
|
| 59 |
+
gt_latent (Tensor): ground truth latent tensor in shape B,C,T,H,W
|
| 60 |
+
condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference
|
| 61 |
+
sigma (Tensor): noise level for the generation region
|
| 62 |
+
seed (int): random seed for reproducibility
|
| 63 |
+
Returns:
|
| 64 |
+
VideoExtendCondition: updated condition object
|
| 65 |
+
condition_video_augment_sigma: sigma for the condition region, feed to the network
|
| 66 |
+
augment_latent (Tensor): augmented latent tensor in shape B,C,T,H,W
|
| 67 |
+
|
| 68 |
+
"""
|
| 69 |
+
|
| 70 |
+
# Inference only, use fixed sigma for the condition region
|
| 71 |
+
assert (
|
| 72 |
+
condition_video_augment_sigma_in_inference is not None
|
| 73 |
+
), "condition_video_augment_sigma_in_inference should be provided"
|
| 74 |
+
augment_sigma = condition_video_augment_sigma_in_inference
|
| 75 |
+
|
| 76 |
+
if augment_sigma >= sigma.flatten()[0]:
|
| 77 |
+
# This is a inference trick! If the sampling sigma is smaller than the augment sigma, we will start denoising the condition region together.
|
| 78 |
+
# This is achieved by setting all region as `generation`, i.e. value=0
|
| 79 |
+
log.debug("augment_sigma larger than sigma or other frame, remove condition")
|
| 80 |
+
condition.condition_video_indicator = condition.condition_video_indicator * 0
|
| 81 |
+
|
| 82 |
+
augment_sigma = torch.tensor([augment_sigma], **self.tensor_kwargs)
|
| 83 |
+
|
| 84 |
+
# Now apply the augment_sigma to the gt_latent
|
| 85 |
+
|
| 86 |
+
noise = misc.arch_invariant_rand(
|
| 87 |
+
gt_latent.shape,
|
| 88 |
+
torch.float32,
|
| 89 |
+
self.tensor_kwargs["device"],
|
| 90 |
+
seed,
|
| 91 |
+
)
|
| 92 |
+
|
| 93 |
+
augment_latent = gt_latent + noise * augment_sigma[:, None, None, None, None]
|
| 94 |
+
|
| 95 |
+
_, _, c_in_augment, _ = self.scaling(sigma=augment_sigma)
|
| 96 |
+
|
| 97 |
+
# Multiply the whole latent with c_in_augment
|
| 98 |
+
augment_latent_cin = batch_mul(augment_latent, c_in_augment)
|
| 99 |
+
|
| 100 |
+
# Since the whole latent will multiply with c_in later, we devide the value to cancel the effect
|
| 101 |
+
_, _, c_in, _ = self.scaling(sigma=sigma)
|
| 102 |
+
augment_latent_cin = batch_mul(augment_latent_cin, 1 / c_in)
|
| 103 |
+
|
| 104 |
+
return condition, augment_latent_cin
|
| 105 |
+
|
| 106 |
+
def denoise(
|
| 107 |
+
self,
|
| 108 |
+
noise_x: Tensor,
|
| 109 |
+
sigma: Tensor,
|
| 110 |
+
condition: VideoExtendCondition,
|
| 111 |
+
condition_video_augment_sigma_in_inference: float = 0.001,
|
| 112 |
+
seed: int = 1,
|
| 113 |
+
) -> VideoDenoisePrediction:
|
| 114 |
+
"""Denoises input tensor using conditional video generation.
|
| 115 |
+
|
| 116 |
+
Args:
|
| 117 |
+
noise_x (Tensor): Noisy input tensor.
|
| 118 |
+
sigma (Tensor): Noise level.
|
| 119 |
+
condition (VideoExtendCondition): Condition for denoising.
|
| 120 |
+
condition_video_augment_sigma_in_inference (float): sigma for condition video augmentation in inference
|
| 121 |
+
seed (int): Random seed for reproducibility
|
| 122 |
+
Returns:
|
| 123 |
+
VideoDenoisePrediction containing:
|
| 124 |
+
- x0: Denoised prediction
|
| 125 |
+
- eps: Noise prediction
|
| 126 |
+
- logvar: Log variance of noise prediction
|
| 127 |
+
- xt: Input before c_in multiplication
|
| 128 |
+
- x0_pred_replaced: x0 prediction with condition regions replaced by ground truth
|
| 129 |
+
"""
|
| 130 |
+
|
| 131 |
+
assert (
|
| 132 |
+
condition.gt_latent is not None
|
| 133 |
+
), f"find None gt_latent in condition, likely didn't call self.add_condition_video_indicator_and_video_input_mask when preparing the condition or this is a image batch but condition.data_type is wrong, get {noise_x.shape}"
|
| 134 |
+
gt_latent = condition.gt_latent
|
| 135 |
+
cfg_video_cond_bool: VideoCondBoolConfig = self.config.conditioner.video_cond_bool
|
| 136 |
+
|
| 137 |
+
condition_latent = gt_latent
|
| 138 |
+
|
| 139 |
+
# Augment the latent with different sigma value, and add the augment_sigma to the condition object if needed
|
| 140 |
+
condition, augment_latent = self.augment_conditional_latent_frames(
|
| 141 |
+
condition, cfg_video_cond_bool, condition_latent, condition_video_augment_sigma_in_inference, sigma, seed
|
| 142 |
+
)
|
| 143 |
+
condition_video_indicator = condition.condition_video_indicator # [B, 1, T, 1, 1]
|
| 144 |
+
|
| 145 |
+
# Compose the model input with condition region (augment_latent) and generation region (noise_x)
|
| 146 |
+
new_noise_xt = condition_video_indicator * augment_latent + (1 - condition_video_indicator) * noise_x
|
| 147 |
+
# Call the abse model
|
| 148 |
+
denoise_pred = super().denoise(new_noise_xt, sigma, condition)
|
| 149 |
+
|
| 150 |
+
x0_pred_replaced = condition_video_indicator * gt_latent + (1 - condition_video_indicator) * denoise_pred.x0
|
| 151 |
+
|
| 152 |
+
x0_pred = x0_pred_replaced
|
| 153 |
+
|
| 154 |
+
return VideoDenoisePrediction(
|
| 155 |
+
x0=x0_pred,
|
| 156 |
+
eps=batch_mul(noise_x - x0_pred, 1.0 / sigma),
|
| 157 |
+
logvar=denoise_pred.logvar,
|
| 158 |
+
xt=new_noise_xt,
|
| 159 |
+
x0_pred_replaced=x0_pred_replaced,
|
| 160 |
+
)
|
| 161 |
+
|
| 162 |
+
def generate_samples_from_batch(
|
| 163 |
+
self,
|
| 164 |
+
data_batch: Dict,
|
| 165 |
+
guidance: float = 1.5,
|
| 166 |
+
seed: int = 1,
|
| 167 |
+
state_shape: Tuple | None = None,
|
| 168 |
+
n_sample: int | None = None,
|
| 169 |
+
is_negative_prompt: bool = False,
|
| 170 |
+
num_steps: int = 35,
|
| 171 |
+
condition_latent: Union[torch.Tensor, None] = None,
|
| 172 |
+
num_condition_t: Union[int, None] = None,
|
| 173 |
+
condition_video_augment_sigma_in_inference: float = None,
|
| 174 |
+
add_input_frames_guidance: bool = False,
|
| 175 |
+
x_sigma_max: Optional[torch.Tensor] = None,
|
| 176 |
+
) -> Tensor:
|
| 177 |
+
"""Generates video samples conditioned on input frames.
|
| 178 |
+
|
| 179 |
+
Args:
|
| 180 |
+
data_batch: Input data dictionary
|
| 181 |
+
guidance: Classifier-free guidance scale
|
| 182 |
+
seed: Random seed for reproducibility
|
| 183 |
+
state_shape: Shape of output tensor (defaults to model's state shape)
|
| 184 |
+
n_sample: Number of samples to generate (defaults to batch size)
|
| 185 |
+
is_negative_prompt: Whether to use negative prompting
|
| 186 |
+
num_steps: Number of denoising steps
|
| 187 |
+
condition_latent: Conditioning frames tensor (B,C,T,H,W)
|
| 188 |
+
num_condition_t: Number of frames to condition on
|
| 189 |
+
condition_video_augment_sigma_in_inference: Noise level for condition augmentation
|
| 190 |
+
add_input_frames_guidance: Whether to apply guidance to input frames
|
| 191 |
+
x_sigma_max: Maximum noise level tensor
|
| 192 |
+
|
| 193 |
+
Returns:
|
| 194 |
+
Generated video samples tensor
|
| 195 |
+
"""
|
| 196 |
+
|
| 197 |
+
if n_sample is None:
|
| 198 |
+
input_key = self.input_data_key
|
| 199 |
+
n_sample = data_batch[input_key].shape[0]
|
| 200 |
+
if state_shape is None:
|
| 201 |
+
log.debug(f"Default Video state shape is used. {self.state_shape}")
|
| 202 |
+
state_shape = self.state_shape
|
| 203 |
+
|
| 204 |
+
assert condition_latent is not None, "condition_latent should be provided"
|
| 205 |
+
|
| 206 |
+
x0_fn = self.get_x0_fn_from_batch_with_condition_latent(
|
| 207 |
+
data_batch,
|
| 208 |
+
guidance,
|
| 209 |
+
is_negative_prompt=is_negative_prompt,
|
| 210 |
+
condition_latent=condition_latent,
|
| 211 |
+
num_condition_t=num_condition_t,
|
| 212 |
+
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
|
| 213 |
+
add_input_frames_guidance=add_input_frames_guidance,
|
| 214 |
+
seed=seed,
|
| 215 |
+
)
|
| 216 |
+
if x_sigma_max is None:
|
| 217 |
+
x_sigma_max = (
|
| 218 |
+
misc.arch_invariant_rand(
|
| 219 |
+
(n_sample,) + tuple(state_shape),
|
| 220 |
+
torch.float32,
|
| 221 |
+
self.tensor_kwargs["device"],
|
| 222 |
+
seed,
|
| 223 |
+
)
|
| 224 |
+
* self.sde.sigma_max
|
| 225 |
+
)
|
| 226 |
+
|
| 227 |
+
samples = self.sampler(x0_fn, x_sigma_max, num_steps=num_steps, sigma_max=self.sde.sigma_max)
|
| 228 |
+
return samples
|
| 229 |
+
|
| 230 |
+
def get_x0_fn_from_batch_with_condition_latent(
|
| 231 |
+
self,
|
| 232 |
+
data_batch: Dict,
|
| 233 |
+
guidance: float = 1.5,
|
| 234 |
+
is_negative_prompt: bool = False,
|
| 235 |
+
condition_latent: torch.Tensor = None,
|
| 236 |
+
num_condition_t: Union[int, None] = None,
|
| 237 |
+
condition_video_augment_sigma_in_inference: float = None,
|
| 238 |
+
add_input_frames_guidance: bool = False,
|
| 239 |
+
seed: int = 1,
|
| 240 |
+
) -> Callable:
|
| 241 |
+
"""Creates denoising function for conditional video generation.
|
| 242 |
+
|
| 243 |
+
Args:
|
| 244 |
+
data_batch: Input data dictionary
|
| 245 |
+
guidance: Classifier-free guidance scale
|
| 246 |
+
is_negative_prompt: Whether to use negative prompting
|
| 247 |
+
condition_latent: Conditioning frames tensor (B,C,T,H,W)
|
| 248 |
+
num_condition_t: Number of frames to condition on
|
| 249 |
+
condition_video_augment_sigma_in_inference: Noise level for condition augmentation
|
| 250 |
+
add_input_frames_guidance: Whether to apply guidance to input frames
|
| 251 |
+
seed: Random seed for reproducibility
|
| 252 |
+
|
| 253 |
+
Returns:
|
| 254 |
+
Function that takes noisy input and noise level and returns denoised prediction
|
| 255 |
+
"""
|
| 256 |
+
if is_negative_prompt:
|
| 257 |
+
condition, uncondition = self.conditioner.get_condition_with_negative_prompt(data_batch)
|
| 258 |
+
else:
|
| 259 |
+
condition, uncondition = self.conditioner.get_condition_uncondition(data_batch)
|
| 260 |
+
|
| 261 |
+
condition.video_cond_bool = True
|
| 262 |
+
condition = self.add_condition_video_indicator_and_video_input_mask(
|
| 263 |
+
condition_latent, condition, num_condition_t
|
| 264 |
+
)
|
| 265 |
+
|
| 266 |
+
uncondition.video_cond_bool = False if add_input_frames_guidance else True
|
| 267 |
+
uncondition = self.add_condition_video_indicator_and_video_input_mask(
|
| 268 |
+
condition_latent, uncondition, num_condition_t
|
| 269 |
+
)
|
| 270 |
+
|
| 271 |
+
def x0_fn(noise_x: torch.Tensor, sigma: torch.Tensor) -> torch.Tensor:
|
| 272 |
+
cond_x0 = self.denoise(
|
| 273 |
+
noise_x,
|
| 274 |
+
sigma,
|
| 275 |
+
condition,
|
| 276 |
+
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
|
| 277 |
+
seed=seed,
|
| 278 |
+
).x0_pred_replaced
|
| 279 |
+
uncond_x0 = self.denoise(
|
| 280 |
+
noise_x,
|
| 281 |
+
sigma,
|
| 282 |
+
uncondition,
|
| 283 |
+
condition_video_augment_sigma_in_inference=condition_video_augment_sigma_in_inference,
|
| 284 |
+
seed=seed,
|
| 285 |
+
).x0_pred_replaced
|
| 286 |
+
|
| 287 |
+
return cond_x0 + guidance * (cond_x0 - uncond_x0)
|
| 288 |
+
|
| 289 |
+
return x0_fn
|
| 290 |
+
|
| 291 |
+
def add_condition_video_indicator_and_video_input_mask(
|
| 292 |
+
self, latent_state: torch.Tensor, condition: VideoExtendCondition, num_condition_t: Union[int, None] = None
|
| 293 |
+
) -> VideoExtendCondition:
|
| 294 |
+
"""Adds conditioning masks to VideoExtendCondition object.
|
| 295 |
+
|
| 296 |
+
Creates binary indicators and input masks for conditional video generation.
|
| 297 |
+
|
| 298 |
+
Args:
|
| 299 |
+
latent_state: Input latent tensor (B,C,T,H,W)
|
| 300 |
+
condition: VideoExtendCondition object to update
|
| 301 |
+
num_condition_t: Number of frames to condition on
|
| 302 |
+
|
| 303 |
+
Returns:
|
| 304 |
+
Updated VideoExtendCondition with added masks:
|
| 305 |
+
- condition_video_indicator: Binary tensor marking condition regions
|
| 306 |
+
- condition_video_input_mask: Input mask for network
|
| 307 |
+
- gt_latent: Ground truth latent tensor
|
| 308 |
+
"""
|
| 309 |
+
T = latent_state.shape[2]
|
| 310 |
+
latent_dtype = latent_state.dtype
|
| 311 |
+
condition_video_indicator = torch.zeros(1, 1, T, 1, 1, device=latent_state.device).type(
|
| 312 |
+
latent_dtype
|
| 313 |
+
) # 1 for condition region
|
| 314 |
+
|
| 315 |
+
# Only in inference to decide the condition region
|
| 316 |
+
assert num_condition_t is not None, "num_condition_t should be provided"
|
| 317 |
+
assert num_condition_t <= T, f"num_condition_t should be less than T, get {num_condition_t}, {T}"
|
| 318 |
+
log.debug(
|
| 319 |
+
f"condition_location first_n, num_condition_t {num_condition_t}, condition.video_cond_bool {condition.video_cond_bool}"
|
| 320 |
+
)
|
| 321 |
+
condition_video_indicator[:, :, :num_condition_t] += 1.0
|
| 322 |
+
|
| 323 |
+
condition.gt_latent = latent_state
|
| 324 |
+
condition.condition_video_indicator = condition_video_indicator
|
| 325 |
+
|
| 326 |
+
B, C, T, H, W = latent_state.shape
|
| 327 |
+
# Create additional input_mask channel, this will be concatenated to the input of the network
|
| 328 |
+
# See design doc section (Implementation detail A.1 and A.2) for visualization
|
| 329 |
+
ones_padding = torch.ones((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device)
|
| 330 |
+
zeros_padding = torch.zeros((B, 1, T, H, W), dtype=latent_state.dtype, device=latent_state.device)
|
| 331 |
+
assert condition.video_cond_bool is not None, "video_cond_bool should be set"
|
| 332 |
+
|
| 333 |
+
# The input mask indicate whether the input is conditional region or not
|
| 334 |
+
if condition.video_cond_bool: # Condition one given video frames
|
| 335 |
+
condition.condition_video_input_mask = (
|
| 336 |
+
condition_video_indicator * ones_padding + (1 - condition_video_indicator) * zeros_padding
|
| 337 |
+
)
|
| 338 |
+
else: # Unconditional case, use for cfg
|
| 339 |
+
condition.condition_video_input_mask = zeros_padding
|
| 340 |
+
|
| 341 |
+
return condition
|
t5_text_encoder.py
ADDED
|
@@ -0,0 +1,108 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from typing import List, Tuple, Union
|
| 17 |
+
|
| 18 |
+
import torch
|
| 19 |
+
import transformers
|
| 20 |
+
from transformers import T5EncoderModel, T5TokenizerFast
|
| 21 |
+
|
| 22 |
+
from Cosmos.utils import log
|
| 23 |
+
|
| 24 |
+
transformers.logging.set_verbosity_error()
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
class CosmosT5TextEncoder(torch.nn.Module):
|
| 28 |
+
"""Handles T5 text encoding operations."""
|
| 29 |
+
|
| 30 |
+
def __init__(self, model_name: str = "google-t5/t5-11b", device: str = "cuda", cache_dir: str = "~/.cache"):
|
| 31 |
+
"""Initializes the T5 tokenizer and encoder.
|
| 32 |
+
|
| 33 |
+
Args:
|
| 34 |
+
model_name: The name of the T5 model to use.
|
| 35 |
+
device: The device to use for computations.
|
| 36 |
+
"""
|
| 37 |
+
super().__init__()
|
| 38 |
+
try:
|
| 39 |
+
self.tokenizer = T5TokenizerFast.from_pretrained(model_name, cache_dir=cache_dir)
|
| 40 |
+
self.text_encoder = T5EncoderModel.from_pretrained(model_name, cache_dir=cache_dir).to(device)
|
| 41 |
+
except Exception as e:
|
| 42 |
+
log.warning(f"Failed to load T5 model using cache_dir '{cache_dir}', falling back to default location: {e}")
|
| 43 |
+
self.tokenizer = T5TokenizerFast.from_pretrained(model_name)
|
| 44 |
+
self.text_encoder = T5EncoderModel.from_pretrained(model_name).to(device)
|
| 45 |
+
self.text_encoder.eval()
|
| 46 |
+
self.device = device
|
| 47 |
+
|
| 48 |
+
@torch.inference_mode()
|
| 49 |
+
def encode_prompts(
|
| 50 |
+
self, prompts: Union[str, List[str]], max_length: int = 512
|
| 51 |
+
) -> Tuple[torch.Tensor, torch.Tensor]:
|
| 52 |
+
"""Encodes text prompts into hidden state representations using a T5 encoder.
|
| 53 |
+
|
| 54 |
+
This function tokenizes the input prompts, processes them through a T5 text encoder,
|
| 55 |
+
and returns the last hidden states. The encoded outputs beyond the actual sequence
|
| 56 |
+
length are zero-padded. All prompts in a batch are padded to max_length.
|
| 57 |
+
|
| 58 |
+
Args:
|
| 59 |
+
prompts: Input text to encode. Can be a single string or a list of strings.
|
| 60 |
+
max_length: Maximum sequence length for tokenization and padding. Longer
|
| 61 |
+
sequences will be truncated. Defaults to 512.
|
| 62 |
+
return_mask: If True, returns the attention mask along with encoded text.
|
| 63 |
+
Defaults to False.
|
| 64 |
+
|
| 65 |
+
Returns:
|
| 66 |
+
If return_mask is False:
|
| 67 |
+
torch.Tensor: Encoded text embeddings of shape (batch_size, max_length, hidden_size).
|
| 68 |
+
If return_mask is True:
|
| 69 |
+
tuple[torch.Tensor, torch.Tensor]: A tuple containing:
|
| 70 |
+
- Encoded text embeddings of shape (batch_size, max_length, hidden_size)
|
| 71 |
+
- Attention mask of shape (batch_size, max_length) as boolean tensor
|
| 72 |
+
|
| 73 |
+
Raises:
|
| 74 |
+
ValueError: If the input prompts list is empty.
|
| 75 |
+
|
| 76 |
+
Example:
|
| 77 |
+
>>> encoder = CosmosT5TextEncoder()
|
| 78 |
+
>>> prompts = ["Hello world", "Another example"]
|
| 79 |
+
>>> embeddings = encoder.encode_prompts(prompts, max_length=128)
|
| 80 |
+
"""
|
| 81 |
+
if isinstance(prompts, str):
|
| 82 |
+
prompts = [prompts]
|
| 83 |
+
|
| 84 |
+
if not prompts:
|
| 85 |
+
raise ValueError("The input prompt list is empty.")
|
| 86 |
+
|
| 87 |
+
batch_encoding = self.tokenizer.batch_encode_plus(
|
| 88 |
+
prompts,
|
| 89 |
+
return_tensors="pt",
|
| 90 |
+
truncation=True,
|
| 91 |
+
padding="max_length",
|
| 92 |
+
max_length=max_length,
|
| 93 |
+
return_length=True,
|
| 94 |
+
return_offsets_mapping=False,
|
| 95 |
+
)
|
| 96 |
+
|
| 97 |
+
input_ids = batch_encoding.input_ids.to(self.device)
|
| 98 |
+
attn_mask = batch_encoding.attention_mask.to(self.device)
|
| 99 |
+
|
| 100 |
+
outputs = self.text_encoder(input_ids=input_ids, attention_mask=attn_mask)
|
| 101 |
+
|
| 102 |
+
encoded_text = outputs.last_hidden_state
|
| 103 |
+
lengths = attn_mask.sum(dim=1).cpu()
|
| 104 |
+
|
| 105 |
+
for batch_id in range(encoded_text.shape[0]):
|
| 106 |
+
encoded_text[batch_id][lengths[batch_id] :] = 0
|
| 107 |
+
|
| 108 |
+
return encoded_text, attn_mask
|
text2world.py
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
from Cosmos.utils import misc
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from Cosmos.inference_utils import add_common_arguments, validate_args
|
| 23 |
+
from Cosmos.world_generation_pipeline import DiffusionText2WorldGenerationPipeline
|
| 24 |
+
from Cosmos.utils import log
|
| 25 |
+
from Cosmos.utils.io import read_prompts_from_file, save_video
|
| 26 |
+
|
| 27 |
+
torch.enable_grad(False)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def parse_arguments() -> argparse.Namespace:
|
| 31 |
+
parser = argparse.ArgumentParser(description="Text to world generation demo script")
|
| 32 |
+
# Add common arguments
|
| 33 |
+
add_common_arguments(parser)
|
| 34 |
+
|
| 35 |
+
# Add text2world specific arguments
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--diffusion_transformer_dir",
|
| 38 |
+
type=str,
|
| 39 |
+
default="Cosmos-1.0-Diffusion-7B-Text2World",
|
| 40 |
+
help="DiT model weights directory name relative to checkpoint_dir",
|
| 41 |
+
choices=[
|
| 42 |
+
"Cosmos-1.0-Diffusion-7B-Text2World",
|
| 43 |
+
"Cosmos-1.0-Diffusion-14B-Text2World",
|
| 44 |
+
],
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--prompt_upsampler_dir",
|
| 48 |
+
type=str,
|
| 49 |
+
default="Cosmos-1.0-Prompt-Upsampler-12B-Text2World",
|
| 50 |
+
help="Prompt upsampler weights directory relative to checkpoint_dir",
|
| 51 |
+
)
|
| 52 |
+
|
| 53 |
+
parser.add_argument(
|
| 54 |
+
"--word_limit_to_skip_upsampler",
|
| 55 |
+
type=int,
|
| 56 |
+
default=250,
|
| 57 |
+
help="Skip prompt upsampler for better robustness if the number of words in the prompt is greater than this value",
|
| 58 |
+
)
|
| 59 |
+
|
| 60 |
+
return parser.parse_args()
|
| 61 |
+
|
| 62 |
+
|
| 63 |
+
def demo(cfg):
|
| 64 |
+
"""Run text-to-world generation demo.
|
| 65 |
+
|
| 66 |
+
This function handles the main text-to-world generation pipeline, including:
|
| 67 |
+
- Setting up the random seed for reproducibility
|
| 68 |
+
- Initializing the generation pipeline with the provided configuration
|
| 69 |
+
- Processing single or multiple prompts from input
|
| 70 |
+
- Generating videos from text prompts
|
| 71 |
+
- Saving the generated videos and corresponding prompts to disk
|
| 72 |
+
|
| 73 |
+
Args:
|
| 74 |
+
cfg (argparse.Namespace): Configuration namespace containing:
|
| 75 |
+
- Model configuration (checkpoint paths, model settings)
|
| 76 |
+
- Generation parameters (guidance, steps, dimensions)
|
| 77 |
+
- Input/output settings (prompts, save paths)
|
| 78 |
+
- Performance options (model offloading settings)
|
| 79 |
+
|
| 80 |
+
The function will save:
|
| 81 |
+
- Generated MP4 video files
|
| 82 |
+
- Text files containing the processed prompts
|
| 83 |
+
|
| 84 |
+
If guardrails block the generation, a critical log message is displayed
|
| 85 |
+
and the function continues to the next prompt if available.
|
| 86 |
+
"""
|
| 87 |
+
misc.set_random_seed(cfg.seed)
|
| 88 |
+
inference_type = "text2world"
|
| 89 |
+
validate_args(cfg, inference_type)
|
| 90 |
+
|
| 91 |
+
# Initialize text2world generation model pipeline
|
| 92 |
+
pipeline = DiffusionText2WorldGenerationPipeline(
|
| 93 |
+
inference_type=inference_type,
|
| 94 |
+
checkpoint_dir=cfg.checkpoint_dir,
|
| 95 |
+
checkpoint_name=cfg.diffusion_transformer_dir,
|
| 96 |
+
prompt_upsampler_dir=cfg.prompt_upsampler_dir,
|
| 97 |
+
enable_prompt_upsampler=not cfg.disable_prompt_upsampler,
|
| 98 |
+
offload_network=cfg.offload_diffusion_transformer,
|
| 99 |
+
offload_tokenizer=cfg.offload_tokenizer,
|
| 100 |
+
offload_text_encoder_model=cfg.offload_text_encoder_model,
|
| 101 |
+
offload_prompt_upsampler=cfg.offload_prompt_upsampler,
|
| 102 |
+
offload_guardrail_models=cfg.offload_guardrail_models,
|
| 103 |
+
guidance=cfg.guidance,
|
| 104 |
+
num_steps=cfg.num_steps,
|
| 105 |
+
height=cfg.height,
|
| 106 |
+
width=cfg.width,
|
| 107 |
+
fps=cfg.fps,
|
| 108 |
+
num_video_frames=cfg.num_video_frames,
|
| 109 |
+
seed=cfg.seed,
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
# Handle multiple prompts if prompt file is provided
|
| 113 |
+
if cfg.batch_input_path:
|
| 114 |
+
log.info(f"Reading batch inputs from path: {args.batch_input_path}")
|
| 115 |
+
prompts = read_prompts_from_file(cfg.batch_input_path)
|
| 116 |
+
else:
|
| 117 |
+
# Single prompt case
|
| 118 |
+
prompts = [{"prompt": cfg.prompt}]
|
| 119 |
+
|
| 120 |
+
os.makedirs(cfg.video_save_folder, exist_ok=True)
|
| 121 |
+
for i, input_dict in enumerate(prompts):
|
| 122 |
+
current_prompt = input_dict.get("prompt", None)
|
| 123 |
+
if current_prompt is None:
|
| 124 |
+
log.critical("Prompt is missing, skipping world generation.")
|
| 125 |
+
continue
|
| 126 |
+
|
| 127 |
+
# Generate video
|
| 128 |
+
generated_output = pipeline.generate(current_prompt, cfg.negative_prompt, cfg.word_limit_to_skip_upsampler)
|
| 129 |
+
if generated_output is None:
|
| 130 |
+
log.critical("Guardrail blocked text2world generation.")
|
| 131 |
+
continue
|
| 132 |
+
video, prompt = generated_output
|
| 133 |
+
|
| 134 |
+
if cfg.batch_input_path:
|
| 135 |
+
video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4")
|
| 136 |
+
prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt")
|
| 137 |
+
else:
|
| 138 |
+
video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
|
| 139 |
+
prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
|
| 140 |
+
|
| 141 |
+
# Save video
|
| 142 |
+
save_video(
|
| 143 |
+
video=video,
|
| 144 |
+
fps=cfg.fps,
|
| 145 |
+
H=cfg.height,
|
| 146 |
+
W=cfg.width,
|
| 147 |
+
video_save_quality=5,
|
| 148 |
+
video_save_path=video_save_path,
|
| 149 |
+
)
|
| 150 |
+
|
| 151 |
+
# Save prompt to text file alongside video
|
| 152 |
+
with open(prompt_save_path, "wb") as f:
|
| 153 |
+
f.write(prompt.encode("utf-8"))
|
| 154 |
+
|
| 155 |
+
log.info(f"Saved video to {video_save_path}")
|
| 156 |
+
log.info(f"Saved prompt to {prompt_save_path}")
|
| 157 |
+
|
| 158 |
+
|
| 159 |
+
if __name__ == "__main__":
|
| 160 |
+
args = parse_arguments()
|
| 161 |
+
demo(args)
|
text2world_prompt_upsampler_inference.py
ADDED
|
@@ -0,0 +1,157 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
This demo script is used to run inference for Cosmos-1.0-Prompt-Upsampler-12B-Text2World.
|
| 18 |
+
Command:
|
| 19 |
+
PYTHONPATH=$(pwd) python cosmos1/models/diffusion/prompt_upsampler/text2world_prompt_upsampler_inference.py
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
import argparse
|
| 23 |
+
import os
|
| 24 |
+
import re
|
| 25 |
+
|
| 26 |
+
from cosmos1.models.autoregressive.configs.base.model_config import create_text_model_config
|
| 27 |
+
from cosmos1.models.autoregressive.model import AutoRegressiveModel
|
| 28 |
+
from cosmos1.models.diffusion.prompt_upsampler.inference import chat_completion
|
| 29 |
+
from Cosmos import guardrail_presets as guardrail_presets
|
| 30 |
+
from Cosmos.utils import log
|
| 31 |
+
|
| 32 |
+
|
| 33 |
+
def create_prompt_upsampler(checkpoint_dir: str) -> AutoRegressiveModel:
|
| 34 |
+
model_config, tokenizer_config = create_text_model_config(
|
| 35 |
+
model_ckpt_path=os.path.join(checkpoint_dir, "model.pt"),
|
| 36 |
+
tokenizer_path=os.path.join(checkpoint_dir),
|
| 37 |
+
model_family="mistral",
|
| 38 |
+
model_size="12b",
|
| 39 |
+
is_instruct_model=True,
|
| 40 |
+
max_batch_size=1,
|
| 41 |
+
rope_dim="1D",
|
| 42 |
+
add_special_tokens=True,
|
| 43 |
+
max_seq_len=1024,
|
| 44 |
+
pytorch_rope_version="v1",
|
| 45 |
+
)
|
| 46 |
+
log.debug(f"Text prompt upsampler model config: {model_config}")
|
| 47 |
+
|
| 48 |
+
# Create and return a LLM instance
|
| 49 |
+
return AutoRegressiveModel.build(
|
| 50 |
+
model_config=model_config,
|
| 51 |
+
tokenizer_config=tokenizer_config,
|
| 52 |
+
).to("cuda")
|
| 53 |
+
|
| 54 |
+
|
| 55 |
+
def run_chat_completion(model: AutoRegressiveModel, input: str, temperature: float = 0.01):
|
| 56 |
+
"""
|
| 57 |
+
text2world prompt upsampler model is finetuned for chat.
|
| 58 |
+
During training, the context window for the initial prompt upsampler models is 512 tokens. For inference, we set max_seq_len to 1024 to accommodate longer inputs.
|
| 59 |
+
Setting `max_gen_len` is optional as the finetuned models can naturally determine when to stop generating.
|
| 60 |
+
"""
|
| 61 |
+
|
| 62 |
+
dialogs = [[{"role": "user", "content": f"Upsample the short caption to a long caption: {str(input)}"}]]
|
| 63 |
+
|
| 64 |
+
results = chat_completion(
|
| 65 |
+
model,
|
| 66 |
+
dialogs,
|
| 67 |
+
max_gen_len=512,
|
| 68 |
+
temperature=temperature,
|
| 69 |
+
top_p=None,
|
| 70 |
+
top_k=None,
|
| 71 |
+
logprobs=False,
|
| 72 |
+
)
|
| 73 |
+
upsampled_prompt = str(clean_text(results[0]["generation"]["content"]))
|
| 74 |
+
return upsampled_prompt
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
def clean_text(text: str) -> str:
|
| 78 |
+
"""Clean the text by removing prefixes, suffixes, formatting markers, and normalizing whitespace."""
|
| 79 |
+
# Replace all variations of newlines with a space
|
| 80 |
+
text = text.replace("\n", " ").replace("\r", " ")
|
| 81 |
+
|
| 82 |
+
# Use a regex to find sections of the form '- **...**'
|
| 83 |
+
pattern = r"(- \*\*)(.*?)(\*\*)"
|
| 84 |
+
|
| 85 |
+
def replacement(match: re.Match[str]) -> str:
|
| 86 |
+
content = match.group(2) # The text inside - ** and **
|
| 87 |
+
words = re.findall(r"\w+", content)
|
| 88 |
+
if len(words) < 10:
|
| 89 |
+
# If fewer than 10 words, remove the entire '- **...**' portion
|
| 90 |
+
return ""
|
| 91 |
+
else:
|
| 92 |
+
# If 10 or more words, keep the entire section as it is
|
| 93 |
+
return match.group(0)
|
| 94 |
+
|
| 95 |
+
text = re.sub(pattern, replacement, text)
|
| 96 |
+
|
| 97 |
+
# Remove common prefixes
|
| 98 |
+
prefixes = ["Caption:", "#####", "####", "- ", "* ", ","]
|
| 99 |
+
for prefix in prefixes:
|
| 100 |
+
# lstrip(prefix) won't strip entire strings, but character sets.
|
| 101 |
+
# For more reliable prefix removal, do:
|
| 102 |
+
if text.startswith(prefix):
|
| 103 |
+
text = text[len(prefix) :].lstrip()
|
| 104 |
+
|
| 105 |
+
# Remove extra spaces
|
| 106 |
+
text = " ".join(text.split())
|
| 107 |
+
|
| 108 |
+
# Strip any remaining leading/trailing punctuation, whitespace, and quotes
|
| 109 |
+
text = text.strip(' -,*:"\'"“”')
|
| 110 |
+
|
| 111 |
+
return text
|
| 112 |
+
|
| 113 |
+
|
| 114 |
+
def parse_args():
|
| 115 |
+
parser = argparse.ArgumentParser(description="Run prompt upsampler inference")
|
| 116 |
+
parser.add_argument("--input", type=str, default="A dog is playing with a ball.")
|
| 117 |
+
parser.add_argument("--temperature", type=float, default=0.01, help="Inference temperature")
|
| 118 |
+
parser.add_argument(
|
| 119 |
+
"--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints"
|
| 120 |
+
)
|
| 121 |
+
parser.add_argument(
|
| 122 |
+
"--prompt_upsampler_dir",
|
| 123 |
+
type=str,
|
| 124 |
+
default="Cosmos-1.0-Prompt-Upsampler-12B-Text2World",
|
| 125 |
+
help="Prompt upsampler weights directory relative to checkpoint_dir",
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--guardrail_dir",
|
| 129 |
+
type=str,
|
| 130 |
+
default="Cosmos-1.0-Guardrail",
|
| 131 |
+
help="Guardrail weights directory relative to checkpoint_dir",
|
| 132 |
+
)
|
| 133 |
+
return parser.parse_args()
|
| 134 |
+
|
| 135 |
+
|
| 136 |
+
def main(args):
|
| 137 |
+
guardrail_runner = guardrail_presets.create_text_guardrail_runner(
|
| 138 |
+
os.path.join(args.checkpoint_dir, args.guardrail_dir)
|
| 139 |
+
)
|
| 140 |
+
is_safe = guardrail_presets.run_text_guardrail(args.input, guardrail_runner)
|
| 141 |
+
if not is_safe:
|
| 142 |
+
log.critical("Input text prompt is not safe.")
|
| 143 |
+
return
|
| 144 |
+
|
| 145 |
+
prompt_upsampler = create_prompt_upsampler(os.path.join(args.checkpoint_dir, args.prompt_upsampler_dir))
|
| 146 |
+
upsampled_prompt = run_chat_completion(prompt_upsampler, args.input, temperature=args.temperature)
|
| 147 |
+
is_safe = guardrail_presets.run_text_guardrail(upsampled_prompt, guardrail_runner)
|
| 148 |
+
if not is_safe:
|
| 149 |
+
log.critical("Upsampled text prompt is not safe.")
|
| 150 |
+
return
|
| 151 |
+
|
| 152 |
+
log.info(f"Upsampled prompt: {upsampled_prompt}")
|
| 153 |
+
|
| 154 |
+
|
| 155 |
+
if __name__ == "__main__":
|
| 156 |
+
args = parse_args()
|
| 157 |
+
main(args)
|
types.py
ADDED
|
@@ -0,0 +1,28 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
from __future__ import annotations
|
| 17 |
+
|
| 18 |
+
from dataclasses import dataclass
|
| 19 |
+
from typing import Optional
|
| 20 |
+
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
@dataclass
|
| 25 |
+
class DenoisePrediction:
|
| 26 |
+
x0: torch.Tensor # clean data prediction
|
| 27 |
+
eps: Optional[torch.Tensor] = None # noise prediction
|
| 28 |
+
logvar: Optional[torch.Tensor] = None # log variance of noise prediction, can be used a confidence / uncertainty
|
video2world.py
ADDED
|
@@ -0,0 +1,179 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
from Cosmos.utils import misc
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from Cosmos.inference_utils import add_common_arguments, check_input_frames, validate_args
|
| 23 |
+
from Cosmos.world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline
|
| 24 |
+
from Cosmos.utils import log
|
| 25 |
+
from Cosmos.utils.io import read_prompts_from_file, save_video
|
| 26 |
+
|
| 27 |
+
torch.enable_grad(False)
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
def parse_arguments() -> argparse.Namespace:
|
| 31 |
+
parser = argparse.ArgumentParser(description="Video to world generation demo script")
|
| 32 |
+
# Add common arguments
|
| 33 |
+
add_common_arguments(parser)
|
| 34 |
+
|
| 35 |
+
# Add video2world specific arguments
|
| 36 |
+
parser.add_argument(
|
| 37 |
+
"--diffusion_transformer_dir",
|
| 38 |
+
type=str,
|
| 39 |
+
default="Cosmos-1.0-Diffusion-7B-Video2World",
|
| 40 |
+
help="DiT model weights directory name relative to checkpoint_dir",
|
| 41 |
+
choices=[
|
| 42 |
+
"Cosmos-1.0-Diffusion-7B-Video2World",
|
| 43 |
+
"Cosmos-1.0-Diffusion-14B-Video2World",
|
| 44 |
+
],
|
| 45 |
+
)
|
| 46 |
+
parser.add_argument(
|
| 47 |
+
"--prompt_upsampler_dir",
|
| 48 |
+
type=str,
|
| 49 |
+
default="Pixtral-12B",
|
| 50 |
+
help="Prompt upsampler weights directory relative to checkpoint_dir",
|
| 51 |
+
)
|
| 52 |
+
parser.add_argument(
|
| 53 |
+
"--input_image_or_video_path",
|
| 54 |
+
type=str,
|
| 55 |
+
help="Input video/image path for generating a single video",
|
| 56 |
+
)
|
| 57 |
+
parser.add_argument(
|
| 58 |
+
"--num_input_frames",
|
| 59 |
+
type=int,
|
| 60 |
+
default=1,
|
| 61 |
+
help="Number of input frames for video2world prediction",
|
| 62 |
+
choices=[1, 9],
|
| 63 |
+
)
|
| 64 |
+
|
| 65 |
+
return parser.parse_args()
|
| 66 |
+
|
| 67 |
+
|
| 68 |
+
def demo(cfg):
|
| 69 |
+
"""Run video-to-world generation demo.
|
| 70 |
+
|
| 71 |
+
This function handles the main video-to-world generation pipeline, including:
|
| 72 |
+
- Setting up the random seed for reproducibility
|
| 73 |
+
- Initializing the generation pipeline with the provided configuration
|
| 74 |
+
- Processing single or multiple prompts/images/videos from input
|
| 75 |
+
- Generating videos from prompts and images/videos
|
| 76 |
+
- Saving the generated videos and corresponding prompts to disk
|
| 77 |
+
|
| 78 |
+
Args:
|
| 79 |
+
cfg (argparse.Namespace): Configuration namespace containing:
|
| 80 |
+
- Model configuration (checkpoint paths, model settings)
|
| 81 |
+
- Generation parameters (guidance, steps, dimensions)
|
| 82 |
+
- Input/output settings (prompts/images/videos, save paths)
|
| 83 |
+
- Performance options (model offloading settings)
|
| 84 |
+
|
| 85 |
+
The function will save:
|
| 86 |
+
- Generated MP4 video files
|
| 87 |
+
- Text files containing the processed prompts
|
| 88 |
+
|
| 89 |
+
If guardrails block the generation, a critical log message is displayed
|
| 90 |
+
and the function continues to the next prompt if available.
|
| 91 |
+
"""
|
| 92 |
+
misc.set_random_seed(cfg.seed)
|
| 93 |
+
inference_type = "video2world"
|
| 94 |
+
validate_args(cfg, inference_type)
|
| 95 |
+
|
| 96 |
+
# Initialize video2world generation model pipeline
|
| 97 |
+
pipeline = DiffusionVideo2WorldGenerationPipeline(
|
| 98 |
+
inference_type=inference_type,
|
| 99 |
+
checkpoint_dir=cfg.checkpoint_dir,
|
| 100 |
+
checkpoint_name=cfg.diffusion_transformer_dir,
|
| 101 |
+
prompt_upsampler_dir=cfg.prompt_upsampler_dir,
|
| 102 |
+
enable_prompt_upsampler=not cfg.disable_prompt_upsampler,
|
| 103 |
+
offload_network=cfg.offload_diffusion_transformer,
|
| 104 |
+
offload_tokenizer=cfg.offload_tokenizer,
|
| 105 |
+
offload_text_encoder_model=cfg.offload_text_encoder_model,
|
| 106 |
+
offload_prompt_upsampler=cfg.offload_prompt_upsampler,
|
| 107 |
+
offload_guardrail_models=cfg.offload_guardrail_models,
|
| 108 |
+
guidance=cfg.guidance,
|
| 109 |
+
num_steps=cfg.num_steps,
|
| 110 |
+
height=cfg.height,
|
| 111 |
+
width=cfg.width,
|
| 112 |
+
fps=cfg.fps,
|
| 113 |
+
num_video_frames=cfg.num_video_frames,
|
| 114 |
+
seed=cfg.seed,
|
| 115 |
+
num_input_frames=cfg.num_input_frames,
|
| 116 |
+
)
|
| 117 |
+
|
| 118 |
+
# Handle multiple prompts if prompt file is provided
|
| 119 |
+
if cfg.batch_input_path:
|
| 120 |
+
log.info(f"Reading batch inputs from path: {args.batch_input_path}")
|
| 121 |
+
prompts = read_prompts_from_file(cfg.batch_input_path)
|
| 122 |
+
else:
|
| 123 |
+
# Single prompt case
|
| 124 |
+
prompts = [{"prompt": cfg.prompt, "visual_input": cfg.input_image_or_video_path}]
|
| 125 |
+
|
| 126 |
+
os.makedirs(cfg.video_save_folder, exist_ok=True)
|
| 127 |
+
for i, input_dict in enumerate(prompts):
|
| 128 |
+
current_prompt = input_dict.get("prompt", None)
|
| 129 |
+
if current_prompt is None and cfg.disable_prompt_upsampler:
|
| 130 |
+
log.critical("Prompt is missing, skipping world generation.")
|
| 131 |
+
continue
|
| 132 |
+
current_image_or_video_path = input_dict.get("visual_input", None)
|
| 133 |
+
if current_image_or_video_path is None:
|
| 134 |
+
log.critical("Visual input is missing, skipping world generation.")
|
| 135 |
+
continue
|
| 136 |
+
|
| 137 |
+
# Check input frames
|
| 138 |
+
if not check_input_frames(current_image_or_video_path, cfg.num_input_frames):
|
| 139 |
+
continue
|
| 140 |
+
|
| 141 |
+
# Generate video
|
| 142 |
+
generated_output = pipeline.generate(
|
| 143 |
+
prompt=current_prompt,
|
| 144 |
+
image_or_video_path=current_image_or_video_path,
|
| 145 |
+
negative_prompt=cfg.negative_prompt,
|
| 146 |
+
)
|
| 147 |
+
if generated_output is None:
|
| 148 |
+
log.critical("Guardrail blocked video2world generation.")
|
| 149 |
+
continue
|
| 150 |
+
video, prompt = generated_output
|
| 151 |
+
|
| 152 |
+
if cfg.batch_input_path:
|
| 153 |
+
video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4")
|
| 154 |
+
prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt")
|
| 155 |
+
else:
|
| 156 |
+
video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
|
| 157 |
+
prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
|
| 158 |
+
|
| 159 |
+
# Save video
|
| 160 |
+
save_video(
|
| 161 |
+
video=video,
|
| 162 |
+
fps=cfg.fps,
|
| 163 |
+
H=cfg.height,
|
| 164 |
+
W=cfg.width,
|
| 165 |
+
video_save_quality=5,
|
| 166 |
+
video_save_path=video_save_path,
|
| 167 |
+
)
|
| 168 |
+
|
| 169 |
+
# Save prompt to text file alongside video
|
| 170 |
+
with open(prompt_save_path, "wb") as f:
|
| 171 |
+
f.write(prompt.encode("utf-8"))
|
| 172 |
+
|
| 173 |
+
log.info(f"Saved video to {video_save_path}")
|
| 174 |
+
log.info(f"Saved prompt to {prompt_save_path}")
|
| 175 |
+
|
| 176 |
+
|
| 177 |
+
if __name__ == "__main__":
|
| 178 |
+
args = parse_arguments()
|
| 179 |
+
demo(args)
|
video2world_hf.py
ADDED
|
@@ -0,0 +1,283 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import argparse
|
| 17 |
+
import os
|
| 18 |
+
|
| 19 |
+
from Cosmos.utils import misc
|
| 20 |
+
import torch
|
| 21 |
+
|
| 22 |
+
from Cosmos.inference_utils import add_common_arguments, check_input_frames, validate_args
|
| 23 |
+
from Cosmos.world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline
|
| 24 |
+
from Cosmos.utils import log
|
| 25 |
+
from Cosmos.utils.io import read_prompts_from_file, save_video
|
| 26 |
+
|
| 27 |
+
from Cosmos.download_diffusion import main as download_diffusion
|
| 28 |
+
from transformers import PreTrainedModel, PretrainedConfig
|
| 29 |
+
|
| 30 |
+
torch.enable_grad(False)
|
| 31 |
+
|
| 32 |
+
#custom config class
|
| 33 |
+
class DiffusionVideo2WorldConfig(PretrainedConfig):
|
| 34 |
+
model_type = "DiffusionVideo2World"
|
| 35 |
+
def __init__(self, **kwargs):
|
| 36 |
+
super().__init__(**kwargs)
|
| 37 |
+
self.checkpoint_dir = kwargs.get("checkpoint_dir", "checkpoints")
|
| 38 |
+
self.tokenizer_dir = kwargs.get("tokenizer_dir", "Cosmos-1.0-Tokenizer-CV8x8x8")
|
| 39 |
+
self.video_save_name = kwargs.get("video_save_name", "output")
|
| 40 |
+
self.video_save_folder = kwargs.get("video_save_folder", "outputs/")
|
| 41 |
+
self.prompt = kwargs.get("prompt", None)
|
| 42 |
+
self.batch_input_path = kwargs.get("batch_input_path", None)
|
| 43 |
+
self.negative_prompt = kwargs.get("negative_prompt", None)
|
| 44 |
+
self.num_steps = kwargs.get("num_steps", 35)
|
| 45 |
+
self.guidance = kwargs.get("guidance", 7)
|
| 46 |
+
self.num_video_frames = kwargs.get("num_video_frames", 121)
|
| 47 |
+
self.height = kwargs.get("height", 704)
|
| 48 |
+
self.width = kwargs.get("width", 1280)
|
| 49 |
+
self.fps = kwargs.get("fps", 24)
|
| 50 |
+
self.seed = kwargs.get("seed", 1)
|
| 51 |
+
self.disable_prompt_upsampler = kwargs.get("disable_prompt_upsampler", False)
|
| 52 |
+
self.offload_diffusion_transformer = kwargs.get("offload_diffusion_transformer", False)
|
| 53 |
+
self.offload_tokenizer = kwargs.get("offload_tokenizer", False)
|
| 54 |
+
self.offload_text_encoder_model = kwargs.get("offload_text_encoder_model", False)
|
| 55 |
+
self.offload_prompt_upsampler = kwargs.get("offload_prompt_upsampler", False)
|
| 56 |
+
self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
|
| 57 |
+
|
| 58 |
+
self.diffusion_transformer_dir = kwargs.get("diffusion_transformer_dir", "Cosmos-1.0-Diffusion-7B-Video2World")
|
| 59 |
+
self.prompt_upsampler_dir = kwargs.get("prompt_upsampler_dir", "Pixtral-12B")
|
| 60 |
+
self.input_image_or_video_path = kwargs.get("input_image_or_video_path", None)
|
| 61 |
+
self.num_input_frames = kwargs.get("num_input_frames", 1)
|
| 62 |
+
|
| 63 |
+
class DiffusionVideo2World(PreTrainedModel):
|
| 64 |
+
config_class = DiffusionVideo2WorldConfig
|
| 65 |
+
|
| 66 |
+
def __init__(self, config=DiffusionVideo2WorldConfig()):
|
| 67 |
+
super().__init__(config)
|
| 68 |
+
cfg = config
|
| 69 |
+
|
| 70 |
+
misc.set_random_seed(cfg.seed)
|
| 71 |
+
inference_type = "video2world"
|
| 72 |
+
validate_args(cfg, inference_type)
|
| 73 |
+
|
| 74 |
+
self.pipeline = DiffusionVideo2WorldGenerationPipeline(
|
| 75 |
+
inference_type=inference_type,
|
| 76 |
+
checkpoint_dir=cfg.checkpoint_dir,
|
| 77 |
+
checkpoint_name=cfg.diffusion_transformer_dir,
|
| 78 |
+
prompt_upsampler_dir=cfg.prompt_upsampler_dir,
|
| 79 |
+
enable_prompt_upsampler=not cfg.disable_prompt_upsampler,
|
| 80 |
+
offload_network=cfg.offload_diffusion_transformer,
|
| 81 |
+
offload_tokenizer=cfg.offload_tokenizer,
|
| 82 |
+
offload_text_encoder_model=cfg.offload_text_encoder_model,
|
| 83 |
+
offload_prompt_upsampler=cfg.offload_prompt_upsampler,
|
| 84 |
+
offload_guardrail_models=cfg.offload_guardrail_models,
|
| 85 |
+
guidance=cfg.guidance,
|
| 86 |
+
num_steps=cfg.num_steps,
|
| 87 |
+
height=cfg.height,
|
| 88 |
+
width=cfg.width,
|
| 89 |
+
fps=cfg.fps,
|
| 90 |
+
num_video_frames=cfg.num_video_frames,
|
| 91 |
+
seed=cfg.seed,
|
| 92 |
+
num_input_frames=cfg.num_input_frames,
|
| 93 |
+
)
|
| 94 |
+
|
| 95 |
+
def forward(self):
|
| 96 |
+
cfg = self.config
|
| 97 |
+
|
| 98 |
+
# Handle multiple prompts if prompt file is provided
|
| 99 |
+
if cfg.batch_input_path:
|
| 100 |
+
log.info(f"Reading batch inputs from path: {args.batch_input_path}")
|
| 101 |
+
prompts = read_prompts_from_file(cfg.batch_input_path)
|
| 102 |
+
else:
|
| 103 |
+
# Single prompt case
|
| 104 |
+
prompts = [{"prompt": cfg.prompt, "visual_input": cfg.input_image_or_video_path}]
|
| 105 |
+
|
| 106 |
+
os.makedirs(cfg.video_save_folder, exist_ok=True)
|
| 107 |
+
for i, input_dict in enumerate(prompts):
|
| 108 |
+
current_prompt = input_dict.get("prompt", None)
|
| 109 |
+
if current_prompt is None and cfg.disable_prompt_upsampler:
|
| 110 |
+
log.critical("Prompt is missing, skipping world generation.")
|
| 111 |
+
continue
|
| 112 |
+
current_image_or_video_path = input_dict.get("visual_input", None)
|
| 113 |
+
if current_image_or_video_path is None:
|
| 114 |
+
log.critical("Visual input is missing, skipping world generation.")
|
| 115 |
+
continue
|
| 116 |
+
|
| 117 |
+
# Check input frames
|
| 118 |
+
if not check_input_frames(current_image_or_video_path, cfg.num_input_frames):
|
| 119 |
+
continue
|
| 120 |
+
|
| 121 |
+
# Generate video
|
| 122 |
+
generated_output = pipeline.generate(
|
| 123 |
+
prompt=current_prompt,
|
| 124 |
+
image_or_video_path=current_image_or_video_path,
|
| 125 |
+
negative_prompt=cfg.negative_prompt,
|
| 126 |
+
)
|
| 127 |
+
if generated_output is None:
|
| 128 |
+
log.critical("Guardrail blocked video2world generation.")
|
| 129 |
+
continue
|
| 130 |
+
video, prompt = generated_output
|
| 131 |
+
|
| 132 |
+
if cfg.batch_input_path:
|
| 133 |
+
video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4")
|
| 134 |
+
prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt")
|
| 135 |
+
else:
|
| 136 |
+
video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
|
| 137 |
+
prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
|
| 138 |
+
|
| 139 |
+
# Save video
|
| 140 |
+
save_video(
|
| 141 |
+
video=video,
|
| 142 |
+
fps=cfg.fps,
|
| 143 |
+
H=cfg.height,
|
| 144 |
+
W=cfg.width,
|
| 145 |
+
video_save_quality=5,
|
| 146 |
+
video_save_path=video_save_path,
|
| 147 |
+
)
|
| 148 |
+
|
| 149 |
+
# Save prompt to text file alongside video
|
| 150 |
+
with open(prompt_save_path, "wb") as f:
|
| 151 |
+
f.write(prompt.encode("utf-8"))
|
| 152 |
+
|
| 153 |
+
log.info(f"Saved video to {video_save_path}")
|
| 154 |
+
log.info(f"Saved prompt to {prompt_save_path}")
|
| 155 |
+
|
| 156 |
+
def save_pretrained(self, save_directory, **kwargs):
|
| 157 |
+
# We don't save anything, but need this function to override
|
| 158 |
+
pass
|
| 159 |
+
|
| 160 |
+
@classmethod
|
| 161 |
+
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
|
| 162 |
+
config = kwargs["config"]
|
| 163 |
+
other_args = kwargs.copy()
|
| 164 |
+
other_args.pop("config")
|
| 165 |
+
config.update(other_args)
|
| 166 |
+
model_sizes = ["7B",] if "7B" in config.diffusion_transformer_dir else ["14B",]
|
| 167 |
+
model_types = ["Video2World",]
|
| 168 |
+
download_diffusion(model_types, model_sizes, config.checkpoint_dir)
|
| 169 |
+
model = cls(config)
|
| 170 |
+
return model
|
| 171 |
+
|
| 172 |
+
def demo(cfg):
|
| 173 |
+
"""Run video-to-world generation demo.
|
| 174 |
+
|
| 175 |
+
This function handles the main video-to-world generation pipeline, including:
|
| 176 |
+
- Setting up the random seed for reproducibility
|
| 177 |
+
- Initializing the generation pipeline with the provided configuration
|
| 178 |
+
- Processing single or multiple prompts/images/videos from input
|
| 179 |
+
- Generating videos from prompts and images/videos
|
| 180 |
+
- Saving the generated videos and corresponding prompts to disk
|
| 181 |
+
|
| 182 |
+
Args:
|
| 183 |
+
cfg (argparse.Namespace): Configuration namespace containing:
|
| 184 |
+
- Model configuration (checkpoint paths, model settings)
|
| 185 |
+
- Generation parameters (guidance, steps, dimensions)
|
| 186 |
+
- Input/output settings (prompts/images/videos, save paths)
|
| 187 |
+
- Performance options (model offloading settings)
|
| 188 |
+
|
| 189 |
+
The function will save:
|
| 190 |
+
- Generated MP4 video files
|
| 191 |
+
- Text files containing the processed prompts
|
| 192 |
+
|
| 193 |
+
If guardrails block the generation, a critical log message is displayed
|
| 194 |
+
and the function continues to the next prompt if available.
|
| 195 |
+
"""
|
| 196 |
+
misc.set_random_seed(cfg.seed)
|
| 197 |
+
inference_type = "video2world"
|
| 198 |
+
validate_args(cfg, inference_type)
|
| 199 |
+
|
| 200 |
+
# Initialize video2world generation model pipeline
|
| 201 |
+
pipeline = DiffusionVideo2WorldGenerationPipeline(
|
| 202 |
+
inference_type=inference_type,
|
| 203 |
+
checkpoint_dir=cfg.checkpoint_dir,
|
| 204 |
+
checkpoint_name=cfg.diffusion_transformer_dir,
|
| 205 |
+
prompt_upsampler_dir=cfg.prompt_upsampler_dir,
|
| 206 |
+
enable_prompt_upsampler=not cfg.disable_prompt_upsampler,
|
| 207 |
+
offload_network=cfg.offload_diffusion_transformer,
|
| 208 |
+
offload_tokenizer=cfg.offload_tokenizer,
|
| 209 |
+
offload_text_encoder_model=cfg.offload_text_encoder_model,
|
| 210 |
+
offload_prompt_upsampler=cfg.offload_prompt_upsampler,
|
| 211 |
+
offload_guardrail_models=cfg.offload_guardrail_models,
|
| 212 |
+
guidance=cfg.guidance,
|
| 213 |
+
num_steps=cfg.num_steps,
|
| 214 |
+
height=cfg.height,
|
| 215 |
+
width=cfg.width,
|
| 216 |
+
fps=cfg.fps,
|
| 217 |
+
num_video_frames=cfg.num_video_frames,
|
| 218 |
+
seed=cfg.seed,
|
| 219 |
+
num_input_frames=cfg.num_input_frames,
|
| 220 |
+
)
|
| 221 |
+
|
| 222 |
+
# Handle multiple prompts if prompt file is provided
|
| 223 |
+
if cfg.batch_input_path:
|
| 224 |
+
log.info(f"Reading batch inputs from path: {args.batch_input_path}")
|
| 225 |
+
prompts = read_prompts_from_file(cfg.batch_input_path)
|
| 226 |
+
else:
|
| 227 |
+
# Single prompt case
|
| 228 |
+
prompts = [{"prompt": cfg.prompt, "visual_input": cfg.input_image_or_video_path}]
|
| 229 |
+
|
| 230 |
+
os.makedirs(cfg.video_save_folder, exist_ok=True)
|
| 231 |
+
for i, input_dict in enumerate(prompts):
|
| 232 |
+
current_prompt = input_dict.get("prompt", None)
|
| 233 |
+
if current_prompt is None and cfg.disable_prompt_upsampler:
|
| 234 |
+
log.critical("Prompt is missing, skipping world generation.")
|
| 235 |
+
continue
|
| 236 |
+
current_image_or_video_path = input_dict.get("visual_input", None)
|
| 237 |
+
if current_image_or_video_path is None:
|
| 238 |
+
log.critical("Visual input is missing, skipping world generation.")
|
| 239 |
+
continue
|
| 240 |
+
|
| 241 |
+
# Check input frames
|
| 242 |
+
if not check_input_frames(current_image_or_video_path, cfg.num_input_frames):
|
| 243 |
+
continue
|
| 244 |
+
|
| 245 |
+
# Generate video
|
| 246 |
+
generated_output = pipeline.generate(
|
| 247 |
+
prompt=current_prompt,
|
| 248 |
+
image_or_video_path=current_image_or_video_path,
|
| 249 |
+
negative_prompt=cfg.negative_prompt,
|
| 250 |
+
)
|
| 251 |
+
if generated_output is None:
|
| 252 |
+
log.critical("Guardrail blocked video2world generation.")
|
| 253 |
+
continue
|
| 254 |
+
video, prompt = generated_output
|
| 255 |
+
|
| 256 |
+
if cfg.batch_input_path:
|
| 257 |
+
video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4")
|
| 258 |
+
prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt")
|
| 259 |
+
else:
|
| 260 |
+
video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
|
| 261 |
+
prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
|
| 262 |
+
|
| 263 |
+
# Save video
|
| 264 |
+
save_video(
|
| 265 |
+
video=video,
|
| 266 |
+
fps=cfg.fps,
|
| 267 |
+
H=cfg.height,
|
| 268 |
+
W=cfg.width,
|
| 269 |
+
video_save_quality=5,
|
| 270 |
+
video_save_path=video_save_path,
|
| 271 |
+
)
|
| 272 |
+
|
| 273 |
+
# Save prompt to text file alongside video
|
| 274 |
+
with open(prompt_save_path, "wb") as f:
|
| 275 |
+
f.write(prompt.encode("utf-8"))
|
| 276 |
+
|
| 277 |
+
log.info(f"Saved video to {video_save_path}")
|
| 278 |
+
log.info(f"Saved prompt to {prompt_save_path}")
|
| 279 |
+
|
| 280 |
+
|
| 281 |
+
if __name__ == "__main__":
|
| 282 |
+
args = parse_arguments()
|
| 283 |
+
demo(args)
|
video2world_prompt_upsampler_inference.py
ADDED
|
@@ -0,0 +1,167 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
"""
|
| 17 |
+
This demo script is used to run inference for Pixtral-12B.
|
| 18 |
+
Command:
|
| 19 |
+
PYTHONPATH=$(pwd) python cosmos1/models/diffusion/prompt_upsampler/video2world_prompt_upsampler_inference.py
|
| 20 |
+
|
| 21 |
+
"""
|
| 22 |
+
|
| 23 |
+
import argparse
|
| 24 |
+
import os
|
| 25 |
+
from math import ceil
|
| 26 |
+
|
| 27 |
+
from PIL import Image
|
| 28 |
+
|
| 29 |
+
from cosmos1.models.autoregressive.configs.base.model_config import create_vision_language_model_config
|
| 30 |
+
from cosmos1.models.autoregressive.model import AutoRegressiveModel
|
| 31 |
+
from cosmos1.models.diffusion.prompt_upsampler.inference import chat_completion
|
| 32 |
+
from Cosmos import guardrail_presets as guardrail_presets
|
| 33 |
+
from Cosmos.utils import log
|
| 34 |
+
from Cosmos.utils.io import load_from_fileobj
|
| 35 |
+
|
| 36 |
+
|
| 37 |
+
def create_vlm_prompt_upsampler(
|
| 38 |
+
checkpoint_dir: str, tokenizer_ckpt_path: str = "mistral-community/pixtral-12b"
|
| 39 |
+
) -> AutoRegressiveModel:
|
| 40 |
+
"""
|
| 41 |
+
Load the fine-tuned pixtral model for SimReady.
|
| 42 |
+
If pixtral_ckpt is not provided, use the pretrained checkpoint.
|
| 43 |
+
"""
|
| 44 |
+
model_ckpt_path = os.path.join(checkpoint_dir, "model.pt")
|
| 45 |
+
model_config, tokenizer_config = create_vision_language_model_config(
|
| 46 |
+
model_ckpt_path=model_ckpt_path,
|
| 47 |
+
tokenizer_ckpt_path=tokenizer_ckpt_path,
|
| 48 |
+
model_family="pixtral",
|
| 49 |
+
model_size="12b",
|
| 50 |
+
is_instruct_model=True,
|
| 51 |
+
max_batch_size=1,
|
| 52 |
+
max_seq_len=4300,
|
| 53 |
+
pytorch_rope_version="v1",
|
| 54 |
+
)
|
| 55 |
+
# during instantiate, the weights will be downloaded (if not already cached) and loaded
|
| 56 |
+
return AutoRegressiveModel.build(
|
| 57 |
+
model_config=model_config,
|
| 58 |
+
tokenizer_config=tokenizer_config,
|
| 59 |
+
).to("cuda")
|
| 60 |
+
|
| 61 |
+
|
| 62 |
+
def resize_image(image: Image.Image, max_size: int = 1024) -> Image.Image:
|
| 63 |
+
"""
|
| 64 |
+
Ensure that the image is no larger than max_size in both dimensions.
|
| 65 |
+
"""
|
| 66 |
+
image_width, image_height = image.size
|
| 67 |
+
max_width, max_height = max_size, max_size
|
| 68 |
+
ratio = max(image_width / max_width, image_height / max_height)
|
| 69 |
+
if ratio > 1:
|
| 70 |
+
image = image.resize((ceil(image_width / ratio), ceil(image_height / ratio)))
|
| 71 |
+
return image
|
| 72 |
+
|
| 73 |
+
|
| 74 |
+
def prepare_dialog(image_or_video_path: str) -> list[dict]:
|
| 75 |
+
if image_or_video_path.endswith(".mp4"):
|
| 76 |
+
video_np, _ = load_from_fileobj(image_or_video_path, format="mp4")
|
| 77 |
+
image_frame = video_np[-1]
|
| 78 |
+
image = Image.fromarray(image_frame)
|
| 79 |
+
else:
|
| 80 |
+
image: Image.Image = Image.open(image_or_video_path)
|
| 81 |
+
|
| 82 |
+
image = resize_image(image, max_size=1024)
|
| 83 |
+
prompt = """\
|
| 84 |
+
Your task is to transform a given prompt into a refined and concise video description, no more than 150 words.
|
| 85 |
+
Focus only on the content, no filler words or descriptions on the style. Never mention things outside the video.
|
| 86 |
+
""".strip()
|
| 87 |
+
|
| 88 |
+
return [
|
| 89 |
+
{
|
| 90 |
+
"role": "user",
|
| 91 |
+
"content": "[IMG]\n" + prompt,
|
| 92 |
+
"images": [image],
|
| 93 |
+
}
|
| 94 |
+
]
|
| 95 |
+
|
| 96 |
+
|
| 97 |
+
def run_chat_completion(pixtral: AutoRegressiveModel, dialog: list[dict], **inference_args) -> str:
|
| 98 |
+
default_args = {
|
| 99 |
+
"max_gen_len": 400,
|
| 100 |
+
"temperature": 0,
|
| 101 |
+
"top_p": 0.9,
|
| 102 |
+
"logprobs": False,
|
| 103 |
+
"compile_sampling": False,
|
| 104 |
+
"compile_prefill": False,
|
| 105 |
+
}
|
| 106 |
+
default_args.update(inference_args)
|
| 107 |
+
results = chat_completion(
|
| 108 |
+
pixtral,
|
| 109 |
+
[dialog],
|
| 110 |
+
**default_args,
|
| 111 |
+
)
|
| 112 |
+
assert len(results) == 1
|
| 113 |
+
upsampled_prompt = str(results[0]["generation"]["content"])
|
| 114 |
+
return upsampled_prompt
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
def parse_args():
|
| 118 |
+
parser = argparse.ArgumentParser(description="Run prompt upsampler inference")
|
| 119 |
+
parser.add_argument(
|
| 120 |
+
"--image_or_video_path", type=str, default="cosmos1/models/diffusion/assets/v1p0/video2world_input0.jpg"
|
| 121 |
+
)
|
| 122 |
+
parser.add_argument("--temperature", type=float, default=0.01, help="Inference temperature")
|
| 123 |
+
parser.add_argument("--top_p", type=float, default=0.9, help="Top-p value for top-p sampling")
|
| 124 |
+
parser.add_argument(
|
| 125 |
+
"--checkpoint_dir", type=str, default="checkpoints", help="Base directory containing model checkpoints"
|
| 126 |
+
)
|
| 127 |
+
parser.add_argument(
|
| 128 |
+
"--prompt_upsampler_dir",
|
| 129 |
+
type=str,
|
| 130 |
+
default="Pixtral-12B",
|
| 131 |
+
help="Prompt upsampler weights directory relative to checkpoint_dir",
|
| 132 |
+
)
|
| 133 |
+
parser.add_argument(
|
| 134 |
+
"--guardrail_dir",
|
| 135 |
+
type=str,
|
| 136 |
+
default="Cosmos-1.0-Guardrail",
|
| 137 |
+
help="Guardrail weights directory relative to checkpoint_dir",
|
| 138 |
+
)
|
| 139 |
+
return parser.parse_args()
|
| 140 |
+
|
| 141 |
+
|
| 142 |
+
def main(args):
|
| 143 |
+
guardrail_runner = guardrail_presets.create_text_guardrail_runner(
|
| 144 |
+
os.path.join(args.checkpoint_dir, args.guardrail_dir)
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
pixtral = create_vlm_prompt_upsampler(os.path.join(args.checkpoint_dir, args.prompt_upsampler_dir))
|
| 148 |
+
dialog = prepare_dialog(args.image_or_video_path)
|
| 149 |
+
upsampled_prompt = run_chat_completion(
|
| 150 |
+
pixtral,
|
| 151 |
+
dialog,
|
| 152 |
+
max_gen_len=400,
|
| 153 |
+
temperature=args.temperature,
|
| 154 |
+
top_p=args.top_p,
|
| 155 |
+
logprobs=False,
|
| 156 |
+
)
|
| 157 |
+
is_safe = guardrail_presets.run_text_guardrail(upsampled_prompt, guardrail_runner)
|
| 158 |
+
if not is_safe:
|
| 159 |
+
log.critical("Upsampled text prompt is not safe.")
|
| 160 |
+
return
|
| 161 |
+
|
| 162 |
+
log.info(f"Upsampled prompt: {upsampled_prompt}")
|
| 163 |
+
|
| 164 |
+
|
| 165 |
+
if __name__ == "__main__":
|
| 166 |
+
args = parse_args()
|
| 167 |
+
main(args)
|
world_generation_pipeline.py
ADDED
|
@@ -0,0 +1,658 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
| 2 |
+
# SPDX-License-Identifier: Apache-2.0
|
| 3 |
+
#
|
| 4 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 5 |
+
# you may not use this file except in compliance with the License.
|
| 6 |
+
# You may obtain a copy of the License at
|
| 7 |
+
#
|
| 8 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 9 |
+
#
|
| 10 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 11 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 12 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 13 |
+
# See the License for the specific language governing permissions and
|
| 14 |
+
# limitations under the License.
|
| 15 |
+
|
| 16 |
+
import gc
|
| 17 |
+
import os
|
| 18 |
+
from typing import Any, Optional
|
| 19 |
+
|
| 20 |
+
import numpy as np
|
| 21 |
+
import torch
|
| 22 |
+
|
| 23 |
+
from Cosmos.base_world_generation_pipeline import BaseWorldGenerationPipeline
|
| 24 |
+
from Cosmos.inference_utils import (
|
| 25 |
+
generate_world_from_text,
|
| 26 |
+
generate_world_from_video,
|
| 27 |
+
get_condition_latent,
|
| 28 |
+
get_video_batch,
|
| 29 |
+
load_model_by_config,
|
| 30 |
+
load_network_model,
|
| 31 |
+
load_tokenizer_model,
|
| 32 |
+
)
|
| 33 |
+
from Cosmos.model_t2w import DiffusionT2WModel
|
| 34 |
+
from Cosmos.model_v2w import DiffusionV2WModel
|
| 35 |
+
from Cosmos.text2world_prompt_upsampler_inference import (
|
| 36 |
+
create_prompt_upsampler,
|
| 37 |
+
run_chat_completion,
|
| 38 |
+
)
|
| 39 |
+
from Cosmos.video2world_prompt_upsampler_inference import (
|
| 40 |
+
create_vlm_prompt_upsampler,
|
| 41 |
+
prepare_dialog,
|
| 42 |
+
)
|
| 43 |
+
from Cosmos.video2world_prompt_upsampler_inference import (
|
| 44 |
+
run_chat_completion as run_chat_completion_vlm,
|
| 45 |
+
)
|
| 46 |
+
from Cosmos.utils import log
|
| 47 |
+
|
| 48 |
+
MODEL_NAME_DICT = {
|
| 49 |
+
"Cosmos-1.0-Diffusion-7B-Text2World": "Cosmos_1_0_Diffusion_Text2World_7B",
|
| 50 |
+
"Cosmos-1.0-Diffusion-14B-Text2World": "Cosmos_1_0_Diffusion_Text2World_14B",
|
| 51 |
+
"Cosmos-1.0-Diffusion-7B-Video2World": "Cosmos_1_0_Diffusion_Video2World_7B",
|
| 52 |
+
"Cosmos-1.0-Diffusion-14B-Video2World": "Cosmos_1_0_Diffusion_Video2World_14B",
|
| 53 |
+
}
|
| 54 |
+
|
| 55 |
+
|
| 56 |
+
class DiffusionText2WorldGenerationPipeline(BaseWorldGenerationPipeline):
|
| 57 |
+
def __init__(
|
| 58 |
+
self,
|
| 59 |
+
inference_type: str,
|
| 60 |
+
checkpoint_dir: str,
|
| 61 |
+
checkpoint_name: str,
|
| 62 |
+
prompt_upsampler_dir: Optional[str] = None,
|
| 63 |
+
enable_prompt_upsampler: bool = True,
|
| 64 |
+
enable_text_guardrail: bool = True,
|
| 65 |
+
enable_video_guardrail: bool = True,
|
| 66 |
+
offload_network: bool = False,
|
| 67 |
+
offload_tokenizer: bool = False,
|
| 68 |
+
offload_text_encoder_model: bool = False,
|
| 69 |
+
offload_prompt_upsampler: bool = False,
|
| 70 |
+
offload_guardrail_models: bool = False,
|
| 71 |
+
guidance: float = 7.0,
|
| 72 |
+
num_steps: int = 35,
|
| 73 |
+
height: int = 704,
|
| 74 |
+
width: int = 1280,
|
| 75 |
+
fps: int = 24,
|
| 76 |
+
num_video_frames: int = 121,
|
| 77 |
+
seed: int = 0,
|
| 78 |
+
):
|
| 79 |
+
"""Initialize the diffusion world generation pipeline.
|
| 80 |
+
|
| 81 |
+
Args:
|
| 82 |
+
inference_type: Type of world generation ('text2world' or 'video2world')
|
| 83 |
+
checkpoint_dir: Base directory containing model checkpoints
|
| 84 |
+
checkpoint_name: Name of the diffusion transformer checkpoint to use
|
| 85 |
+
prompt_upsampler_dir: Directory containing prompt upsampler model weights
|
| 86 |
+
enable_prompt_upsampler: Whether to use prompt upsampling
|
| 87 |
+
enable_text_guardrail: Whether to enable text guardrail
|
| 88 |
+
enable_video_guardrail: Whether to enable video guardrail
|
| 89 |
+
offload_network: Whether to offload diffusion transformer after inference
|
| 90 |
+
offload_tokenizer: Whether to offload tokenizer after inference
|
| 91 |
+
offload_text_encoder_model: Whether to offload T5 model after inference
|
| 92 |
+
offload_prompt_upsampler: Whether to offload prompt upsampler
|
| 93 |
+
offload_guardrail_models: Whether to offload guardrail models
|
| 94 |
+
guidance: Classifier-free guidance scale
|
| 95 |
+
num_steps: Number of diffusion sampling steps
|
| 96 |
+
height: Height of output video
|
| 97 |
+
width: Width of output video
|
| 98 |
+
fps: Frames per second of output video
|
| 99 |
+
num_video_frames: Number of frames to generate
|
| 100 |
+
seed: Random seed for sampling
|
| 101 |
+
"""
|
| 102 |
+
assert inference_type in [
|
| 103 |
+
"text2world",
|
| 104 |
+
"video2world",
|
| 105 |
+
], "Invalid inference_type, must be 'text2world' or 'video2world'"
|
| 106 |
+
|
| 107 |
+
self.model_name = MODEL_NAME_DICT[checkpoint_name]
|
| 108 |
+
self.guidance = guidance
|
| 109 |
+
self.num_steps = num_steps
|
| 110 |
+
self.height = height
|
| 111 |
+
self.width = width
|
| 112 |
+
self.fps = fps
|
| 113 |
+
self.num_video_frames = num_video_frames
|
| 114 |
+
self.seed = seed
|
| 115 |
+
|
| 116 |
+
super().__init__(
|
| 117 |
+
inference_type=inference_type,
|
| 118 |
+
checkpoint_dir=checkpoint_dir,
|
| 119 |
+
checkpoint_name=checkpoint_name,
|
| 120 |
+
enable_text_guardrail=enable_text_guardrail,
|
| 121 |
+
enable_video_guardrail=enable_video_guardrail,
|
| 122 |
+
offload_network=offload_network,
|
| 123 |
+
offload_tokenizer=offload_tokenizer,
|
| 124 |
+
offload_text_encoder_model=offload_text_encoder_model,
|
| 125 |
+
offload_guardrail_models=offload_guardrail_models,
|
| 126 |
+
)
|
| 127 |
+
self.prompt_upsampler_dir = prompt_upsampler_dir
|
| 128 |
+
self.enable_prompt_upsampler = enable_prompt_upsampler
|
| 129 |
+
self.offload_prompt_upsampler = offload_prompt_upsampler
|
| 130 |
+
|
| 131 |
+
self.prompt_upsampler = None
|
| 132 |
+
if enable_prompt_upsampler and not offload_prompt_upsampler:
|
| 133 |
+
self._load_prompt_upsampler_model()
|
| 134 |
+
|
| 135 |
+
def _load_prompt_upsampler_model(self):
|
| 136 |
+
self.prompt_upsampler = create_prompt_upsampler(
|
| 137 |
+
checkpoint_dir=os.path.join(self.checkpoint_dir, self.prompt_upsampler_dir),
|
| 138 |
+
)
|
| 139 |
+
|
| 140 |
+
def _load_model(self):
|
| 141 |
+
self.model = load_model_by_config(
|
| 142 |
+
config_job_name=self.model_name,
|
| 143 |
+
config_file="cosmos1/models/diffusion/config/config.py",
|
| 144 |
+
model_class=DiffusionT2WModel,
|
| 145 |
+
)
|
| 146 |
+
|
| 147 |
+
def _load_network(self):
|
| 148 |
+
load_network_model(self.model, f"{self.checkpoint_dir}/{self.checkpoint_name}/model.pt")
|
| 149 |
+
|
| 150 |
+
def _load_tokenizer(self):
|
| 151 |
+
load_tokenizer_model(self.model, f"{self.checkpoint_dir}/Cosmos-1.0-Tokenizer-CV8x8x8")
|
| 152 |
+
|
| 153 |
+
def _offload_prompt_upsampler_model(self):
|
| 154 |
+
"""Move prompt enhancement model to CPU/disk.
|
| 155 |
+
|
| 156 |
+
Offloads prompt upsampling model after processing input
|
| 157 |
+
to reduce GPU memory usage.
|
| 158 |
+
"""
|
| 159 |
+
if self.prompt_upsampler:
|
| 160 |
+
del self.prompt_upsampler
|
| 161 |
+
self.prompt_upsampler = None
|
| 162 |
+
gc.collect()
|
| 163 |
+
torch.cuda.empty_cache()
|
| 164 |
+
|
| 165 |
+
def _run_prompt_upsampler_on_prompt(self, prompt: str) -> str:
|
| 166 |
+
"""Enhance the input prompt using the prompt upsampler model.
|
| 167 |
+
|
| 168 |
+
Args:
|
| 169 |
+
prompt: Raw text prompt to be enhanced
|
| 170 |
+
|
| 171 |
+
Returns:
|
| 172 |
+
str: Enhanced version of the input prompt with more descriptive details
|
| 173 |
+
"""
|
| 174 |
+
upsampled_prompt = run_chat_completion(self.prompt_upsampler, prompt)
|
| 175 |
+
log.info(f"Upsampled prompt: {upsampled_prompt}")
|
| 176 |
+
return upsampled_prompt
|
| 177 |
+
|
| 178 |
+
def _run_prompt_upsampler_on_prompt_with_offload(self, *args: Any, **kwargs: Any) -> str:
|
| 179 |
+
"""Enhance prompt with prompt upsampler model.
|
| 180 |
+
|
| 181 |
+
Args:
|
| 182 |
+
*args: Positional arguments
|
| 183 |
+
**kwargs: Keyword arguments
|
| 184 |
+
|
| 185 |
+
Returns:
|
| 186 |
+
Enhanced prompt string
|
| 187 |
+
"""
|
| 188 |
+
if self.offload_prompt_upsampler:
|
| 189 |
+
self._load_prompt_upsampler_model()
|
| 190 |
+
|
| 191 |
+
enhanced_prompt = self._run_prompt_upsampler_on_prompt(*args, **kwargs)
|
| 192 |
+
|
| 193 |
+
if self.offload_prompt_upsampler:
|
| 194 |
+
self._offload_prompt_upsampler_model()
|
| 195 |
+
|
| 196 |
+
return enhanced_prompt
|
| 197 |
+
|
| 198 |
+
def _run_tokenizer_decoding(self, sample: torch.Tensor) -> np.ndarray:
|
| 199 |
+
"""Decode latent samples to video frames using the tokenizer decoder.
|
| 200 |
+
|
| 201 |
+
Args:
|
| 202 |
+
sample: Latent tensor from diffusion model [B, C, T, H, W]
|
| 203 |
+
|
| 204 |
+
Returns:
|
| 205 |
+
np.ndarray: Decoded video frames as uint8 numpy array [T, H, W, C]
|
| 206 |
+
with values in range [0, 255]
|
| 207 |
+
"""
|
| 208 |
+
# Decode video
|
| 209 |
+
video = (1.0 + self.model.decode(sample)).clamp(0, 2) / 2 # [B, 3, T, H, W]
|
| 210 |
+
video = (video[0].permute(1, 2, 3, 0) * 255).to(torch.uint8).cpu().numpy()
|
| 211 |
+
|
| 212 |
+
return video
|
| 213 |
+
|
| 214 |
+
def _run_model(
|
| 215 |
+
self,
|
| 216 |
+
embedding: torch.Tensor,
|
| 217 |
+
negative_prompt_embedding: Optional[torch.Tensor] = None,
|
| 218 |
+
) -> torch.Tensor:
|
| 219 |
+
"""Generate video latents using the diffusion model.
|
| 220 |
+
|
| 221 |
+
Args:
|
| 222 |
+
embedding: Text embedding tensor from text encoder
|
| 223 |
+
negative_prompt_embedding: Optional embedding for negative prompt guidance
|
| 224 |
+
|
| 225 |
+
Returns:
|
| 226 |
+
torch.Tensor: Generated video latents before tokenizer decoding
|
| 227 |
+
|
| 228 |
+
Note:
|
| 229 |
+
The model and tokenizer are automatically offloaded after inference
|
| 230 |
+
if offloading is enabled in the config.
|
| 231 |
+
"""
|
| 232 |
+
# Get video batch and state shape
|
| 233 |
+
data_batch, state_shape = get_video_batch(
|
| 234 |
+
model=self.model,
|
| 235 |
+
prompt_embedding=embedding,
|
| 236 |
+
negative_prompt_embedding=negative_prompt_embedding,
|
| 237 |
+
height=self.height,
|
| 238 |
+
width=self.width,
|
| 239 |
+
fps=self.fps,
|
| 240 |
+
num_video_frames=self.num_video_frames,
|
| 241 |
+
)
|
| 242 |
+
|
| 243 |
+
# Generate video frames
|
| 244 |
+
sample = generate_world_from_text(
|
| 245 |
+
model=self.model,
|
| 246 |
+
state_shape=state_shape,
|
| 247 |
+
is_negative_prompt=True if negative_prompt_embedding is not None else False,
|
| 248 |
+
data_batch=data_batch,
|
| 249 |
+
guidance=self.guidance,
|
| 250 |
+
num_steps=self.num_steps,
|
| 251 |
+
seed=self.seed,
|
| 252 |
+
)
|
| 253 |
+
|
| 254 |
+
return sample
|
| 255 |
+
|
| 256 |
+
def _run_model_with_offload(
|
| 257 |
+
self, prompt_embedding: torch.Tensor, negative_prompt_embedding: Optional[torch.Tensor] = None
|
| 258 |
+
) -> np.ndarray:
|
| 259 |
+
"""Generate world representation with automatic model offloading.
|
| 260 |
+
|
| 261 |
+
Wraps the core generation process with model loading/offloading logic
|
| 262 |
+
to minimize GPU memory usage during inference.
|
| 263 |
+
|
| 264 |
+
Args:
|
| 265 |
+
*args: Positional arguments passed to _run_model
|
| 266 |
+
**kwargs: Keyword arguments passed to _run_model
|
| 267 |
+
|
| 268 |
+
Returns:
|
| 269 |
+
np.ndarray: Generated world representation as numpy array
|
| 270 |
+
"""
|
| 271 |
+
if self.offload_network:
|
| 272 |
+
self._load_network()
|
| 273 |
+
|
| 274 |
+
if self.offload_tokenizer:
|
| 275 |
+
self._load_tokenizer()
|
| 276 |
+
|
| 277 |
+
sample = self._run_model(prompt_embedding, negative_prompt_embedding)
|
| 278 |
+
|
| 279 |
+
if self.offload_network:
|
| 280 |
+
self._offload_network()
|
| 281 |
+
|
| 282 |
+
if self.offload_tokenizer:
|
| 283 |
+
self._load_tokenizer()
|
| 284 |
+
|
| 285 |
+
sample = self._run_tokenizer_decoding(sample)
|
| 286 |
+
|
| 287 |
+
if self.offload_tokenizer:
|
| 288 |
+
self._offload_tokenizer()
|
| 289 |
+
return sample
|
| 290 |
+
|
| 291 |
+
def generate(
|
| 292 |
+
self,
|
| 293 |
+
prompt: str,
|
| 294 |
+
negative_prompt: Optional[str] = None,
|
| 295 |
+
word_limit_to_skip_upsampler: Optional[int] = None,
|
| 296 |
+
) -> tuple[np.ndarray, str] | None:
|
| 297 |
+
"""Generate video from text prompt with optional negative prompt guidance.
|
| 298 |
+
|
| 299 |
+
Pipeline steps:
|
| 300 |
+
1. Run safety checks on input prompt
|
| 301 |
+
2. Enhance prompt using upsampler if enabled
|
| 302 |
+
3. Run safety checks on upsampled prompt if applicable
|
| 303 |
+
4. Convert prompt to embeddings
|
| 304 |
+
5. Generate video frames using diffusion
|
| 305 |
+
6. Run safety checks and apply face blur on generated video frames
|
| 306 |
+
|
| 307 |
+
Args:
|
| 308 |
+
prompt: Text description of desired video
|
| 309 |
+
negative_prompt: Optional text to guide what not to generate
|
| 310 |
+
word_limit_to_skip_upsampler: Skip prompt upsampler for better robustness if the number of words in the prompt is greater than this value
|
| 311 |
+
Returns:
|
| 312 |
+
tuple: (
|
| 313 |
+
Generated video frames as uint8 np.ndarray [T, H, W, C],
|
| 314 |
+
Final prompt used for generation (may be enhanced)
|
| 315 |
+
), or None if content fails guardrail safety checks
|
| 316 |
+
"""
|
| 317 |
+
log.info(f"Run with prompt: {prompt}")
|
| 318 |
+
log.info(f"Run with negative prompt: {negative_prompt}")
|
| 319 |
+
log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}")
|
| 320 |
+
|
| 321 |
+
if self.enable_text_guardrail:
|
| 322 |
+
log.info("Run guardrail on prompt")
|
| 323 |
+
is_safe = self._run_guardrail_on_prompt_with_offload(prompt)
|
| 324 |
+
if not is_safe:
|
| 325 |
+
log.critical("Input text prompt is not safe")
|
| 326 |
+
return None
|
| 327 |
+
log.info("Pass guardrail on prompt")
|
| 328 |
+
|
| 329 |
+
# Enhance prompt
|
| 330 |
+
if self.enable_prompt_upsampler:
|
| 331 |
+
word_count = len(prompt.split())
|
| 332 |
+
if word_limit_to_skip_upsampler is None or word_count <= word_limit_to_skip_upsampler:
|
| 333 |
+
log.info("Run prompt upsampler on prompt")
|
| 334 |
+
prompt = self._run_prompt_upsampler_on_prompt_with_offload(prompt)
|
| 335 |
+
if self.enable_text_guardrail:
|
| 336 |
+
log.info("Run guardrail on upsampled prompt")
|
| 337 |
+
is_safe = self._run_guardrail_on_prompt_with_offload(prompt=prompt)
|
| 338 |
+
if not is_safe:
|
| 339 |
+
log.critical("Upsampled text prompt is not safe")
|
| 340 |
+
return None
|
| 341 |
+
log.info("Pass guardrail on upsampled prompt")
|
| 342 |
+
else:
|
| 343 |
+
log.info(
|
| 344 |
+
f"Skip prompt upsampler for better robustness because the number of words ({word_count}) in the prompt is greater than {word_limit_to_skip_upsampler}"
|
| 345 |
+
)
|
| 346 |
+
|
| 347 |
+
log.info("Run text embedding on prompt")
|
| 348 |
+
if negative_prompt:
|
| 349 |
+
prompts = [prompt, negative_prompt]
|
| 350 |
+
else:
|
| 351 |
+
prompts = [prompt]
|
| 352 |
+
prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts)
|
| 353 |
+
prompt_embedding = prompt_embeddings[0]
|
| 354 |
+
negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None
|
| 355 |
+
log.info("Finish text embedding on prompt")
|
| 356 |
+
|
| 357 |
+
# Generate video
|
| 358 |
+
log.info("Run generation")
|
| 359 |
+
video = self._run_model_with_offload(
|
| 360 |
+
prompt_embedding,
|
| 361 |
+
negative_prompt_embedding=negative_prompt_embedding,
|
| 362 |
+
)
|
| 363 |
+
log.info("Finish generation")
|
| 364 |
+
|
| 365 |
+
if self.enable_video_guardrail:
|
| 366 |
+
log.info("Run guardrail on generated video")
|
| 367 |
+
video = self._run_guardrail_on_video_with_offload(video)
|
| 368 |
+
if video is None:
|
| 369 |
+
log.critical("Generated video is not safe")
|
| 370 |
+
return None
|
| 371 |
+
log.info("Pass guardrail on generated video")
|
| 372 |
+
|
| 373 |
+
return video, prompt
|
| 374 |
+
|
| 375 |
+
|
| 376 |
+
class DiffusionVideo2WorldGenerationPipeline(DiffusionText2WorldGenerationPipeline):
|
| 377 |
+
def __init__(
|
| 378 |
+
self,
|
| 379 |
+
inference_type: str,
|
| 380 |
+
checkpoint_dir: str,
|
| 381 |
+
checkpoint_name: str,
|
| 382 |
+
prompt_upsampler_dir: Optional[str] = None,
|
| 383 |
+
enable_prompt_upsampler: bool = True,
|
| 384 |
+
enable_text_guardrail: bool = True,
|
| 385 |
+
enable_video_guardrail: bool = True,
|
| 386 |
+
offload_network: bool = False,
|
| 387 |
+
offload_tokenizer: bool = False,
|
| 388 |
+
offload_text_encoder_model: bool = False,
|
| 389 |
+
offload_prompt_upsampler: bool = False,
|
| 390 |
+
offload_guardrail_models: bool = False,
|
| 391 |
+
guidance: float = 7.0,
|
| 392 |
+
num_steps: int = 35,
|
| 393 |
+
height: int = 704,
|
| 394 |
+
width: int = 1280,
|
| 395 |
+
fps: int = 24,
|
| 396 |
+
num_video_frames: int = 121,
|
| 397 |
+
seed: int = 0,
|
| 398 |
+
num_input_frames: int = 1,
|
| 399 |
+
):
|
| 400 |
+
"""Initialize diffusion world generation pipeline.
|
| 401 |
+
|
| 402 |
+
Args:
|
| 403 |
+
inference_type: Type of world generation ('text2world' or 'video2world')
|
| 404 |
+
checkpoint_dir: Base directory containing model checkpoints
|
| 405 |
+
checkpoint_name: Name of the diffusion transformer checkpoint to use
|
| 406 |
+
prompt_upsampler_dir: Directory containing prompt upsampler model weights
|
| 407 |
+
enable_prompt_upsampler: Whether to use prompt upsampling
|
| 408 |
+
enable_text_guardrail: Whether to enable text guardrail
|
| 409 |
+
enable_video_guardrail: Whether to enable video guardrail
|
| 410 |
+
offload_network: Whether to offload diffusion transformer after inference
|
| 411 |
+
offload_tokenizer: Whether to offload tokenizer after inference
|
| 412 |
+
offload_text_encoder_model: Whether to offload T5 model after inference
|
| 413 |
+
offload_prompt_upsampler: Whether to offload prompt upsampler
|
| 414 |
+
offload_guardrail_models: Whether to offload guardrail models
|
| 415 |
+
guidance: Classifier-free guidance scale
|
| 416 |
+
num_steps: Number of diffusion sampling steps
|
| 417 |
+
height: Height of output video
|
| 418 |
+
width: Width of output video
|
| 419 |
+
fps: Frames per second of output video
|
| 420 |
+
num_video_frames: Number of frames to generate
|
| 421 |
+
seed: Random seed for sampling
|
| 422 |
+
num_input_frames: Number of latent conditions
|
| 423 |
+
"""
|
| 424 |
+
self.num_input_frames = num_input_frames
|
| 425 |
+
super().__init__(
|
| 426 |
+
inference_type=inference_type,
|
| 427 |
+
checkpoint_dir=checkpoint_dir,
|
| 428 |
+
checkpoint_name=checkpoint_name,
|
| 429 |
+
prompt_upsampler_dir=prompt_upsampler_dir,
|
| 430 |
+
enable_prompt_upsampler=enable_prompt_upsampler,
|
| 431 |
+
enable_text_guardrail=enable_text_guardrail,
|
| 432 |
+
enable_video_guardrail=enable_video_guardrail,
|
| 433 |
+
offload_network=offload_network,
|
| 434 |
+
offload_tokenizer=offload_tokenizer,
|
| 435 |
+
offload_text_encoder_model=offload_text_encoder_model,
|
| 436 |
+
offload_prompt_upsampler=offload_prompt_upsampler,
|
| 437 |
+
offload_guardrail_models=offload_guardrail_models,
|
| 438 |
+
guidance=guidance,
|
| 439 |
+
num_steps=num_steps,
|
| 440 |
+
height=height,
|
| 441 |
+
width=width,
|
| 442 |
+
fps=fps,
|
| 443 |
+
num_video_frames=num_video_frames,
|
| 444 |
+
seed=seed,
|
| 445 |
+
)
|
| 446 |
+
|
| 447 |
+
def _run_prompt_upsampler_on_prompt(self, image_or_video_path: str) -> str:
|
| 448 |
+
"""Enhance the input prompt using visual context from the conditioning image.
|
| 449 |
+
|
| 450 |
+
Args:
|
| 451 |
+
image_or_video_path: Path to conditioning image or video used for visual context
|
| 452 |
+
|
| 453 |
+
Returns:
|
| 454 |
+
str: Enhanced prompt incorporating visual details from the image
|
| 455 |
+
"""
|
| 456 |
+
dialog = prepare_dialog(image_or_video_path)
|
| 457 |
+
upsampled_prompt = run_chat_completion_vlm(
|
| 458 |
+
self.prompt_upsampler, dialog, max_gen_len=400, temperature=0.01, top_p=0.9, logprobs=False
|
| 459 |
+
)
|
| 460 |
+
log.info(f"Upsampled prompt: {upsampled_prompt}")
|
| 461 |
+
return upsampled_prompt
|
| 462 |
+
|
| 463 |
+
def _load_prompt_upsampler_model(self):
|
| 464 |
+
self.prompt_upsampler = create_vlm_prompt_upsampler(
|
| 465 |
+
checkpoint_dir=os.path.join(self.checkpoint_dir, self.prompt_upsampler_dir),
|
| 466 |
+
)
|
| 467 |
+
|
| 468 |
+
def _load_model(self):
|
| 469 |
+
self.model = load_model_by_config(
|
| 470 |
+
config_job_name=self.model_name,
|
| 471 |
+
config_file="cosmos1/models/diffusion/config/config.py",
|
| 472 |
+
model_class=DiffusionV2WModel,
|
| 473 |
+
)
|
| 474 |
+
|
| 475 |
+
def _run_model(
|
| 476 |
+
self,
|
| 477 |
+
embedding: torch.Tensor,
|
| 478 |
+
condition_latent: torch.Tensor,
|
| 479 |
+
negative_prompt_embedding: torch.Tensor | None = None,
|
| 480 |
+
) -> torch.Tensor:
|
| 481 |
+
"""Generate video frames using the diffusion model.
|
| 482 |
+
|
| 483 |
+
Args:
|
| 484 |
+
embedding: Text embedding tensor from T5 encoder
|
| 485 |
+
condition_latent: Latent tensor from conditioning image or video
|
| 486 |
+
negative_prompt_embedding: Optional embedding for negative prompt guidance
|
| 487 |
+
|
| 488 |
+
Returns:
|
| 489 |
+
Tensor of generated video frames
|
| 490 |
+
|
| 491 |
+
Note:
|
| 492 |
+
Model and tokenizer are automatically offloaded after inference
|
| 493 |
+
if offloading is enabled.
|
| 494 |
+
"""
|
| 495 |
+
# Get video batch and state shape
|
| 496 |
+
data_batch, state_shape = get_video_batch(
|
| 497 |
+
model=self.model,
|
| 498 |
+
prompt_embedding=embedding,
|
| 499 |
+
negative_prompt_embedding=negative_prompt_embedding,
|
| 500 |
+
height=self.height,
|
| 501 |
+
width=self.width,
|
| 502 |
+
fps=self.fps,
|
| 503 |
+
num_video_frames=self.num_video_frames,
|
| 504 |
+
)
|
| 505 |
+
|
| 506 |
+
# Generate video frames
|
| 507 |
+
video = generate_world_from_video(
|
| 508 |
+
model=self.model,
|
| 509 |
+
state_shape=self.model.state_shape,
|
| 510 |
+
is_negative_prompt=True,
|
| 511 |
+
data_batch=data_batch,
|
| 512 |
+
guidance=self.guidance,
|
| 513 |
+
num_steps=self.num_steps,
|
| 514 |
+
seed=self.seed,
|
| 515 |
+
condition_latent=condition_latent,
|
| 516 |
+
num_input_frames=self.num_input_frames,
|
| 517 |
+
)
|
| 518 |
+
|
| 519 |
+
return video
|
| 520 |
+
|
| 521 |
+
def _run_tokenizer_encoding(self, image_or_video_path: str) -> torch.Tensor:
|
| 522 |
+
"""
|
| 523 |
+
Encode image to latent space
|
| 524 |
+
|
| 525 |
+
Args:
|
| 526 |
+
image_or_video_path: Path to conditioning image
|
| 527 |
+
|
| 528 |
+
Returns:
|
| 529 |
+
torch.Tensor: Latent tensor from tokenizer encoding
|
| 530 |
+
"""
|
| 531 |
+
condition_latent = get_condition_latent(
|
| 532 |
+
model=self.model,
|
| 533 |
+
input_image_or_video_path=image_or_video_path,
|
| 534 |
+
num_input_frames=self.num_input_frames,
|
| 535 |
+
state_shape=self.model.state_shape,
|
| 536 |
+
)
|
| 537 |
+
|
| 538 |
+
return condition_latent
|
| 539 |
+
|
| 540 |
+
def _run_model_with_offload(
|
| 541 |
+
self,
|
| 542 |
+
prompt_embedding: torch.Tensor,
|
| 543 |
+
image_or_video_path: str,
|
| 544 |
+
negative_prompt_embedding: Optional[torch.Tensor] = None,
|
| 545 |
+
) -> np.ndarray:
|
| 546 |
+
"""Generate world representation with automatic model offloading.
|
| 547 |
+
|
| 548 |
+
Wraps the core generation process with model loading/offloading logic
|
| 549 |
+
to minimize GPU memory usage during inference.
|
| 550 |
+
|
| 551 |
+
Args:
|
| 552 |
+
prompt_embedding: Text embedding tensor from T5 encoder
|
| 553 |
+
image_or_video_path: Path to conditioning image or video
|
| 554 |
+
negative_prompt_embedding: Optional embedding for negative prompt guidance
|
| 555 |
+
|
| 556 |
+
Returns:
|
| 557 |
+
np.ndarray: Generated world representation as numpy array
|
| 558 |
+
"""
|
| 559 |
+
if self.offload_tokenizer:
|
| 560 |
+
self._load_tokenizer()
|
| 561 |
+
|
| 562 |
+
condition_latent = self._run_tokenizer_encoding(image_or_video_path)
|
| 563 |
+
|
| 564 |
+
if self.offload_network:
|
| 565 |
+
self._load_network()
|
| 566 |
+
|
| 567 |
+
sample = self._run_model(prompt_embedding, condition_latent, negative_prompt_embedding)
|
| 568 |
+
|
| 569 |
+
if self.offload_network:
|
| 570 |
+
self._offload_network()
|
| 571 |
+
|
| 572 |
+
sample = self._run_tokenizer_decoding(sample)
|
| 573 |
+
|
| 574 |
+
if self.offload_tokenizer:
|
| 575 |
+
self._offload_tokenizer()
|
| 576 |
+
|
| 577 |
+
return sample
|
| 578 |
+
|
| 579 |
+
def generate(
|
| 580 |
+
self,
|
| 581 |
+
prompt: str,
|
| 582 |
+
image_or_video_path: str,
|
| 583 |
+
negative_prompt: Optional[str] = None,
|
| 584 |
+
) -> tuple[np.ndarray, str] | None:
|
| 585 |
+
"""Generate video from text prompt and optional image.
|
| 586 |
+
|
| 587 |
+
Pipeline steps:
|
| 588 |
+
1. Run safety checks on input prompt
|
| 589 |
+
2. Enhance prompt using upsampler if enabled
|
| 590 |
+
3. Run safety checks on upsampled prompt if applicable
|
| 591 |
+
4. Convert prompt to embeddings
|
| 592 |
+
5. Generate video frames using diffusion
|
| 593 |
+
6. Run safety checks and apply face blur on generated video frames
|
| 594 |
+
|
| 595 |
+
Args:
|
| 596 |
+
prompt: Text description of desired video
|
| 597 |
+
image_or_video_path: Path to conditioning image or video
|
| 598 |
+
negative_prompt: Optional text to guide what not to generate
|
| 599 |
+
|
| 600 |
+
Returns:
|
| 601 |
+
tuple: (
|
| 602 |
+
Generated video frames as uint8 np.ndarray [T, H, W, C],
|
| 603 |
+
Final prompt used for generation (may be enhanced)
|
| 604 |
+
), or None if content fails guardrail safety checks
|
| 605 |
+
"""
|
| 606 |
+
log.info(f"Run with prompt: {prompt}")
|
| 607 |
+
log.info(f"Run with image or video path: {image_or_video_path}")
|
| 608 |
+
log.info(f"Run with negative prompt: {negative_prompt}")
|
| 609 |
+
log.info(f"Run with prompt upsampler: {self.enable_prompt_upsampler}")
|
| 610 |
+
|
| 611 |
+
if self.enable_text_guardrail and not self.enable_prompt_upsampler:
|
| 612 |
+
log.info("Run guardrail on prompt")
|
| 613 |
+
is_safe = self._run_guardrail_on_prompt_with_offload(prompt)
|
| 614 |
+
if not is_safe:
|
| 615 |
+
log.critical("Input text prompt is not safe")
|
| 616 |
+
return None
|
| 617 |
+
log.info("Pass guardrail on prompt")
|
| 618 |
+
|
| 619 |
+
# Enhance prompt
|
| 620 |
+
if self.enable_prompt_upsampler:
|
| 621 |
+
log.info("Run prompt upsampler on image or video, input prompt is not used")
|
| 622 |
+
prompt = self._run_prompt_upsampler_on_prompt_with_offload(image_or_video_path=image_or_video_path)
|
| 623 |
+
if self.enable_text_guardrail:
|
| 624 |
+
log.info("Run guardrail on upsampled prompt")
|
| 625 |
+
is_safe = self._run_guardrail_on_prompt_with_offload(prompt)
|
| 626 |
+
if not is_safe:
|
| 627 |
+
log.critical("Upsampled text prompt is not safe")
|
| 628 |
+
return None
|
| 629 |
+
log.info("Pass guardrail on upsampled prompt")
|
| 630 |
+
|
| 631 |
+
log.info("Run text embedding on prompt")
|
| 632 |
+
if negative_prompt:
|
| 633 |
+
prompts = [prompt, negative_prompt]
|
| 634 |
+
else:
|
| 635 |
+
prompts = [prompt]
|
| 636 |
+
prompt_embeddings, _ = self._run_text_embedding_on_prompt_with_offload(prompts)
|
| 637 |
+
prompt_embedding = prompt_embeddings[0]
|
| 638 |
+
negative_prompt_embedding = prompt_embeddings[1] if negative_prompt else None
|
| 639 |
+
log.info("Finish text embedding on prompt")
|
| 640 |
+
|
| 641 |
+
# Generate video
|
| 642 |
+
log.info("Run generation")
|
| 643 |
+
video = self._run_model_with_offload(
|
| 644 |
+
prompt_embedding,
|
| 645 |
+
negative_prompt_embedding=negative_prompt_embedding,
|
| 646 |
+
image_or_video_path=image_or_video_path,
|
| 647 |
+
)
|
| 648 |
+
log.info("Finish generation")
|
| 649 |
+
|
| 650 |
+
if self.enable_video_guardrail:
|
| 651 |
+
log.info("Run guardrail on generated video")
|
| 652 |
+
video = self._run_guardrail_on_video_with_offload(video)
|
| 653 |
+
if video is None:
|
| 654 |
+
log.critical("Generated video is not safe")
|
| 655 |
+
return None
|
| 656 |
+
log.info("Pass guardrail on generated video")
|
| 657 |
+
|
| 658 |
+
return video, prompt
|