QwenTest
/
pythonProject
/diffusers-main
/src
/diffusers
/pipelines
/cosmos
/pipeline_cosmos_text2world.py
| # Copyright 2025 The NVIDIA Team and The HuggingFace Team. All rights reserved. | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| import inspect | |
| from typing import Callable, Dict, List, Optional, Union | |
| import numpy as np | |
| import torch | |
| from transformers import T5EncoderModel, T5TokenizerFast | |
| from ...callbacks import MultiPipelineCallbacks, PipelineCallback | |
| from ...models import AutoencoderKLCosmos, CosmosTransformer3DModel | |
| from ...schedulers import EDMEulerScheduler | |
| from ...utils import is_cosmos_guardrail_available, is_torch_xla_available, logging, replace_example_docstring | |
| from ...utils.torch_utils import randn_tensor | |
| from ...video_processor import VideoProcessor | |
| from ..pipeline_utils import DiffusionPipeline | |
| from .pipeline_output import CosmosPipelineOutput | |
| if is_cosmos_guardrail_available(): | |
| from cosmos_guardrail import CosmosSafetyChecker | |
| else: | |
| class CosmosSafetyChecker: | |
| def __init__(self, *args, **kwargs): | |
| raise ImportError( | |
| "`cosmos_guardrail` is not installed. Please install it to use the safety checker for Cosmos: `pip install cosmos_guardrail`." | |
| ) | |
| if is_torch_xla_available(): | |
| import torch_xla.core.xla_model as xm | |
| XLA_AVAILABLE = True | |
| else: | |
| XLA_AVAILABLE = False | |
| logger = logging.get_logger(__name__) # pylint: disable=invalid-name | |
| EXAMPLE_DOC_STRING = """ | |
| Examples: | |
| ```python | |
| >>> import torch | |
| >>> from diffusers import CosmosTextToWorldPipeline | |
| >>> from diffusers.utils import export_to_video | |
| >>> model_id = "nvidia/Cosmos-1.0-Diffusion-7B-Text2World" | |
| >>> pipe = CosmosTextToWorldPipeline.from_pretrained(model_id, torch_dtype=torch.bfloat16) | |
| >>> pipe.to("cuda") | |
| >>> prompt = "A sleek, humanoid robot stands in a vast warehouse filled with neatly stacked cardboard boxes on industrial shelves. The robot's metallic body gleams under the bright, even lighting, highlighting its futuristic design and intricate joints. A glowing blue light emanates from its chest, adding a touch of advanced technology. The background is dominated by rows of boxes, suggesting a highly organized storage system. The floor is lined with wooden pallets, enhancing the industrial setting. The camera remains static, capturing the robot's poised stance amidst the orderly environment, with a shallow depth of field that keeps the focus on the robot while subtly blurring the background for a cinematic effect." | |
| >>> output = pipe(prompt=prompt).frames[0] | |
| >>> export_to_video(output, "output.mp4", fps=30) | |
| ``` | |
| """ | |
| # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps | |
| def retrieve_timesteps( | |
| scheduler, | |
| num_inference_steps: Optional[int] = None, | |
| device: Optional[Union[str, torch.device]] = None, | |
| timesteps: Optional[List[int]] = None, | |
| sigmas: Optional[List[float]] = None, | |
| **kwargs, | |
| ): | |
| r""" | |
| Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles | |
| custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`. | |
| Args: | |
| scheduler (`SchedulerMixin`): | |
| The scheduler to get timesteps from. | |
| num_inference_steps (`int`): | |
| The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps` | |
| must be `None`. | |
| device (`str` or `torch.device`, *optional*): | |
| The device to which the timesteps should be moved to. If `None`, the timesteps are not moved. | |
| timesteps (`List[int]`, *optional*): | |
| Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed, | |
| `num_inference_steps` and `sigmas` must be `None`. | |
| sigmas (`List[float]`, *optional*): | |
| Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed, | |
| `num_inference_steps` and `timesteps` must be `None`. | |
| Returns: | |
| `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the | |
| second element is the number of inference steps. | |
| """ | |
| if timesteps is not None and sigmas is not None: | |
| raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values") | |
| if timesteps is not None: | |
| accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
| if not accepts_timesteps: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" timestep schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| elif sigmas is not None: | |
| accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys()) | |
| if not accept_sigmas: | |
| raise ValueError( | |
| f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom" | |
| f" sigmas schedules. Please check whether you are using the correct scheduler." | |
| ) | |
| scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| num_inference_steps = len(timesteps) | |
| else: | |
| scheduler.set_timesteps(num_inference_steps, device=device, **kwargs) | |
| timesteps = scheduler.timesteps | |
| return timesteps, num_inference_steps | |
| class CosmosTextToWorldPipeline(DiffusionPipeline): | |
| r""" | |
| Pipeline for text-to-world generation using [Cosmos Predict1](https://github.com/nvidia-cosmos/cosmos-predict1). | |
| This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods | |
| implemented for all pipelines (downloading, saving, running on a particular device, etc.). | |
| Args: | |
| text_encoder ([`T5EncoderModel`]): | |
| Frozen text-encoder. Cosmos uses | |
| [T5](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5EncoderModel); specifically the | |
| [t5-11b](https://huggingface.co/google-t5/t5-11b) variant. | |
| tokenizer (`T5TokenizerFast`): | |
| Tokenizer of class | |
| [T5Tokenizer](https://huggingface.co/docs/transformers/model_doc/t5#transformers.T5Tokenizer). | |
| transformer ([`CosmosTransformer3DModel`]): | |
| Conditional Transformer to denoise the encoded image latents. | |
| scheduler ([`FlowMatchEulerDiscreteScheduler`]): | |
| A scheduler to be used in combination with `transformer` to denoise the encoded image latents. | |
| vae ([`AutoencoderKLCosmos`]): | |
| Variational Auto-Encoder (VAE) Model to encode and decode videos to and from latent representations. | |
| """ | |
| model_cpu_offload_seq = "text_encoder->transformer->vae" | |
| _callback_tensor_inputs = ["latents", "prompt_embeds", "negative_prompt_embeds"] | |
| # We mark safety_checker as optional here to get around some test failures, but it is not really optional | |
| _optional_components = ["safety_checker"] | |
| def __init__( | |
| self, | |
| text_encoder: T5EncoderModel, | |
| tokenizer: T5TokenizerFast, | |
| transformer: CosmosTransformer3DModel, | |
| vae: AutoencoderKLCosmos, | |
| scheduler: EDMEulerScheduler, | |
| safety_checker: CosmosSafetyChecker = None, | |
| ): | |
| super().__init__() | |
| if safety_checker is None: | |
| safety_checker = CosmosSafetyChecker() | |
| self.register_modules( | |
| vae=vae, | |
| text_encoder=text_encoder, | |
| tokenizer=tokenizer, | |
| transformer=transformer, | |
| scheduler=scheduler, | |
| safety_checker=safety_checker, | |
| ) | |
| self.vae_scale_factor_temporal = ( | |
| self.vae.config.temporal_compression_ratio if getattr(self, "vae", None) else 8 | |
| ) | |
| self.vae_scale_factor_spatial = self.vae.config.spatial_compression_ratio if getattr(self, "vae", None) else 8 | |
| self.video_processor = VideoProcessor(vae_scale_factor=self.vae_scale_factor_spatial) | |
| def _get_t5_prompt_embeds( | |
| self, | |
| prompt: Union[str, List[str]] = None, | |
| max_sequence_length: int = 512, | |
| device: Optional[torch.device] = None, | |
| dtype: Optional[torch.dtype] = None, | |
| ): | |
| device = device or self._execution_device | |
| dtype = dtype or self.text_encoder.dtype | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| text_inputs = self.tokenizer( | |
| prompt, | |
| padding="max_length", | |
| max_length=max_sequence_length, | |
| truncation=True, | |
| return_tensors="pt", | |
| return_length=True, | |
| return_offsets_mapping=False, | |
| ) | |
| text_input_ids = text_inputs.input_ids | |
| prompt_attention_mask = text_inputs.attention_mask.bool().to(device) | |
| untruncated_ids = self.tokenizer(prompt, padding="longest", return_tensors="pt").input_ids | |
| if untruncated_ids.shape[-1] >= text_input_ids.shape[-1] and not torch.equal(text_input_ids, untruncated_ids): | |
| removed_text = self.tokenizer.batch_decode(untruncated_ids[:, max_sequence_length - 1 : -1]) | |
| logger.warning( | |
| "The following part of your input was truncated because `max_sequence_length` is set to " | |
| f" {max_sequence_length} tokens: {removed_text}" | |
| ) | |
| prompt_embeds = self.text_encoder( | |
| text_input_ids.to(device), attention_mask=prompt_attention_mask | |
| ).last_hidden_state | |
| prompt_embeds = prompt_embeds.to(dtype=dtype, device=device) | |
| lengths = prompt_attention_mask.sum(dim=1).cpu() | |
| for i, length in enumerate(lengths): | |
| prompt_embeds[i, length:] = 0 | |
| return prompt_embeds | |
| def encode_prompt( | |
| self, | |
| prompt: Union[str, List[str]], | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| do_classifier_free_guidance: bool = True, | |
| num_videos_per_prompt: int = 1, | |
| prompt_embeds: Optional[torch.Tensor] = None, | |
| negative_prompt_embeds: Optional[torch.Tensor] = None, | |
| max_sequence_length: int = 512, | |
| device: Optional[torch.device] = None, | |
| dtype: Optional[torch.dtype] = None, | |
| ): | |
| r""" | |
| Encodes the prompt into text encoder hidden states. | |
| Args: | |
| prompt (`str` or `List[str]`, *optional*): | |
| prompt to be encoded | |
| negative_prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts not to guide the image generation. If not defined, one has to pass | |
| `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is | |
| less than `1`). | |
| do_classifier_free_guidance (`bool`, *optional*, defaults to `True`): | |
| Whether to use classifier free guidance or not. | |
| num_videos_per_prompt (`int`, *optional*, defaults to 1): | |
| Number of videos that should be generated per prompt. torch device to place the resulting embeddings on | |
| prompt_embeds (`torch.Tensor`, *optional*): | |
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
| provided, text embeddings will be generated from `prompt` input argument. | |
| negative_prompt_embeds (`torch.Tensor`, *optional*): | |
| Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt | |
| weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input | |
| argument. | |
| device: (`torch.device`, *optional*): | |
| torch device | |
| dtype: (`torch.dtype`, *optional*): | |
| torch dtype | |
| """ | |
| device = device or self._execution_device | |
| prompt = [prompt] if isinstance(prompt, str) else prompt | |
| if prompt is not None: | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| if prompt_embeds is None: | |
| prompt_embeds = self._get_t5_prompt_embeds( | |
| prompt=prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype | |
| ) | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| _, seq_len, _ = prompt_embeds.shape | |
| prompt_embeds = prompt_embeds.repeat(1, num_videos_per_prompt, 1) | |
| prompt_embeds = prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) | |
| if do_classifier_free_guidance and negative_prompt_embeds is None: | |
| negative_prompt = negative_prompt or "" | |
| negative_prompt = batch_size * [negative_prompt] if isinstance(negative_prompt, str) else negative_prompt | |
| if prompt is not None and type(prompt) is not type(negative_prompt): | |
| raise TypeError( | |
| f"`negative_prompt` should be the same type to `prompt`, but got {type(negative_prompt)} !=" | |
| f" {type(prompt)}." | |
| ) | |
| elif batch_size != len(negative_prompt): | |
| raise ValueError( | |
| f"`negative_prompt`: {negative_prompt} has batch size {len(negative_prompt)}, but `prompt`:" | |
| f" {prompt} has batch size {batch_size}. Please make sure that passed `negative_prompt` matches" | |
| " the batch size of `prompt`." | |
| ) | |
| negative_prompt_embeds = self._get_t5_prompt_embeds( | |
| prompt=negative_prompt, max_sequence_length=max_sequence_length, device=device, dtype=dtype | |
| ) | |
| # duplicate text embeddings for each generation per prompt, using mps friendly method | |
| _, seq_len, _ = negative_prompt_embeds.shape | |
| negative_prompt_embeds = negative_prompt_embeds.repeat(1, num_videos_per_prompt, 1) | |
| negative_prompt_embeds = negative_prompt_embeds.view(batch_size * num_videos_per_prompt, seq_len, -1) | |
| return prompt_embeds, negative_prompt_embeds | |
| def prepare_latents( | |
| self, | |
| batch_size: int, | |
| num_channels_latents: 16, | |
| height: int = 704, | |
| width: int = 1280, | |
| num_frames: int = 121, | |
| dtype: Optional[torch.dtype] = None, | |
| device: Optional[torch.device] = None, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.Tensor] = None, | |
| ) -> torch.Tensor: | |
| if latents is not None: | |
| return latents.to(device=device, dtype=dtype) * self.scheduler.config.sigma_max | |
| num_latent_frames = (num_frames - 1) // self.vae_scale_factor_temporal + 1 | |
| latent_height = height // self.vae_scale_factor_spatial | |
| latent_width = width // self.vae_scale_factor_spatial | |
| shape = (batch_size, num_channels_latents, num_latent_frames, latent_height, latent_width) | |
| if isinstance(generator, list) and len(generator) != batch_size: | |
| raise ValueError( | |
| f"You have passed a list of generators of length {len(generator)}, but requested an effective batch" | |
| f" size of {batch_size}. Make sure the batch size matches the length of the generators." | |
| ) | |
| latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype) | |
| return latents * self.scheduler.config.sigma_max | |
| def check_inputs( | |
| self, | |
| prompt, | |
| height, | |
| width, | |
| prompt_embeds=None, | |
| callback_on_step_end_tensor_inputs=None, | |
| ): | |
| if height % 16 != 0 or width % 16 != 0: | |
| raise ValueError(f"`height` and `width` have to be divisible by 16 but are {height} and {width}.") | |
| if callback_on_step_end_tensor_inputs is not None and not all( | |
| k in self._callback_tensor_inputs for k in callback_on_step_end_tensor_inputs | |
| ): | |
| raise ValueError( | |
| f"`callback_on_step_end_tensor_inputs` has to be in {self._callback_tensor_inputs}, but found {[k for k in callback_on_step_end_tensor_inputs if k not in self._callback_tensor_inputs]}" | |
| ) | |
| if prompt is not None and prompt_embeds is not None: | |
| raise ValueError( | |
| f"Cannot forward both `prompt`: {prompt} and `prompt_embeds`: {prompt_embeds}. Please make sure to" | |
| " only forward one of the two." | |
| ) | |
| elif prompt is None and prompt_embeds is None: | |
| raise ValueError( | |
| "Provide either `prompt` or `prompt_embeds`. Cannot leave both `prompt` and `prompt_embeds` undefined." | |
| ) | |
| elif prompt is not None and (not isinstance(prompt, str) and not isinstance(prompt, list)): | |
| raise ValueError(f"`prompt` has to be of type `str` or `list` but is {type(prompt)}") | |
| def guidance_scale(self): | |
| return self._guidance_scale | |
| def do_classifier_free_guidance(self): | |
| return self._guidance_scale > 1.0 | |
| def num_timesteps(self): | |
| return self._num_timesteps | |
| def current_timestep(self): | |
| return self._current_timestep | |
| def interrupt(self): | |
| return self._interrupt | |
| def __call__( | |
| self, | |
| prompt: Union[str, List[str]] = None, | |
| negative_prompt: Optional[Union[str, List[str]]] = None, | |
| height: int = 704, | |
| width: int = 1280, | |
| num_frames: int = 121, | |
| num_inference_steps: int = 36, | |
| guidance_scale: float = 7.0, | |
| fps: int = 30, | |
| num_videos_per_prompt: Optional[int] = 1, | |
| generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None, | |
| latents: Optional[torch.Tensor] = None, | |
| prompt_embeds: Optional[torch.Tensor] = None, | |
| negative_prompt_embeds: Optional[torch.Tensor] = None, | |
| output_type: Optional[str] = "pil", | |
| return_dict: bool = True, | |
| callback_on_step_end: Optional[ | |
| Union[Callable[[int, int, Dict], None], PipelineCallback, MultiPipelineCallbacks] | |
| ] = None, | |
| callback_on_step_end_tensor_inputs: List[str] = ["latents"], | |
| max_sequence_length: int = 512, | |
| ): | |
| r""" | |
| The call function to the pipeline for generation. | |
| Args: | |
| prompt (`str` or `List[str]`, *optional*): | |
| The prompt or prompts to guide the image generation. If not defined, one has to pass `prompt_embeds`. | |
| instead. | |
| height (`int`, defaults to `720`): | |
| The height in pixels of the generated image. | |
| width (`int`, defaults to `1280`): | |
| The width in pixels of the generated image. | |
| num_frames (`int`, defaults to `121`): | |
| The number of frames in the generated video. | |
| num_inference_steps (`int`, defaults to `36`): | |
| The number of denoising steps. More denoising steps usually lead to a higher quality image at the | |
| expense of slower inference. | |
| guidance_scale (`float`, defaults to `7.0`): | |
| Guidance scale as defined in [Classifier-Free Diffusion | |
| Guidance](https://huggingface.co/papers/2207.12598). `guidance_scale` is defined as `w` of equation 2. | |
| of [Imagen Paper](https://huggingface.co/papers/2205.11487). Guidance scale is enabled by setting | |
| `guidance_scale > 1`. | |
| fps (`int`, defaults to `30`): | |
| The frames per second of the generated video. | |
| num_videos_per_prompt (`int`, *optional*, defaults to 1): | |
| The number of images to generate per prompt. | |
| generator (`torch.Generator` or `List[torch.Generator]`, *optional*): | |
| A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make | |
| generation deterministic. | |
| latents (`torch.Tensor`, *optional*): | |
| Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image | |
| generation. Can be used to tweak the same generation with different prompts. If not provided, a latents | |
| tensor is generated by sampling using the supplied random `generator`. | |
| prompt_embeds (`torch.Tensor`, *optional*): | |
| Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not | |
| provided, text embeddings will be generated from `prompt` input argument. | |
| negative_prompt_embeds (`torch.FloatTensor`, *optional*): | |
| Pre-generated negative text embeddings. For PixArt-Sigma this negative prompt should be "". If not | |
| provided, negative_prompt_embeds will be generated from `negative_prompt` input argument. | |
| output_type (`str`, *optional*, defaults to `"pil"`): | |
| The output format of the generated image. Choose between `PIL.Image` or `np.array`. | |
| return_dict (`bool`, *optional*, defaults to `True`): | |
| Whether or not to return a [`CosmosPipelineOutput`] instead of a plain tuple. | |
| callback_on_step_end (`Callable`, `PipelineCallback`, `MultiPipelineCallbacks`, *optional*): | |
| A function or a subclass of `PipelineCallback` or `MultiPipelineCallbacks` that is called at the end of | |
| each denoising step during the inference. with the following arguments: `callback_on_step_end(self: | |
| DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`. `callback_kwargs` will include a | |
| list of all tensors as specified by `callback_on_step_end_tensor_inputs`. | |
| callback_on_step_end_tensor_inputs (`List`, *optional*): | |
| The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list | |
| will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the | |
| `._callback_tensor_inputs` attribute of your pipeline class. | |
| Examples: | |
| Returns: | |
| [`~CosmosPipelineOutput`] or `tuple`: | |
| If `return_dict` is `True`, [`CosmosPipelineOutput`] is returned, otherwise a `tuple` is returned where | |
| the first element is a list with the generated images and the second element is a list of `bool`s | |
| indicating whether the corresponding generated image contains "not-safe-for-work" (nsfw) content. | |
| """ | |
| if self.safety_checker is None: | |
| raise ValueError( | |
| f"You have disabled the safety checker for {self.__class__}. This is in violation of the " | |
| "[NVIDIA Open Model License Agreement](https://www.nvidia.com/en-us/agreements/enterprise-software/nvidia-open-model-license). " | |
| f"Please ensure that you are compliant with the license agreement." | |
| ) | |
| if isinstance(callback_on_step_end, (PipelineCallback, MultiPipelineCallbacks)): | |
| callback_on_step_end_tensor_inputs = callback_on_step_end.tensor_inputs | |
| # 1. Check inputs. Raise error if not correct | |
| self.check_inputs(prompt, height, width, prompt_embeds, callback_on_step_end_tensor_inputs) | |
| self._guidance_scale = guidance_scale | |
| self._current_timestep = None | |
| self._interrupt = False | |
| device = self._execution_device | |
| if self.safety_checker is not None: | |
| self.safety_checker.to(device) | |
| if prompt is not None: | |
| prompt_list = [prompt] if isinstance(prompt, str) else prompt | |
| for p in prompt_list: | |
| if not self.safety_checker.check_text_safety(p): | |
| raise ValueError( | |
| f"Cosmos Guardrail detected unsafe text in the prompt: {p}. Please ensure that the " | |
| f"prompt abides by the NVIDIA Open Model License Agreement." | |
| ) | |
| self.safety_checker.to("cpu") | |
| # 2. Define call parameters | |
| if prompt is not None and isinstance(prompt, str): | |
| batch_size = 1 | |
| elif prompt is not None and isinstance(prompt, list): | |
| batch_size = len(prompt) | |
| else: | |
| batch_size = prompt_embeds.shape[0] | |
| # 3. Encode input prompt | |
| ( | |
| prompt_embeds, | |
| negative_prompt_embeds, | |
| ) = self.encode_prompt( | |
| prompt=prompt, | |
| negative_prompt=negative_prompt, | |
| do_classifier_free_guidance=self.do_classifier_free_guidance, | |
| num_videos_per_prompt=num_videos_per_prompt, | |
| prompt_embeds=prompt_embeds, | |
| negative_prompt_embeds=negative_prompt_embeds, | |
| device=device, | |
| max_sequence_length=max_sequence_length, | |
| ) | |
| # 4. Prepare timesteps | |
| timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device) | |
| # 5. Prepare latent variables | |
| transformer_dtype = self.transformer.dtype | |
| num_channels_latents = self.transformer.config.in_channels | |
| latents = self.prepare_latents( | |
| batch_size * num_videos_per_prompt, | |
| num_channels_latents, | |
| height, | |
| width, | |
| num_frames, | |
| torch.float32, | |
| device, | |
| generator, | |
| latents, | |
| ) | |
| padding_mask = latents.new_zeros(1, 1, height, width, dtype=transformer_dtype) | |
| # 6. Denoising loop | |
| num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order | |
| self._num_timesteps = len(timesteps) | |
| with self.progress_bar(total=num_inference_steps) as progress_bar: | |
| for i, t in enumerate(timesteps): | |
| if self.interrupt: | |
| continue | |
| self._current_timestep = t | |
| timestep = t.expand(latents.shape[0]).to(transformer_dtype) | |
| latent_model_input = latents | |
| latent_model_input = self.scheduler.scale_model_input(latent_model_input, t) | |
| latent_model_input = latent_model_input.to(transformer_dtype) | |
| noise_pred = self.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=prompt_embeds, | |
| fps=fps, | |
| padding_mask=padding_mask, | |
| return_dict=False, | |
| )[0] | |
| sample = latents | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond = self.transformer( | |
| hidden_states=latent_model_input, | |
| timestep=timestep, | |
| encoder_hidden_states=negative_prompt_embeds, | |
| fps=fps, | |
| padding_mask=padding_mask, | |
| return_dict=False, | |
| )[0] | |
| noise_pred = torch.cat([noise_pred_uncond, noise_pred]) | |
| sample = torch.cat([sample, sample]) | |
| # pred_original_sample (x0) | |
| noise_pred = self.scheduler.step(noise_pred, t, sample, return_dict=False)[1] | |
| self.scheduler._step_index -= 1 | |
| if self.do_classifier_free_guidance: | |
| noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2) | |
| noise_pred = noise_pred_cond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond) | |
| # pred_sample (eps) | |
| latents = self.scheduler.step( | |
| noise_pred, t, latents, return_dict=False, pred_original_sample=noise_pred | |
| )[0] | |
| if callback_on_step_end is not None: | |
| callback_kwargs = {} | |
| for k in callback_on_step_end_tensor_inputs: | |
| callback_kwargs[k] = locals()[k] | |
| callback_outputs = callback_on_step_end(self, i, t, callback_kwargs) | |
| latents = callback_outputs.pop("latents", latents) | |
| prompt_embeds = callback_outputs.pop("prompt_embeds", prompt_embeds) | |
| negative_prompt_embeds = callback_outputs.pop("negative_prompt_embeds", negative_prompt_embeds) | |
| # call the callback, if provided | |
| if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0): | |
| progress_bar.update() | |
| if XLA_AVAILABLE: | |
| xm.mark_step() | |
| self._current_timestep = None | |
| if not output_type == "latent": | |
| if self.vae.config.latents_mean is not None: | |
| latents_mean, latents_std = self.vae.config.latents_mean, self.vae.config.latents_std | |
| latents_mean = ( | |
| torch.tensor(latents_mean) | |
| .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)] | |
| .to(latents) | |
| ) | |
| latents_std = ( | |
| torch.tensor(latents_std) | |
| .view(1, self.vae.config.latent_channels, -1, 1, 1)[:, :, : latents.size(2)] | |
| .to(latents) | |
| ) | |
| latents = latents * latents_std / self.scheduler.config.sigma_data + latents_mean | |
| else: | |
| latents = latents / self.scheduler.config.sigma_data | |
| video = self.vae.decode(latents.to(self.vae.dtype), return_dict=False)[0] | |
| if self.safety_checker is not None: | |
| self.safety_checker.to(device) | |
| video = self.video_processor.postprocess_video(video, output_type="np") | |
| video = (video * 255).astype(np.uint8) | |
| video_batch = [] | |
| for vid in video: | |
| vid = self.safety_checker.check_video_safety(vid) | |
| video_batch.append(vid) | |
| video = np.stack(video_batch).astype(np.float32) / 255.0 * 2 - 1 | |
| video = torch.from_numpy(video).permute(0, 4, 1, 2, 3) | |
| video = self.video_processor.postprocess_video(video, output_type=output_type) | |
| self.safety_checker.to("cpu") | |
| else: | |
| video = self.video_processor.postprocess_video(video, output_type=output_type) | |
| else: | |
| video = latents | |
| # Offload all models | |
| self.maybe_free_model_hooks() | |
| if not return_dict: | |
| return (video,) | |
| return CosmosPipelineOutput(frames=video) | |