DiffusionVideo2WorldGeneration / video2world_hf.py
tchoudha21's picture
Upload modified files
5a33b3a verified
raw
history blame
12.1 kB
# SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved.
# SPDX-License-Identifier: Apache-2.0
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
import argparse
import os
from Cosmos.utils import misc
import torch
from Cosmos.inference_utils import add_common_arguments, check_input_frames, validate_args
from Cosmos.world_generation_pipeline import DiffusionVideo2WorldGenerationPipeline
from Cosmos.utils import log
from Cosmos.utils.io import read_prompts_from_file, save_video
from Cosmos.download_diffusion import main as download_diffusion
from transformers import PreTrainedModel, PretrainedConfig
torch.enable_grad(False)
#custom config class
class DiffusionVideo2WorldConfig(PretrainedConfig):
model_type = "DiffusionVideo2World"
def __init__(self, **kwargs):
super().__init__(**kwargs)
self.checkpoint_dir = kwargs.get("checkpoint_dir", "checkpoints")
self.tokenizer_dir = kwargs.get("tokenizer_dir", "Cosmos-1.0-Tokenizer-CV8x8x8")
self.video_save_name = kwargs.get("video_save_name", "output")
self.video_save_folder = kwargs.get("video_save_folder", "outputs/")
self.prompt = kwargs.get("prompt", None)
self.batch_input_path = kwargs.get("batch_input_path", None)
self.negative_prompt = kwargs.get("negative_prompt", None)
self.num_steps = kwargs.get("num_steps", 35)
self.guidance = kwargs.get("guidance", 7)
self.num_video_frames = kwargs.get("num_video_frames", 121)
self.height = kwargs.get("height", 704)
self.width = kwargs.get("width", 1280)
self.fps = kwargs.get("fps", 24)
self.seed = kwargs.get("seed", 1)
self.disable_prompt_upsampler = kwargs.get("disable_prompt_upsampler", False)
self.offload_diffusion_transformer = kwargs.get("offload_diffusion_transformer", False)
self.offload_tokenizer = kwargs.get("offload_tokenizer", False)
self.offload_text_encoder_model = kwargs.get("offload_text_encoder_model", False)
self.offload_prompt_upsampler = kwargs.get("offload_prompt_upsampler", False)
self.offload_guardrail_models = kwargs.get("offload_guardrail_models", False)
self.diffusion_transformer_dir = kwargs.get("diffusion_transformer_dir", "Cosmos-1.0-Diffusion-7B-Video2World")
self.prompt_upsampler_dir = kwargs.get("prompt_upsampler_dir", "Pixtral-12B")
self.input_image_or_video_path = kwargs.get("input_image_or_video_path", None)
self.num_input_frames = kwargs.get("num_input_frames", 1)
class DiffusionVideo2World(PreTrainedModel):
config_class = DiffusionVideo2WorldConfig
def __init__(self, config=DiffusionVideo2WorldConfig()):
super().__init__(config)
cfg = config
misc.set_random_seed(cfg.seed)
inference_type = "video2world"
validate_args(cfg, inference_type)
self.pipeline = DiffusionVideo2WorldGenerationPipeline(
inference_type=inference_type,
checkpoint_dir=cfg.checkpoint_dir,
checkpoint_name=cfg.diffusion_transformer_dir,
prompt_upsampler_dir=cfg.prompt_upsampler_dir,
enable_prompt_upsampler=not cfg.disable_prompt_upsampler,
offload_network=cfg.offload_diffusion_transformer,
offload_tokenizer=cfg.offload_tokenizer,
offload_text_encoder_model=cfg.offload_text_encoder_model,
offload_prompt_upsampler=cfg.offload_prompt_upsampler,
offload_guardrail_models=cfg.offload_guardrail_models,
guidance=cfg.guidance,
num_steps=cfg.num_steps,
height=cfg.height,
width=cfg.width,
fps=cfg.fps,
num_video_frames=cfg.num_video_frames,
seed=cfg.seed,
num_input_frames=cfg.num_input_frames,
)
def forward(self):
cfg = self.config
# Handle multiple prompts if prompt file is provided
if cfg.batch_input_path:
log.info(f"Reading batch inputs from path: {args.batch_input_path}")
prompts = read_prompts_from_file(cfg.batch_input_path)
else:
# Single prompt case
prompts = [{"prompt": cfg.prompt, "visual_input": cfg.input_image_or_video_path}]
os.makedirs(cfg.video_save_folder, exist_ok=True)
for i, input_dict in enumerate(prompts):
current_prompt = input_dict.get("prompt", None)
if current_prompt is None and cfg.disable_prompt_upsampler:
log.critical("Prompt is missing, skipping world generation.")
continue
current_image_or_video_path = input_dict.get("visual_input", None)
if current_image_or_video_path is None:
log.critical("Visual input is missing, skipping world generation.")
continue
# Check input frames
if not check_input_frames(current_image_or_video_path, cfg.num_input_frames):
continue
# Generate video
generated_output = pipeline.generate(
prompt=current_prompt,
image_or_video_path=current_image_or_video_path,
negative_prompt=cfg.negative_prompt,
)
if generated_output is None:
log.critical("Guardrail blocked video2world generation.")
continue
video, prompt = generated_output
if cfg.batch_input_path:
video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4")
prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt")
else:
video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
# Save video
save_video(
video=video,
fps=cfg.fps,
H=cfg.height,
W=cfg.width,
video_save_quality=5,
video_save_path=video_save_path,
)
# Save prompt to text file alongside video
with open(prompt_save_path, "wb") as f:
f.write(prompt.encode("utf-8"))
log.info(f"Saved video to {video_save_path}")
log.info(f"Saved prompt to {prompt_save_path}")
def save_pretrained(self, save_directory, **kwargs):
# We don't save anything, but need this function to override
pass
@classmethod
def from_pretrained(cls, pretrained_model_name_or_path, *model_args, **kwargs):
config = kwargs["config"]
other_args = kwargs.copy()
other_args.pop("config")
config.update(other_args)
model_sizes = ["7B",] if "7B" in config.diffusion_transformer_dir else ["14B",]
model_types = ["Video2World",]
download_diffusion(model_types, model_sizes, config.checkpoint_dir)
model = cls(config)
return model
def demo(cfg):
"""Run video-to-world generation demo.
This function handles the main video-to-world generation pipeline, including:
- Setting up the random seed for reproducibility
- Initializing the generation pipeline with the provided configuration
- Processing single or multiple prompts/images/videos from input
- Generating videos from prompts and images/videos
- Saving the generated videos and corresponding prompts to disk
Args:
cfg (argparse.Namespace): Configuration namespace containing:
- Model configuration (checkpoint paths, model settings)
- Generation parameters (guidance, steps, dimensions)
- Input/output settings (prompts/images/videos, save paths)
- Performance options (model offloading settings)
The function will save:
- Generated MP4 video files
- Text files containing the processed prompts
If guardrails block the generation, a critical log message is displayed
and the function continues to the next prompt if available.
"""
misc.set_random_seed(cfg.seed)
inference_type = "video2world"
validate_args(cfg, inference_type)
# Initialize video2world generation model pipeline
pipeline = DiffusionVideo2WorldGenerationPipeline(
inference_type=inference_type,
checkpoint_dir=cfg.checkpoint_dir,
checkpoint_name=cfg.diffusion_transformer_dir,
prompt_upsampler_dir=cfg.prompt_upsampler_dir,
enable_prompt_upsampler=not cfg.disable_prompt_upsampler,
offload_network=cfg.offload_diffusion_transformer,
offload_tokenizer=cfg.offload_tokenizer,
offload_text_encoder_model=cfg.offload_text_encoder_model,
offload_prompt_upsampler=cfg.offload_prompt_upsampler,
offload_guardrail_models=cfg.offload_guardrail_models,
guidance=cfg.guidance,
num_steps=cfg.num_steps,
height=cfg.height,
width=cfg.width,
fps=cfg.fps,
num_video_frames=cfg.num_video_frames,
seed=cfg.seed,
num_input_frames=cfg.num_input_frames,
)
# Handle multiple prompts if prompt file is provided
if cfg.batch_input_path:
log.info(f"Reading batch inputs from path: {args.batch_input_path}")
prompts = read_prompts_from_file(cfg.batch_input_path)
else:
# Single prompt case
prompts = [{"prompt": cfg.prompt, "visual_input": cfg.input_image_or_video_path}]
os.makedirs(cfg.video_save_folder, exist_ok=True)
for i, input_dict in enumerate(prompts):
current_prompt = input_dict.get("prompt", None)
if current_prompt is None and cfg.disable_prompt_upsampler:
log.critical("Prompt is missing, skipping world generation.")
continue
current_image_or_video_path = input_dict.get("visual_input", None)
if current_image_or_video_path is None:
log.critical("Visual input is missing, skipping world generation.")
continue
# Check input frames
if not check_input_frames(current_image_or_video_path, cfg.num_input_frames):
continue
# Generate video
generated_output = pipeline.generate(
prompt=current_prompt,
image_or_video_path=current_image_or_video_path,
negative_prompt=cfg.negative_prompt,
)
if generated_output is None:
log.critical("Guardrail blocked video2world generation.")
continue
video, prompt = generated_output
if cfg.batch_input_path:
video_save_path = os.path.join(cfg.video_save_folder, f"{i}.mp4")
prompt_save_path = os.path.join(cfg.video_save_folder, f"{i}.txt")
else:
video_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.mp4")
prompt_save_path = os.path.join(cfg.video_save_folder, f"{cfg.video_save_name}.txt")
# Save video
save_video(
video=video,
fps=cfg.fps,
H=cfg.height,
W=cfg.width,
video_save_quality=5,
video_save_path=video_save_path,
)
# Save prompt to text file alongside video
with open(prompt_save_path, "wb") as f:
f.write(prompt.encode("utf-8"))
log.info(f"Saved video to {video_save_path}")
log.info(f"Saved prompt to {prompt_save_path}")
if __name__ == "__main__":
args = parse_arguments()
demo(args)