Commit Β·
4cd55fa
1
Parent(s): 14bcb03
Upload 6 files (#1)
Browse files- Upload 6 files (a2c2e9a29781f1deb4c9f895dc89c3aa948fbbb3)
Co-authored-by: Shaswat Garg <ShaswatRobotics@users.noreply.huggingface.co>
ctrl_world/droid/checkpoint-10000.pt
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
version https://git-lfs.github.com/spec/v1
|
| 2 |
+
oid sha256:ed17de48180d4e6f89fd33c53e9fb7a0196189c1a67d44c2c486a279a80ea8a8
|
| 3 |
+
size 9281040326
|
ctrl_world/droid/config.json
ADDED
|
@@ -0,0 +1,52 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
{
|
| 2 |
+
"name": "droid_ctrl_world",
|
| 3 |
+
"env": "DROID",
|
| 4 |
+
"model_type": "ctrl_world",
|
| 5 |
+
"metadata": {
|
| 6 |
+
"num_history": 6,
|
| 7 |
+
"num_frames": 5,
|
| 8 |
+
"action_dim": 7
|
| 9 |
+
},
|
| 10 |
+
"util_folders":{
|
| 11 |
+
"models": "../src/models"
|
| 12 |
+
},
|
| 13 |
+
"models": [
|
| 14 |
+
{
|
| 15 |
+
"name": "world_model",
|
| 16 |
+
"framework": null,
|
| 17 |
+
"format": "state_dict",
|
| 18 |
+
"source": {
|
| 19 |
+
"weights_path": "checkpoint-10000.pt",
|
| 20 |
+
"class_path": "../src/world_model.py",
|
| 21 |
+
"class_name": "CrtlWorld",
|
| 22 |
+
"class_args": [
|
| 23 |
+
{
|
| 24 |
+
"svd_model_path": "stabilityai/stable-video-diffusion-img2vid",
|
| 25 |
+
"clip_model_path": "openai/clip-vit-base-patch32",
|
| 26 |
+
"num_history": 6,
|
| 27 |
+
"num_frames": 5,
|
| 28 |
+
"action_dim": 7,
|
| 29 |
+
"text_cond": true,
|
| 30 |
+
"motion_bucket_id": 127,
|
| 31 |
+
"fps": 7,
|
| 32 |
+
"guidance_scale": 1.0,
|
| 33 |
+
"num_inference_steps": 50,
|
| 34 |
+
"decode_chunk_size": 7,
|
| 35 |
+
"width": 320,
|
| 36 |
+
"height": 192
|
| 37 |
+
}]
|
| 38 |
+
},
|
| 39 |
+
"methods":
|
| 40 |
+
[
|
| 41 |
+
{
|
| 42 |
+
"name": "blocks_left_in_kv_cache",
|
| 43 |
+
"method_name": "blocks_left_in_kv_cache"
|
| 44 |
+
},
|
| 45 |
+
{
|
| 46 |
+
"name": "reset_kv_cache",
|
| 47 |
+
"method_name": "reset_kv_cache"
|
| 48 |
+
}
|
| 49 |
+
]
|
| 50 |
+
}
|
| 51 |
+
]
|
| 52 |
+
}
|
ctrl_world/src/models/pipeline_ctrl_world.py
ADDED
|
@@ -0,0 +1,823 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 2 |
+
import torch
|
| 3 |
+
from einops import rearrange, repeat
|
| 4 |
+
import PIL
|
| 5 |
+
import einops
|
| 6 |
+
|
| 7 |
+
# from diffusers import TextToVideoSDPipeline, StableVideoDiffusionPipeline
|
| 8 |
+
from diffusers import TextToVideoSDPipeline
|
| 9 |
+
from models.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline
|
| 10 |
+
|
| 11 |
+
|
| 12 |
+
from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import TextToVideoSDPipelineOutput
|
| 13 |
+
from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import StableVideoDiffusionPipelineOutput
|
| 14 |
+
from diffusers.utils.torch_utils import randn_tensor
|
| 15 |
+
|
| 16 |
+
def svd_tensor2vid(video: torch.Tensor, processor, output_type="np"):
|
| 17 |
+
# Based on:
|
| 18 |
+
# https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
|
| 19 |
+
|
| 20 |
+
batch_size, channels, num_frames, height, width = video.shape
|
| 21 |
+
outputs = []
|
| 22 |
+
for batch_idx in range(batch_size):
|
| 23 |
+
batch_vid = video[batch_idx].permute(1, 0, 2, 3)
|
| 24 |
+
batch_output = processor.postprocess(batch_vid, output_type)
|
| 25 |
+
|
| 26 |
+
outputs.append(batch_output)
|
| 27 |
+
|
| 28 |
+
return outputs
|
| 29 |
+
|
| 30 |
+
class LatentToVideoPipeline(TextToVideoSDPipeline):
|
| 31 |
+
@torch.no_grad()
|
| 32 |
+
def __call__(
|
| 33 |
+
self,
|
| 34 |
+
prompt = None,
|
| 35 |
+
height= None,
|
| 36 |
+
width= None,
|
| 37 |
+
num_frames: int = 16,
|
| 38 |
+
num_inference_steps: int = 50,
|
| 39 |
+
guidance_scale= 9.0,
|
| 40 |
+
negative_prompt= None,
|
| 41 |
+
eta: float = 0.0,
|
| 42 |
+
generator= None,
|
| 43 |
+
latents= None,
|
| 44 |
+
prompt_embeds= None,
|
| 45 |
+
negative_prompt_embeds= None,
|
| 46 |
+
output_type= "np",
|
| 47 |
+
return_dict: bool = True,
|
| 48 |
+
callback= None,
|
| 49 |
+
callback_steps: int = 1,
|
| 50 |
+
cross_attention_kwargs= None,
|
| 51 |
+
condition_latent=None,
|
| 52 |
+
mask=None,
|
| 53 |
+
timesteps=None,
|
| 54 |
+
motion=None,
|
| 55 |
+
):
|
| 56 |
+
r"""
|
| 57 |
+
Function invoked when calling the pipeline for generation.
|
| 58 |
+
|
| 59 |
+
Args:
|
| 60 |
+
prompt (`str` or `List[str]`, *optional*):
|
| 61 |
+
The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
|
| 62 |
+
instead.
|
| 63 |
+
height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 64 |
+
The height in pixels of the generated video.
|
| 65 |
+
width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
|
| 66 |
+
The width in pixels of the generated video.
|
| 67 |
+
num_frames (`int`, *optional*, defaults to 16):
|
| 68 |
+
The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
|
| 69 |
+
amounts to 2 seconds of video.
|
| 70 |
+
num_inference_steps (`int`, *optional*, defaults to 50):
|
| 71 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
|
| 72 |
+
expense of slower inference.
|
| 73 |
+
guidance_scale (`float`, *optional*, defaults to 7.5):
|
| 74 |
+
Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
|
| 75 |
+
`guidance_scale` is defined as `w` of equation 2. of [Imagen
|
| 76 |
+
Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
|
| 77 |
+
1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
|
| 78 |
+
usually at the expense of lower video quality.
|
| 79 |
+
negative_prompt (`str` or `List[str]`, *optional*):
|
| 80 |
+
The prompt or prompts not to guide the video generation. If not defined, one has to pass
|
| 81 |
+
`negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
|
| 82 |
+
less than `1`).
|
| 83 |
+
eta (`float`, *optional*, defaults to 0.0):
|
| 84 |
+
Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
|
| 85 |
+
[`schedulers.DDIMScheduler`], will be ignored for others.
|
| 86 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 87 |
+
One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
|
| 88 |
+
to make generation deterministic.
|
| 89 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 90 |
+
Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
|
| 91 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 92 |
+
tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape
|
| 93 |
+
`(batch_size, num_channel, num_frames, height, width)`.
|
| 94 |
+
prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 95 |
+
Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
|
| 96 |
+
provided, text embeddings will be generated from `prompt` input argument.
|
| 97 |
+
negative_prompt_embeds (`torch.FloatTensor`, *optional*):
|
| 98 |
+
Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
|
| 99 |
+
weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
|
| 100 |
+
argument.
|
| 101 |
+
output_type (`str`, *optional*, defaults to `"np"`):
|
| 102 |
+
The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`.
|
| 103 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 104 |
+
Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a
|
| 105 |
+
plain tuple.
|
| 106 |
+
callback (`Callable`, *optional*):
|
| 107 |
+
A function that will be called every `callback_steps` steps during inference. The function will be
|
| 108 |
+
called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
|
| 109 |
+
callback_steps (`int`, *optional*, defaults to 1):
|
| 110 |
+
The frequency at which the `callback` function will be called. If not specified, the callback will be
|
| 111 |
+
called at every step.
|
| 112 |
+
cross_attention_kwargs (`dict`, *optional*):
|
| 113 |
+
A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
|
| 114 |
+
`self.processor` in
|
| 115 |
+
[diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
|
| 116 |
+
|
| 117 |
+
Examples:
|
| 118 |
+
|
| 119 |
+
Returns:
|
| 120 |
+
[`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`:
|
| 121 |
+
[`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
|
| 122 |
+
When returning a tuple, the first element is a list with the generated frames.
|
| 123 |
+
"""
|
| 124 |
+
# 0. Default height and width to unet
|
| 125 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 126 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 127 |
+
|
| 128 |
+
num_images_per_prompt = 1
|
| 129 |
+
|
| 130 |
+
# 1. Check inputs. Raise error if not correct
|
| 131 |
+
self.check_inputs(
|
| 132 |
+
prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# 2. Define call parameters
|
| 136 |
+
if prompt is not None and isinstance(prompt, str):
|
| 137 |
+
batch_size = 1
|
| 138 |
+
elif prompt is not None and isinstance(prompt, list):
|
| 139 |
+
batch_size = len(prompt)
|
| 140 |
+
else:
|
| 141 |
+
batch_size = prompt_embeds.shape[0]
|
| 142 |
+
|
| 143 |
+
#device = self._execution_device
|
| 144 |
+
device = latents.device
|
| 145 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 146 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 147 |
+
# corresponds to doing no classifier free guidance.
|
| 148 |
+
do_classifier_free_guidance = guidance_scale > 1.0
|
| 149 |
+
|
| 150 |
+
# 3. Encode input prompt
|
| 151 |
+
text_encoder_lora_scale = (
|
| 152 |
+
cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
|
| 153 |
+
)
|
| 154 |
+
prompt_embeds = self._encode_prompt(
|
| 155 |
+
prompt,
|
| 156 |
+
device,
|
| 157 |
+
num_images_per_prompt,
|
| 158 |
+
do_classifier_free_guidance,
|
| 159 |
+
negative_prompt,
|
| 160 |
+
prompt_embeds=prompt_embeds,
|
| 161 |
+
negative_prompt_embeds=negative_prompt_embeds,
|
| 162 |
+
lora_scale=text_encoder_lora_scale,
|
| 163 |
+
)
|
| 164 |
+
|
| 165 |
+
# 4. Prepare timesteps
|
| 166 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 167 |
+
if timesteps is None:
|
| 168 |
+
timesteps = self.scheduler.timesteps
|
| 169 |
+
else:
|
| 170 |
+
num_inference_steps = len(timesteps)
|
| 171 |
+
# 5. Prepare latent variables. do nothing
|
| 172 |
+
|
| 173 |
+
# 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
|
| 174 |
+
extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
|
| 175 |
+
|
| 176 |
+
# 7. Denoising loop
|
| 177 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 178 |
+
uncondition_latent = condition_latent
|
| 179 |
+
condition_latent = torch.cat([uncondition_latent, condition_latent]) if do_classifier_free_guidance else condition_latent
|
| 180 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 181 |
+
for i, t in enumerate(timesteps):
|
| 182 |
+
# expand the latents if we are doing classifier free guidance
|
| 183 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 184 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 185 |
+
if motion is not None:
|
| 186 |
+
motion = torch.tensor(motion, device=device)
|
| 187 |
+
noise_pred = self.unet(
|
| 188 |
+
latent_model_input,
|
| 189 |
+
t,
|
| 190 |
+
encoder_hidden_states=prompt_embeds,
|
| 191 |
+
cross_attention_kwargs=cross_attention_kwargs,
|
| 192 |
+
condition_latent=condition_latent,
|
| 193 |
+
mask=mask,
|
| 194 |
+
motion=motion
|
| 195 |
+
).sample
|
| 196 |
+
# perform guidance
|
| 197 |
+
if do_classifier_free_guidance:
|
| 198 |
+
noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
|
| 199 |
+
noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
|
| 200 |
+
|
| 201 |
+
# reshape latents
|
| 202 |
+
bsz, channel, frames, width, height = latents.shape
|
| 203 |
+
latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
|
| 204 |
+
noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
|
| 205 |
+
|
| 206 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 207 |
+
latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
|
| 208 |
+
|
| 209 |
+
# reshape latents back
|
| 210 |
+
latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
|
| 211 |
+
|
| 212 |
+
# call the callback, if provided
|
| 213 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 214 |
+
progress_bar.update()
|
| 215 |
+
if callback is not None and i % callback_steps == 0:
|
| 216 |
+
callback(i, t, latents)
|
| 217 |
+
|
| 218 |
+
video_tensor = self.decode_latents(latents)
|
| 219 |
+
|
| 220 |
+
if output_type == "pt":
|
| 221 |
+
video = video_tensor
|
| 222 |
+
else:
|
| 223 |
+
video = tensor2vid(video_tensor)
|
| 224 |
+
|
| 225 |
+
# Offload last model to CPU
|
| 226 |
+
if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
|
| 227 |
+
self.final_offload_hook.offload()
|
| 228 |
+
|
| 229 |
+
if not return_dict:
|
| 230 |
+
return (video, latents)
|
| 231 |
+
|
| 232 |
+
return TextToVideoSDPipelineOutput(frames=video)
|
| 233 |
+
|
| 234 |
+
def _append_dims(x, target_dims):
|
| 235 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
| 236 |
+
dims_to_append = target_dims - x.ndim
|
| 237 |
+
if dims_to_append < 0:
|
| 238 |
+
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
| 239 |
+
return x[(...,) + (None,) * dims_to_append]
|
| 240 |
+
|
| 241 |
+
class CtrlWorldDiffusionPipeline(StableVideoDiffusionPipeline):
|
| 242 |
+
@torch.no_grad()
|
| 243 |
+
def __call__(
|
| 244 |
+
self,
|
| 245 |
+
image,
|
| 246 |
+
text,
|
| 247 |
+
height: int = 576,
|
| 248 |
+
width: int = 1024,
|
| 249 |
+
num_frames: Optional[int] = None,
|
| 250 |
+
num_inference_steps: int = 25,
|
| 251 |
+
min_guidance_scale: float = 1.0,
|
| 252 |
+
max_guidance_scale: float = 3.0,
|
| 253 |
+
fps: int = 7,
|
| 254 |
+
motion_bucket_id: int = 127,
|
| 255 |
+
noise_aug_strength: int = 0.02,
|
| 256 |
+
decode_chunk_size: Optional[int] = None,
|
| 257 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 258 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 259 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 260 |
+
output_type: Optional[str] = "pil",
|
| 261 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 262 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 263 |
+
return_dict: bool = True,
|
| 264 |
+
mask = None,
|
| 265 |
+
cond_wrist=None,
|
| 266 |
+
history=None,
|
| 267 |
+
frame_level_cond=False,
|
| 268 |
+
his_cond_zero=False,
|
| 269 |
+
):
|
| 270 |
+
r"""
|
| 271 |
+
The call function to the pipeline for generation.
|
| 272 |
+
|
| 273 |
+
Args:
|
| 274 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
| 275 |
+
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
| 276 |
+
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
| 277 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 278 |
+
The height in pixels of the generated image.
|
| 279 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 280 |
+
The width in pixels of the generated image.
|
| 281 |
+
num_frames (`int`, *optional*):
|
| 282 |
+
The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
|
| 283 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
| 284 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 285 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
| 286 |
+
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
| 287 |
+
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
| 288 |
+
max_guidance_scale (`float`, *optional*, defaults to 3.0):
|
| 289 |
+
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
| 290 |
+
fps (`int`, *optional*, defaults to 7):
|
| 291 |
+
Frames per second. The rate at which the generated images shall be exported to a video after generation.
|
| 292 |
+
Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
|
| 293 |
+
motion_bucket_id (`int`, *optional*, defaults to 127):
|
| 294 |
+
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
|
| 295 |
+
noise_aug_strength (`int`, *optional*, defaults to 0.02):
|
| 296 |
+
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
|
| 297 |
+
decode_chunk_size (`int`, *optional*):
|
| 298 |
+
The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
|
| 299 |
+
between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
|
| 300 |
+
for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
|
| 301 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 302 |
+
The number of images to generate per prompt.
|
| 303 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 304 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 305 |
+
generation deterministic.
|
| 306 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 307 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 308 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 309 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 310 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 311 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 312 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 313 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 314 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 315 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 316 |
+
`callback_on_step_end_tensor_inputs`.
|
| 317 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 318 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 319 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 320 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 321 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 322 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 323 |
+
plain tuple.
|
| 324 |
+
|
| 325 |
+
Returns:
|
| 326 |
+
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
|
| 327 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
|
| 328 |
+
otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
|
| 329 |
+
|
| 330 |
+
Examples:
|
| 331 |
+
|
| 332 |
+
```py
|
| 333 |
+
from diffusers import StableVideoDiffusionPipeline
|
| 334 |
+
from diffusers.utils import load_image, export_to_video
|
| 335 |
+
|
| 336 |
+
pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
|
| 337 |
+
pipe.to("cuda")
|
| 338 |
+
|
| 339 |
+
image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
|
| 340 |
+
image = image.resize((1024, 576))
|
| 341 |
+
|
| 342 |
+
frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
|
| 343 |
+
export_to_video(frames, "generated.mp4", fps=7)
|
| 344 |
+
```
|
| 345 |
+
"""
|
| 346 |
+
# 0. Default height and width to unet
|
| 347 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 348 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 349 |
+
|
| 350 |
+
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
|
| 351 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
|
| 352 |
+
# device = self._execution_device
|
| 353 |
+
device = self.unet.device
|
| 354 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 355 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 356 |
+
# corresponds to doing no classifier free guidance.
|
| 357 |
+
do_classifier_free_guidance = max_guidance_scale > 1.0
|
| 358 |
+
|
| 359 |
+
# # 1. Check inputs. Raise error if not correct
|
| 360 |
+
# self.check_inputs(image, height, width)
|
| 361 |
+
|
| 362 |
+
# # 2. Define call parameters
|
| 363 |
+
# if isinstance(image, PIL.Image.Image):
|
| 364 |
+
# batch_size = 1
|
| 365 |
+
# elif isinstance(image, list):
|
| 366 |
+
# batch_size = len(image)
|
| 367 |
+
# else:
|
| 368 |
+
# batch_size = image.shape[0]
|
| 369 |
+
# # 3. Encode input image
|
| 370 |
+
# # clip_imgae = self.video_processor.preprocess(image, height=224, width=224)
|
| 371 |
+
# clip_image = _resize_with_antialiasing(image, (224, 224))
|
| 372 |
+
# image_embeddings = self._encode_image(clip_image, device, num_videos_per_prompt, do_classifier_free_guidance)
|
| 373 |
+
image_embeddings = text
|
| 374 |
+
batch_size = image_embeddings.shape[0]
|
| 375 |
+
if do_classifier_free_guidance:
|
| 376 |
+
negative_image_embeddings = torch.zeros_like(image_embeddings)
|
| 377 |
+
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
|
| 378 |
+
|
| 379 |
+
|
| 380 |
+
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
|
| 381 |
+
# is why it is reduced here.
|
| 382 |
+
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
|
| 383 |
+
# fps = fps - 1 # we only use fps = 7 in train, so just set to 7
|
| 384 |
+
|
| 385 |
+
# 4. Encode input image using VAE
|
| 386 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 387 |
+
if image.shape[-3] == 3: # (batch, 3, 256, 256)
|
| 388 |
+
image = self.video_processor.preprocess(image, height=height, width=width)
|
| 389 |
+
noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
|
| 390 |
+
# image = image + noise_aug_strength * noise
|
| 391 |
+
|
| 392 |
+
if needs_upcasting:
|
| 393 |
+
self.vae.to(dtype=torch.float32)
|
| 394 |
+
|
| 395 |
+
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
|
| 396 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
| 397 |
+
|
| 398 |
+
# cast back to fp16 if needed
|
| 399 |
+
if needs_upcasting:
|
| 400 |
+
self.vae.to(dtype=torch.float16)
|
| 401 |
+
else: # (batch, 4, 32, 32)
|
| 402 |
+
image_latents = image/self.vae.config.scaling_factor
|
| 403 |
+
if do_classifier_free_guidance:
|
| 404 |
+
# negative_image_latent = torch.zeros_like(image_latents)
|
| 405 |
+
# image_latents = torch.cat([negative_image_latent, image_latents])
|
| 406 |
+
image_latents = torch.cat([image_latents]*2)
|
| 407 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
| 408 |
+
|
| 409 |
+
# Repeat the image latents for each frame so we can concatenate them with the noise
|
| 410 |
+
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
| 411 |
+
if history is not None:
|
| 412 |
+
B, num_his, C, H, W = history.shape
|
| 413 |
+
num_frames_all = num_frames + num_his
|
| 414 |
+
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames_all, 1, 1, 1)
|
| 415 |
+
if his_cond_zero:
|
| 416 |
+
image_latents[:,:num_his] = 0.0 # set history to 0
|
| 417 |
+
else:
|
| 418 |
+
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
| 419 |
+
# mask = repeat(mask, '1 h w -> 2 f 1 h w', f=num_frames)
|
| 420 |
+
# 5. Get Added Time IDs
|
| 421 |
+
added_time_ids = self._get_add_time_ids(
|
| 422 |
+
fps,
|
| 423 |
+
motion_bucket_id,
|
| 424 |
+
noise_aug_strength,
|
| 425 |
+
image_embeddings.dtype,
|
| 426 |
+
batch_size,
|
| 427 |
+
num_videos_per_prompt,
|
| 428 |
+
do_classifier_free_guidance,
|
| 429 |
+
)
|
| 430 |
+
added_time_ids = added_time_ids.to(device)
|
| 431 |
+
|
| 432 |
+
# 4. Prepare timesteps
|
| 433 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 434 |
+
timesteps = self.scheduler.timesteps
|
| 435 |
+
|
| 436 |
+
# 5. Prepare latent variables
|
| 437 |
+
num_channels_latents = self.unet.config.in_channels
|
| 438 |
+
latents = self.prepare_latents(
|
| 439 |
+
batch_size * num_videos_per_prompt,
|
| 440 |
+
num_frames,
|
| 441 |
+
num_channels_latents,
|
| 442 |
+
height,
|
| 443 |
+
width,
|
| 444 |
+
image_embeddings.dtype,
|
| 445 |
+
device,
|
| 446 |
+
generator,
|
| 447 |
+
latents,
|
| 448 |
+
)
|
| 449 |
+
|
| 450 |
+
# 7. Prepare guidance scale
|
| 451 |
+
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
|
| 452 |
+
guidance_scale = guidance_scale.to(device, latents.dtype)
|
| 453 |
+
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
|
| 454 |
+
guidance_scale = _append_dims(guidance_scale, latents.ndim)
|
| 455 |
+
|
| 456 |
+
self._guidance_scale = guidance_scale
|
| 457 |
+
|
| 458 |
+
# 8. Denoising loop
|
| 459 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 460 |
+
self._num_timesteps = len(timesteps)
|
| 461 |
+
# print("prediction type",self.scheduler.config.prediction_type)
|
| 462 |
+
if cond_wrist is not None:
|
| 463 |
+
B,F, C, H, W = latents.shape
|
| 464 |
+
cond_wrist = einops.repeat(cond_wrist, 'b l c h w -> b (f l) (n c) h w', n=3,f=num_frames) # (B, 8, 12 , 24, 40)
|
| 465 |
+
cond_wrist = torch.cat([cond_wrist]*2) if do_classifier_free_guidance else cond_wrist
|
| 466 |
+
|
| 467 |
+
if history is not None:
|
| 468 |
+
history = torch.cat([history] * 2) if do_classifier_free_guidance else history
|
| 469 |
+
|
| 470 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 471 |
+
for i, t in enumerate(timesteps):
|
| 472 |
+
# expand the latents if we are doing classifier free guidance
|
| 473 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 474 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 475 |
+
|
| 476 |
+
if history is not None:
|
| 477 |
+
latent_model_input = torch.cat([history, latent_model_input], dim=1) # (bsz*2,frame+F,4,32,32)
|
| 478 |
+
|
| 479 |
+
# Concatenate image_latents over channels dimention
|
| 480 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 481 |
+
|
| 482 |
+
if cond_wrist is not None and i==0:
|
| 483 |
+
# print('cond_wrist_shape:',cond_wrist.shape, 'latent_model_input_shape:',latent_model_input.shape)
|
| 484 |
+
latent_model_input = torch.cat([latent_model_input, cond_wrist], dim=3) # (B, 8, 12, 96, 40)
|
| 485 |
+
|
| 486 |
+
|
| 487 |
+
# predict the noise residual
|
| 488 |
+
latent_model_input = latent_model_input.to(self.unet.dtype)
|
| 489 |
+
image_embeddings = image_embeddings.to(self.unet.dtype)
|
| 490 |
+
# print('extract_layer_idx:',extract_layer_idx)
|
| 491 |
+
# print('latent_model_input_shape:',latent_model_input.shape)
|
| 492 |
+
# print('encoder_hidden_states:',image_embeddings.shape)
|
| 493 |
+
# print('added_time_ids:',added_time_ids.shape)
|
| 494 |
+
noise_pred = self.unet(
|
| 495 |
+
latent_model_input,
|
| 496 |
+
t,
|
| 497 |
+
encoder_hidden_states=image_embeddings,
|
| 498 |
+
added_time_ids=added_time_ids,
|
| 499 |
+
return_dict=False,
|
| 500 |
+
frame_level_cond=frame_level_cond,
|
| 501 |
+
)[0]
|
| 502 |
+
|
| 503 |
+
if cond_wrist is not None:
|
| 504 |
+
noise_pred = noise_pred[:, :,:,:H, :W] # remove cond_wrist
|
| 505 |
+
if history is not None:
|
| 506 |
+
# print('history_shape:',history.shape)
|
| 507 |
+
# print('noise_pred_shape:',noise_pred.shape)
|
| 508 |
+
noise_pred = noise_pred[:, num_his:, :, :, :] # remove history
|
| 509 |
+
|
| 510 |
+
# perform guidance
|
| 511 |
+
if do_classifier_free_guidance:
|
| 512 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 513 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 514 |
+
|
| 515 |
+
# model_output = noise_pred
|
| 516 |
+
# # sigma = self.scheduler.get_sigma(t)
|
| 517 |
+
# # sigma = self.scheduler.sigmas[t]
|
| 518 |
+
# self.scheduler._init_step_index(t)
|
| 519 |
+
# sigma = self.scheduler.sigmas[self.scheduler.step_index]
|
| 520 |
+
# print("sigma", sigma)
|
| 521 |
+
# print(t)
|
| 522 |
+
# pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (latents / (sigma**2 + 1))
|
| 523 |
+
# print(pred_original_sample.shape)
|
| 524 |
+
# latents = pred_original_sample
|
| 525 |
+
# # return pred_original_sample
|
| 526 |
+
# break
|
| 527 |
+
|
| 528 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 529 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
| 530 |
+
|
| 531 |
+
if callback_on_step_end is not None:
|
| 532 |
+
callback_kwargs = {}
|
| 533 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 534 |
+
callback_kwargs[k] = locals()[k]
|
| 535 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 536 |
+
|
| 537 |
+
latents = callback_outputs.pop("latents", latents)
|
| 538 |
+
|
| 539 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 540 |
+
progress_bar.update()
|
| 541 |
+
|
| 542 |
+
if not output_type == "latent":
|
| 543 |
+
# cast back to fp16 if needed
|
| 544 |
+
if needs_upcasting:
|
| 545 |
+
self.vae.to(dtype=torch.float16)
|
| 546 |
+
# latents = latents/self.vae.config.scaling_factor
|
| 547 |
+
latents = latents.to(self.vae.dtype)
|
| 548 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
| 549 |
+
frames = svd_tensor2vid(frames, self.video_processor, output_type=output_type)
|
| 550 |
+
else:
|
| 551 |
+
frames = latents
|
| 552 |
+
|
| 553 |
+
self.maybe_free_model_hooks()
|
| 554 |
+
|
| 555 |
+
if not return_dict:
|
| 556 |
+
return frames,latents
|
| 557 |
+
|
| 558 |
+
return StableVideoDiffusionPipelineOutput(frames=frames)
|
| 559 |
+
|
| 560 |
+
class TextStableVideoDiffusionPipeline(StableVideoDiffusionPipeline):
|
| 561 |
+
@torch.no_grad()
|
| 562 |
+
def __call__(
|
| 563 |
+
self,
|
| 564 |
+
image,
|
| 565 |
+
prompt_embeds = None,
|
| 566 |
+
negative_prompt_embeds = None,
|
| 567 |
+
height: int = 576,
|
| 568 |
+
width: int = 1024,
|
| 569 |
+
num_frames: Optional[int] = None,
|
| 570 |
+
num_inference_steps: int = 25,
|
| 571 |
+
min_guidance_scale: float = 1.0,
|
| 572 |
+
max_guidance_scale: float = 3.0,
|
| 573 |
+
fps: int = 7,
|
| 574 |
+
motion_bucket_id: int = 127,
|
| 575 |
+
noise_aug_strength: int = 0.02,
|
| 576 |
+
decode_chunk_size: Optional[int] = None,
|
| 577 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 578 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 579 |
+
latents: Optional[torch.FloatTensor] = None,
|
| 580 |
+
output_type: Optional[str] = "pil",
|
| 581 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 582 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 583 |
+
return_dict: bool = True,
|
| 584 |
+
mask = None,
|
| 585 |
+
condition_type = "image",
|
| 586 |
+
condition_latent = None,
|
| 587 |
+
):
|
| 588 |
+
r"""
|
| 589 |
+
The call function to the pipeline for generation.
|
| 590 |
+
|
| 591 |
+
Args:
|
| 592 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
|
| 593 |
+
Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
|
| 594 |
+
[`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
|
| 595 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 596 |
+
The height in pixels of the generated image.
|
| 597 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 598 |
+
The width in pixels of the generated image.
|
| 599 |
+
num_frames (`int`, *optional*):
|
| 600 |
+
The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
|
| 601 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
| 602 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality image at the
|
| 603 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
| 604 |
+
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
| 605 |
+
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
| 606 |
+
max_guidance_scale (`float`, *optional*, defaults to 3.0):
|
| 607 |
+
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
| 608 |
+
fps (`int`, *optional*, defaults to 7):
|
| 609 |
+
Frames per second. The rate at which the generated images shall be exported to a video after generation.
|
| 610 |
+
Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
|
| 611 |
+
motion_bucket_id (`int`, *optional*, defaults to 127):
|
| 612 |
+
The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
|
| 613 |
+
noise_aug_strength (`int`, *optional*, defaults to 0.02):
|
| 614 |
+
The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
|
| 615 |
+
decode_chunk_size (`int`, *optional*):
|
| 616 |
+
The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
|
| 617 |
+
between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
|
| 618 |
+
for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
|
| 619 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 620 |
+
The number of images to generate per prompt.
|
| 621 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 622 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 623 |
+
generation deterministic.
|
| 624 |
+
latents (`torch.FloatTensor`, *optional*):
|
| 625 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
|
| 626 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 627 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 628 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 629 |
+
The output format of the generated image. Choose between `PIL.Image` or `np.array`.
|
| 630 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 631 |
+
A function that calls at the end of each denoising steps during the inference. The function is called
|
| 632 |
+
with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
|
| 633 |
+
callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
|
| 634 |
+
`callback_on_step_end_tensor_inputs`.
|
| 635 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 636 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 637 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 638 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 639 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 640 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 641 |
+
plain tuple.
|
| 642 |
+
|
| 643 |
+
Returns:
|
| 644 |
+
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
|
| 645 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
|
| 646 |
+
otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
|
| 647 |
+
|
| 648 |
+
Examples:
|
| 649 |
+
|
| 650 |
+
```py
|
| 651 |
+
from diffusers import StableVideoDiffusionPipeline
|
| 652 |
+
from diffusers.utils import load_image, export_to_video
|
| 653 |
+
|
| 654 |
+
pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
|
| 655 |
+
pipe.to("cuda")
|
| 656 |
+
|
| 657 |
+
image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
|
| 658 |
+
image = image.resize((1024, 576))
|
| 659 |
+
|
| 660 |
+
frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
|
| 661 |
+
export_to_video(frames, "generated.mp4", fps=7)
|
| 662 |
+
```
|
| 663 |
+
"""
|
| 664 |
+
# 0. Default height and width to unet
|
| 665 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 666 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 667 |
+
|
| 668 |
+
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
|
| 669 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
|
| 670 |
+
|
| 671 |
+
# 1. Check inputs. Raise error if not correct
|
| 672 |
+
self.check_inputs(image, height, width)
|
| 673 |
+
|
| 674 |
+
# 2. Define call parameters
|
| 675 |
+
if isinstance(image, PIL.Image.Image):
|
| 676 |
+
batch_size = 1
|
| 677 |
+
elif isinstance(image, list):
|
| 678 |
+
batch_size = len(image)
|
| 679 |
+
else:
|
| 680 |
+
batch_size = image.shape[0]
|
| 681 |
+
device = self._execution_device
|
| 682 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 683 |
+
# of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
|
| 684 |
+
# corresponds to doing no classifier free guidance.
|
| 685 |
+
do_classifier_free_guidance = max_guidance_scale > 1.0
|
| 686 |
+
|
| 687 |
+
# 3. Encode input image
|
| 688 |
+
if condition_type=="image":
|
| 689 |
+
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
|
| 690 |
+
elif condition_type=="text":
|
| 691 |
+
if do_classifier_free_guidance:
|
| 692 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 693 |
+
image_embeddings = prompt_embeds
|
| 694 |
+
else:
|
| 695 |
+
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
|
| 696 |
+
if do_classifier_free_guidance:
|
| 697 |
+
prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
|
| 698 |
+
image_embeddings = torch.cat([image_embeddings, prompt_embeds], dim=1)
|
| 699 |
+
motion_mask = self.unet.config.in_channels == 9
|
| 700 |
+
if do_classifier_free_guidance:
|
| 701 |
+
mask = torch.cat([mask]*2)
|
| 702 |
+
# NOTE: Stable Diffusion Video was conditioned on fps - 1, which
|
| 703 |
+
# is why it is reduced here.
|
| 704 |
+
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
|
| 705 |
+
fps = fps - 1
|
| 706 |
+
|
| 707 |
+
# 4. Encode input image using VAE
|
| 708 |
+
image = self.video_processor.preprocess(image, height=height, width=width)
|
| 709 |
+
noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
|
| 710 |
+
image = image + noise_aug_strength * noise
|
| 711 |
+
|
| 712 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 713 |
+
if needs_upcasting:
|
| 714 |
+
self.vae.to(dtype=torch.float32)
|
| 715 |
+
|
| 716 |
+
if condition_latent is None:
|
| 717 |
+
image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
|
| 718 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
| 719 |
+
|
| 720 |
+
# Repeat the image latents for each frame so we can concatenate them with the noise
|
| 721 |
+
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
| 722 |
+
condition_latent = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
| 723 |
+
else:
|
| 724 |
+
if do_classifier_free_guidance:
|
| 725 |
+
condition_latent = torch.cat([condition_latent] * 2)
|
| 726 |
+
# 5. Get Added Time IDs
|
| 727 |
+
|
| 728 |
+
# cast back to fp16 if needed
|
| 729 |
+
if needs_upcasting:
|
| 730 |
+
self.vae.to(dtype=torch.float16)
|
| 731 |
+
|
| 732 |
+
added_time_ids = self._get_add_time_ids(
|
| 733 |
+
fps,
|
| 734 |
+
motion_bucket_id,
|
| 735 |
+
noise_aug_strength,
|
| 736 |
+
image_embeddings.dtype,
|
| 737 |
+
batch_size,
|
| 738 |
+
num_videos_per_prompt,
|
| 739 |
+
do_classifier_free_guidance,
|
| 740 |
+
)
|
| 741 |
+
added_time_ids = added_time_ids.to(device)
|
| 742 |
+
|
| 743 |
+
# 4. Prepare timesteps
|
| 744 |
+
self.scheduler.set_timesteps(num_inference_steps, device=device)
|
| 745 |
+
timesteps = self.scheduler.timesteps
|
| 746 |
+
|
| 747 |
+
# 5. Prepare latent variables
|
| 748 |
+
num_channels_latents = self.unet.config.in_channels
|
| 749 |
+
latents = self.prepare_latents(
|
| 750 |
+
batch_size * num_videos_per_prompt,
|
| 751 |
+
num_frames,
|
| 752 |
+
num_channels_latents,
|
| 753 |
+
height,
|
| 754 |
+
width,
|
| 755 |
+
image_embeddings.dtype,
|
| 756 |
+
device,
|
| 757 |
+
generator,
|
| 758 |
+
latents,
|
| 759 |
+
)
|
| 760 |
+
|
| 761 |
+
# 7. Prepare guidance scale
|
| 762 |
+
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
|
| 763 |
+
guidance_scale = guidance_scale.to(device, latents.dtype)
|
| 764 |
+
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
|
| 765 |
+
guidance_scale = _append_dims(guidance_scale, latents.ndim)
|
| 766 |
+
|
| 767 |
+
self._guidance_scale = guidance_scale
|
| 768 |
+
|
| 769 |
+
# 8. Denoising loop
|
| 770 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 771 |
+
self._num_timesteps = len(timesteps)
|
| 772 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 773 |
+
for i, t in enumerate(timesteps):
|
| 774 |
+
# expand the latents if we are doing classifier free guidance
|
| 775 |
+
latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
|
| 776 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 777 |
+
|
| 778 |
+
if motion_mask:
|
| 779 |
+
# Concatenate image_latents over channels dimention
|
| 780 |
+
latent_model_input = torch.cat([mask, latent_model_input, condition_latent], dim=2)
|
| 781 |
+
else:
|
| 782 |
+
latent_model_input = torch.cat([latent_model_input, condition_latent], dim=2)
|
| 783 |
+
# predict the noise residual
|
| 784 |
+
noise_pred = self.unet(
|
| 785 |
+
latent_model_input,
|
| 786 |
+
t,
|
| 787 |
+
encoder_hidden_states=image_embeddings,
|
| 788 |
+
added_time_ids=added_time_ids,
|
| 789 |
+
return_dict=False,
|
| 790 |
+
)[0]
|
| 791 |
+
|
| 792 |
+
# perform guidance
|
| 793 |
+
if do_classifier_free_guidance:
|
| 794 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 795 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 796 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 797 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
| 798 |
+
if callback_on_step_end is not None:
|
| 799 |
+
callback_kwargs = {}
|
| 800 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 801 |
+
callback_kwargs[k] = locals()[k]
|
| 802 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 803 |
+
|
| 804 |
+
latents = callback_outputs.pop("latents", latents)
|
| 805 |
+
|
| 806 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 807 |
+
progress_bar.update()
|
| 808 |
+
|
| 809 |
+
if not output_type == "latent":
|
| 810 |
+
# cast back to fp16 if needed
|
| 811 |
+
if needs_upcasting:
|
| 812 |
+
self.vae.to(dtype=torch.float16)
|
| 813 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
| 814 |
+
frames = svd_tensor2vid(frames, self.video_processor, output_type=output_type)
|
| 815 |
+
else:
|
| 816 |
+
frames = latents
|
| 817 |
+
|
| 818 |
+
self.maybe_free_model_hooks()
|
| 819 |
+
|
| 820 |
+
if not return_dict:
|
| 821 |
+
return frames
|
| 822 |
+
|
| 823 |
+
return StableVideoDiffusionPipelineOutput(frames=frames)
|
ctrl_world/src/models/pipeline_stable_video_diffusion.py
ADDED
|
@@ -0,0 +1,742 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Copyright 2025 The HuggingFace Team. All rights reserved.
|
| 2 |
+
#
|
| 3 |
+
# Licensed under the Apache License, Version 2.0 (the "License");
|
| 4 |
+
# you may not use this file except in compliance with the License.
|
| 5 |
+
# You may obtain a copy of the License at
|
| 6 |
+
#
|
| 7 |
+
# http://www.apache.org/licenses/LICENSE-2.0
|
| 8 |
+
#
|
| 9 |
+
# Unless required by applicable law or agreed to in writing, software
|
| 10 |
+
# distributed under the License is distributed on an "AS IS" BASIS,
|
| 11 |
+
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
|
| 12 |
+
# See the License for the specific language governing permissions and
|
| 13 |
+
# limitations under the License.
|
| 14 |
+
|
| 15 |
+
import inspect
|
| 16 |
+
from dataclasses import dataclass
|
| 17 |
+
from typing import Callable, Dict, List, Optional, Union
|
| 18 |
+
|
| 19 |
+
import numpy as np
|
| 20 |
+
import PIL.Image
|
| 21 |
+
import torch
|
| 22 |
+
from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
|
| 23 |
+
|
| 24 |
+
from diffusers.image_processor import PipelineImageInput
|
| 25 |
+
|
| 26 |
+
# import from our own models instead of diffusers
|
| 27 |
+
# from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
|
| 28 |
+
from diffusers.models import AutoencoderKLTemporalDecoder
|
| 29 |
+
from models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
| 30 |
+
|
| 31 |
+
from diffusers.schedulers import EulerDiscreteScheduler
|
| 32 |
+
from diffusers.utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
|
| 33 |
+
from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
|
| 34 |
+
from diffusers.video_processor import VideoProcessor
|
| 35 |
+
from diffusers.pipelines.pipeline_utils import DiffusionPipeline
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
if is_torch_xla_available():
|
| 39 |
+
import torch_xla.core.xla_model as xm
|
| 40 |
+
|
| 41 |
+
XLA_AVAILABLE = True
|
| 42 |
+
else:
|
| 43 |
+
XLA_AVAILABLE = False
|
| 44 |
+
|
| 45 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 46 |
+
|
| 47 |
+
|
| 48 |
+
EXAMPLE_DOC_STRING = """
|
| 49 |
+
Examples:
|
| 50 |
+
```py
|
| 51 |
+
>>> from diffusers import StableVideoDiffusionPipeline
|
| 52 |
+
>>> from diffusers.utils import load_image, export_to_video
|
| 53 |
+
|
| 54 |
+
>>> pipe = StableVideoDiffusionPipeline.from_pretrained(
|
| 55 |
+
... "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
|
| 56 |
+
... )
|
| 57 |
+
>>> pipe.to("cuda")
|
| 58 |
+
|
| 59 |
+
>>> image = load_image(
|
| 60 |
+
... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg"
|
| 61 |
+
... )
|
| 62 |
+
>>> image = image.resize((1024, 576))
|
| 63 |
+
|
| 64 |
+
>>> frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
|
| 65 |
+
>>> export_to_video(frames, "generated.mp4", fps=7)
|
| 66 |
+
```
|
| 67 |
+
"""
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
def _append_dims(x, target_dims):
|
| 71 |
+
"""Appends dimensions to the end of a tensor until it has target_dims dimensions."""
|
| 72 |
+
dims_to_append = target_dims - x.ndim
|
| 73 |
+
if dims_to_append < 0:
|
| 74 |
+
raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
|
| 75 |
+
return x[(...,) + (None,) * dims_to_append]
|
| 76 |
+
|
| 77 |
+
|
| 78 |
+
# Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
|
| 79 |
+
def retrieve_timesteps(
|
| 80 |
+
scheduler,
|
| 81 |
+
num_inference_steps: Optional[int] = None,
|
| 82 |
+
device: Optional[Union[str, torch.device]] = None,
|
| 83 |
+
timesteps: Optional[List[int]] = None,
|
| 84 |
+
sigmas: Optional[List[float]] = None,
|
| 85 |
+
**kwargs,
|
| 86 |
+
):
|
| 87 |
+
r"""
|
| 88 |
+
Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
|
| 89 |
+
custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
|
| 90 |
+
|
| 91 |
+
Args:
|
| 92 |
+
scheduler (`SchedulerMixin`):
|
| 93 |
+
The scheduler to get timesteps from.
|
| 94 |
+
num_inference_steps (`int`):
|
| 95 |
+
The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
|
| 96 |
+
must be `None`.
|
| 97 |
+
device (`str` or `torch.device`, *optional*):
|
| 98 |
+
The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
|
| 99 |
+
timesteps (`List[int]`, *optional*):
|
| 100 |
+
Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
|
| 101 |
+
`num_inference_steps` and `sigmas` must be `None`.
|
| 102 |
+
sigmas (`List[float]`, *optional*):
|
| 103 |
+
Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
|
| 104 |
+
`num_inference_steps` and `timesteps` must be `None`.
|
| 105 |
+
|
| 106 |
+
Returns:
|
| 107 |
+
`Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
|
| 108 |
+
second element is the number of inference steps.
|
| 109 |
+
"""
|
| 110 |
+
if timesteps is not None and sigmas is not None:
|
| 111 |
+
raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
|
| 112 |
+
if timesteps is not None:
|
| 113 |
+
accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 114 |
+
if not accepts_timesteps:
|
| 115 |
+
raise ValueError(
|
| 116 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 117 |
+
f" timestep schedules. Please check whether you are using the correct scheduler."
|
| 118 |
+
)
|
| 119 |
+
scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
|
| 120 |
+
timesteps = scheduler.timesteps
|
| 121 |
+
num_inference_steps = len(timesteps)
|
| 122 |
+
elif sigmas is not None:
|
| 123 |
+
accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
|
| 124 |
+
if not accept_sigmas:
|
| 125 |
+
raise ValueError(
|
| 126 |
+
f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
|
| 127 |
+
f" sigmas schedules. Please check whether you are using the correct scheduler."
|
| 128 |
+
)
|
| 129 |
+
scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
|
| 130 |
+
timesteps = scheduler.timesteps
|
| 131 |
+
num_inference_steps = len(timesteps)
|
| 132 |
+
else:
|
| 133 |
+
scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
|
| 134 |
+
timesteps = scheduler.timesteps
|
| 135 |
+
return timesteps, num_inference_steps
|
| 136 |
+
|
| 137 |
+
|
| 138 |
+
@dataclass
|
| 139 |
+
class StableVideoDiffusionPipelineOutput(BaseOutput):
|
| 140 |
+
r"""
|
| 141 |
+
Output class for Stable Video Diffusion pipeline.
|
| 142 |
+
|
| 143 |
+
Args:
|
| 144 |
+
frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):
|
| 145 |
+
List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
|
| 146 |
+
num_frames, height, width, num_channels)`.
|
| 147 |
+
"""
|
| 148 |
+
|
| 149 |
+
frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class StableVideoDiffusionPipeline(DiffusionPipeline):
|
| 153 |
+
r"""
|
| 154 |
+
Pipeline to generate video from an input image using Stable Video Diffusion.
|
| 155 |
+
|
| 156 |
+
This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
|
| 157 |
+
implemented for all pipelines (downloading, saving, running on a particular device, etc.).
|
| 158 |
+
|
| 159 |
+
Args:
|
| 160 |
+
vae ([`AutoencoderKLTemporalDecoder`]):
|
| 161 |
+
Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
|
| 162 |
+
image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
|
| 163 |
+
Frozen CLIP image-encoder
|
| 164 |
+
([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
|
| 165 |
+
unet ([`UNetSpatioTemporalConditionModel`]):
|
| 166 |
+
A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
|
| 167 |
+
scheduler ([`EulerDiscreteScheduler`]):
|
| 168 |
+
A scheduler to be used in combination with `unet` to denoise the encoded image latents.
|
| 169 |
+
feature_extractor ([`~transformers.CLIPImageProcessor`]):
|
| 170 |
+
A `CLIPImageProcessor` to extract features from generated images.
|
| 171 |
+
"""
|
| 172 |
+
|
| 173 |
+
model_cpu_offload_seq = "image_encoder->unet->vae"
|
| 174 |
+
_callback_tensor_inputs = ["latents"]
|
| 175 |
+
|
| 176 |
+
def __init__(
|
| 177 |
+
self,
|
| 178 |
+
vae: AutoencoderKLTemporalDecoder,
|
| 179 |
+
image_encoder: CLIPVisionModelWithProjection,
|
| 180 |
+
unet: UNetSpatioTemporalConditionModel,
|
| 181 |
+
scheduler: EulerDiscreteScheduler,
|
| 182 |
+
feature_extractor: CLIPImageProcessor,
|
| 183 |
+
):
|
| 184 |
+
super().__init__()
|
| 185 |
+
|
| 186 |
+
self.register_modules(
|
| 187 |
+
vae=vae,
|
| 188 |
+
image_encoder=image_encoder,
|
| 189 |
+
unet=unet,
|
| 190 |
+
scheduler=scheduler,
|
| 191 |
+
feature_extractor=feature_extractor,
|
| 192 |
+
)
|
| 193 |
+
self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
|
| 194 |
+
self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)
|
| 195 |
+
|
| 196 |
+
def _encode_image(
|
| 197 |
+
self,
|
| 198 |
+
image: PipelineImageInput,
|
| 199 |
+
device: Union[str, torch.device],
|
| 200 |
+
num_videos_per_prompt: int,
|
| 201 |
+
do_classifier_free_guidance: bool,
|
| 202 |
+
) -> torch.Tensor:
|
| 203 |
+
dtype = next(self.image_encoder.parameters()).dtype
|
| 204 |
+
|
| 205 |
+
if not isinstance(image, torch.Tensor):
|
| 206 |
+
image = self.video_processor.pil_to_numpy(image)
|
| 207 |
+
image = self.video_processor.numpy_to_pt(image)
|
| 208 |
+
|
| 209 |
+
# We normalize the image before resizing to match with the original implementation.
|
| 210 |
+
# Then we unnormalize it after resizing.
|
| 211 |
+
image = image * 2.0 - 1.0
|
| 212 |
+
image = _resize_with_antialiasing(image, (224, 224))
|
| 213 |
+
image = (image + 1.0) / 2.0
|
| 214 |
+
|
| 215 |
+
# Normalize the image with for CLIP input
|
| 216 |
+
image = self.feature_extractor(
|
| 217 |
+
images=image,
|
| 218 |
+
do_normalize=True,
|
| 219 |
+
do_center_crop=False,
|
| 220 |
+
do_resize=False,
|
| 221 |
+
do_rescale=False,
|
| 222 |
+
return_tensors="pt",
|
| 223 |
+
).pixel_values
|
| 224 |
+
|
| 225 |
+
image = image.to(device=device, dtype=dtype)
|
| 226 |
+
image_embeddings = self.image_encoder(image).image_embeds
|
| 227 |
+
image_embeddings = image_embeddings.unsqueeze(1)
|
| 228 |
+
|
| 229 |
+
# duplicate image embeddings for each generation per prompt, using mps friendly method
|
| 230 |
+
bs_embed, seq_len, _ = image_embeddings.shape
|
| 231 |
+
image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
|
| 232 |
+
image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
|
| 233 |
+
|
| 234 |
+
if do_classifier_free_guidance:
|
| 235 |
+
negative_image_embeddings = torch.zeros_like(image_embeddings)
|
| 236 |
+
|
| 237 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 238 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 239 |
+
# to avoid doing two forward passes
|
| 240 |
+
image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
|
| 241 |
+
|
| 242 |
+
return image_embeddings
|
| 243 |
+
|
| 244 |
+
def _encode_vae_image(
|
| 245 |
+
self,
|
| 246 |
+
image: torch.Tensor,
|
| 247 |
+
device: Union[str, torch.device],
|
| 248 |
+
num_videos_per_prompt: int,
|
| 249 |
+
do_classifier_free_guidance: bool,
|
| 250 |
+
):
|
| 251 |
+
image = image.to(device=device)
|
| 252 |
+
image_latents = self.vae.encode(image).latent_dist.mode()
|
| 253 |
+
|
| 254 |
+
# duplicate image_latents for each generation per prompt, using mps friendly method
|
| 255 |
+
image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
|
| 256 |
+
|
| 257 |
+
if do_classifier_free_guidance:
|
| 258 |
+
negative_image_latents = torch.zeros_like(image_latents)
|
| 259 |
+
|
| 260 |
+
# For classifier free guidance, we need to do two forward passes.
|
| 261 |
+
# Here we concatenate the unconditional and text embeddings into a single batch
|
| 262 |
+
# to avoid doing two forward passes
|
| 263 |
+
image_latents = torch.cat([negative_image_latents, image_latents])
|
| 264 |
+
|
| 265 |
+
return image_latents
|
| 266 |
+
|
| 267 |
+
def _get_add_time_ids(
|
| 268 |
+
self,
|
| 269 |
+
fps: int,
|
| 270 |
+
motion_bucket_id: int,
|
| 271 |
+
noise_aug_strength: float,
|
| 272 |
+
dtype: torch.dtype,
|
| 273 |
+
batch_size: int,
|
| 274 |
+
num_videos_per_prompt: int,
|
| 275 |
+
do_classifier_free_guidance: bool,
|
| 276 |
+
):
|
| 277 |
+
add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
|
| 278 |
+
|
| 279 |
+
passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
|
| 280 |
+
expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
|
| 281 |
+
|
| 282 |
+
if expected_add_embed_dim != passed_add_embed_dim:
|
| 283 |
+
raise ValueError(
|
| 284 |
+
f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
|
| 285 |
+
)
|
| 286 |
+
|
| 287 |
+
add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
|
| 288 |
+
add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
|
| 289 |
+
|
| 290 |
+
if do_classifier_free_guidance:
|
| 291 |
+
add_time_ids = torch.cat([add_time_ids, add_time_ids])
|
| 292 |
+
|
| 293 |
+
return add_time_ids
|
| 294 |
+
|
| 295 |
+
def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14):
|
| 296 |
+
# [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
|
| 297 |
+
latents = latents.flatten(0, 1)
|
| 298 |
+
|
| 299 |
+
latents = 1 / self.vae.config.scaling_factor * latents
|
| 300 |
+
|
| 301 |
+
forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
|
| 302 |
+
accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
|
| 303 |
+
|
| 304 |
+
# decode decode_chunk_size frames at a time to avoid OOM
|
| 305 |
+
frames = []
|
| 306 |
+
for i in range(0, latents.shape[0], decode_chunk_size):
|
| 307 |
+
num_frames_in = latents[i : i + decode_chunk_size].shape[0]
|
| 308 |
+
decode_kwargs = {}
|
| 309 |
+
if accepts_num_frames:
|
| 310 |
+
# we only pass num_frames_in if it's expected
|
| 311 |
+
decode_kwargs["num_frames"] = num_frames_in
|
| 312 |
+
|
| 313 |
+
frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
|
| 314 |
+
frames.append(frame)
|
| 315 |
+
frames = torch.cat(frames, dim=0)
|
| 316 |
+
|
| 317 |
+
# [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
|
| 318 |
+
frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
|
| 319 |
+
|
| 320 |
+
# we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
|
| 321 |
+
frames = frames.float()
|
| 322 |
+
return frames
|
| 323 |
+
|
| 324 |
+
def check_inputs(self, image, height, width):
|
| 325 |
+
if (
|
| 326 |
+
not isinstance(image, torch.Tensor)
|
| 327 |
+
and not isinstance(image, PIL.Image.Image)
|
| 328 |
+
and not isinstance(image, list)
|
| 329 |
+
):
|
| 330 |
+
raise ValueError(
|
| 331 |
+
"`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
|
| 332 |
+
f" {type(image)}"
|
| 333 |
+
)
|
| 334 |
+
|
| 335 |
+
if height % 8 != 0 or width % 8 != 0:
|
| 336 |
+
raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
|
| 337 |
+
|
| 338 |
+
def prepare_latents(
|
| 339 |
+
self,
|
| 340 |
+
batch_size: int,
|
| 341 |
+
num_frames: int,
|
| 342 |
+
num_channels_latents: int,
|
| 343 |
+
height: int,
|
| 344 |
+
width: int,
|
| 345 |
+
dtype: torch.dtype,
|
| 346 |
+
device: Union[str, torch.device],
|
| 347 |
+
generator: torch.Generator,
|
| 348 |
+
latents: Optional[torch.Tensor] = None,
|
| 349 |
+
):
|
| 350 |
+
shape = (
|
| 351 |
+
batch_size,
|
| 352 |
+
num_frames,
|
| 353 |
+
num_channels_latents // 2,
|
| 354 |
+
height // self.vae_scale_factor,
|
| 355 |
+
width // self.vae_scale_factor,
|
| 356 |
+
)
|
| 357 |
+
if isinstance(generator, list) and len(generator) != batch_size:
|
| 358 |
+
raise ValueError(
|
| 359 |
+
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
| 360 |
+
f" size of {batch_size}. Make sure the batch size matches the length of the generators."
|
| 361 |
+
)
|
| 362 |
+
|
| 363 |
+
if latents is None:
|
| 364 |
+
latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
|
| 365 |
+
else:
|
| 366 |
+
latents = latents.to(device)
|
| 367 |
+
|
| 368 |
+
# scale the initial noise by the standard deviation required by the scheduler
|
| 369 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 370 |
+
return latents
|
| 371 |
+
|
| 372 |
+
@property
|
| 373 |
+
def guidance_scale(self):
|
| 374 |
+
return self._guidance_scale
|
| 375 |
+
|
| 376 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 377 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 378 |
+
# corresponds to doing no classifier free guidance.
|
| 379 |
+
@property
|
| 380 |
+
def do_classifier_free_guidance(self):
|
| 381 |
+
if isinstance(self.guidance_scale, (int, float)):
|
| 382 |
+
return self.guidance_scale > 1
|
| 383 |
+
return self.guidance_scale.max() > 1
|
| 384 |
+
|
| 385 |
+
@property
|
| 386 |
+
def num_timesteps(self):
|
| 387 |
+
return self._num_timesteps
|
| 388 |
+
|
| 389 |
+
@torch.no_grad()
|
| 390 |
+
@replace_example_docstring(EXAMPLE_DOC_STRING)
|
| 391 |
+
def __call__(
|
| 392 |
+
self,
|
| 393 |
+
image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
|
| 394 |
+
height: int = 576,
|
| 395 |
+
width: int = 1024,
|
| 396 |
+
num_frames: Optional[int] = None,
|
| 397 |
+
num_inference_steps: int = 25,
|
| 398 |
+
sigmas: Optional[List[float]] = None,
|
| 399 |
+
min_guidance_scale: float = 1.0,
|
| 400 |
+
max_guidance_scale: float = 3.0,
|
| 401 |
+
fps: int = 7,
|
| 402 |
+
motion_bucket_id: int = 127,
|
| 403 |
+
noise_aug_strength: float = 0.02,
|
| 404 |
+
decode_chunk_size: Optional[int] = None,
|
| 405 |
+
num_videos_per_prompt: Optional[int] = 1,
|
| 406 |
+
generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
|
| 407 |
+
latents: Optional[torch.Tensor] = None,
|
| 408 |
+
output_type: Optional[str] = "pil",
|
| 409 |
+
callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
|
| 410 |
+
callback_on_step_end_tensor_inputs: List[str] = ["latents"],
|
| 411 |
+
return_dict: bool = True,
|
| 412 |
+
):
|
| 413 |
+
r"""
|
| 414 |
+
The call function to the pipeline for generation.
|
| 415 |
+
|
| 416 |
+
Args:
|
| 417 |
+
image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`):
|
| 418 |
+
Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
|
| 419 |
+
1]`.
|
| 420 |
+
height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 421 |
+
The height in pixels of the generated image.
|
| 422 |
+
width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
|
| 423 |
+
The width in pixels of the generated image.
|
| 424 |
+
num_frames (`int`, *optional*):
|
| 425 |
+
The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
|
| 426 |
+
`stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
|
| 427 |
+
num_inference_steps (`int`, *optional*, defaults to 25):
|
| 428 |
+
The number of denoising steps. More denoising steps usually lead to a higher quality video at the
|
| 429 |
+
expense of slower inference. This parameter is modulated by `strength`.
|
| 430 |
+
sigmas (`List[float]`, *optional*):
|
| 431 |
+
Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
|
| 432 |
+
their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
|
| 433 |
+
will be used.
|
| 434 |
+
min_guidance_scale (`float`, *optional*, defaults to 1.0):
|
| 435 |
+
The minimum guidance scale. Used for the classifier free guidance with first frame.
|
| 436 |
+
max_guidance_scale (`float`, *optional*, defaults to 3.0):
|
| 437 |
+
The maximum guidance scale. Used for the classifier free guidance with last frame.
|
| 438 |
+
fps (`int`, *optional*, defaults to 7):
|
| 439 |
+
Frames per second. The rate at which the generated images shall be exported to a video after
|
| 440 |
+
generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
|
| 441 |
+
motion_bucket_id (`int`, *optional*, defaults to 127):
|
| 442 |
+
Used for conditioning the amount of motion for the generation. The higher the number the more motion
|
| 443 |
+
will be in the video.
|
| 444 |
+
noise_aug_strength (`float`, *optional*, defaults to 0.02):
|
| 445 |
+
The amount of noise added to the init image, the higher it is the less the video will look like the
|
| 446 |
+
init image. Increase it for more motion.
|
| 447 |
+
decode_chunk_size (`int`, *optional*):
|
| 448 |
+
The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
|
| 449 |
+
expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
|
| 450 |
+
For lower memory usage, reduce `decode_chunk_size`.
|
| 451 |
+
num_videos_per_prompt (`int`, *optional*, defaults to 1):
|
| 452 |
+
The number of videos to generate per prompt.
|
| 453 |
+
generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
|
| 454 |
+
A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
|
| 455 |
+
generation deterministic.
|
| 456 |
+
latents (`torch.Tensor`, *optional*):
|
| 457 |
+
Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
|
| 458 |
+
generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
|
| 459 |
+
tensor is generated by sampling using the supplied random `generator`.
|
| 460 |
+
output_type (`str`, *optional*, defaults to `"pil"`):
|
| 461 |
+
The output format of the generated image. Choose between `pil`, `np` or `pt`.
|
| 462 |
+
callback_on_step_end (`Callable`, *optional*):
|
| 463 |
+
A function that is called at the end of each denoising step during inference. The function is called
|
| 464 |
+
with the following arguments:
|
| 465 |
+
`callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
|
| 466 |
+
`callback_kwargs` will include a list of all tensors as specified by
|
| 467 |
+
`callback_on_step_end_tensor_inputs`.
|
| 468 |
+
callback_on_step_end_tensor_inputs (`List`, *optional*):
|
| 469 |
+
The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
|
| 470 |
+
will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
|
| 471 |
+
`._callback_tensor_inputs` attribute of your pipeline class.
|
| 472 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 473 |
+
Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
|
| 474 |
+
plain tuple.
|
| 475 |
+
|
| 476 |
+
Examples:
|
| 477 |
+
|
| 478 |
+
Returns:
|
| 479 |
+
[`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
|
| 480 |
+
If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
|
| 481 |
+
returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.Tensor`) is
|
| 482 |
+
returned.
|
| 483 |
+
"""
|
| 484 |
+
# 0. Default height and width to unet
|
| 485 |
+
height = height or self.unet.config.sample_size * self.vae_scale_factor
|
| 486 |
+
width = width or self.unet.config.sample_size * self.vae_scale_factor
|
| 487 |
+
|
| 488 |
+
num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
|
| 489 |
+
decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
|
| 490 |
+
|
| 491 |
+
# 1. Check inputs. Raise error if not correct
|
| 492 |
+
self.check_inputs(image, height, width)
|
| 493 |
+
|
| 494 |
+
# 2. Define call parameters
|
| 495 |
+
if isinstance(image, PIL.Image.Image):
|
| 496 |
+
batch_size = 1
|
| 497 |
+
elif isinstance(image, list):
|
| 498 |
+
batch_size = len(image)
|
| 499 |
+
else:
|
| 500 |
+
batch_size = image.shape[0]
|
| 501 |
+
device = self._execution_device
|
| 502 |
+
# here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
|
| 503 |
+
# of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
|
| 504 |
+
# corresponds to doing no classifier free guidance.
|
| 505 |
+
self._guidance_scale = max_guidance_scale
|
| 506 |
+
|
| 507 |
+
# 3. Encode input image
|
| 508 |
+
image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
|
| 509 |
+
|
| 510 |
+
# NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
|
| 511 |
+
# See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
|
| 512 |
+
fps = fps - 1
|
| 513 |
+
|
| 514 |
+
# 4. Encode input image using VAE
|
| 515 |
+
image = self.video_processor.preprocess(image, height=height, width=width).to(device)
|
| 516 |
+
noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
|
| 517 |
+
image = image + noise_aug_strength * noise
|
| 518 |
+
|
| 519 |
+
needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
|
| 520 |
+
if needs_upcasting:
|
| 521 |
+
self.vae.to(dtype=torch.float32)
|
| 522 |
+
|
| 523 |
+
image_latents = self._encode_vae_image(
|
| 524 |
+
image,
|
| 525 |
+
device=device,
|
| 526 |
+
num_videos_per_prompt=num_videos_per_prompt,
|
| 527 |
+
do_classifier_free_guidance=self.do_classifier_free_guidance,
|
| 528 |
+
)
|
| 529 |
+
image_latents = image_latents.to(image_embeddings.dtype)
|
| 530 |
+
|
| 531 |
+
# cast back to fp16 if needed
|
| 532 |
+
if needs_upcasting:
|
| 533 |
+
self.vae.to(dtype=torch.float16)
|
| 534 |
+
|
| 535 |
+
# Repeat the image latents for each frame so we can concatenate them with the noise
|
| 536 |
+
# image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
|
| 537 |
+
image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
|
| 538 |
+
|
| 539 |
+
# 5. Get Added Time IDs
|
| 540 |
+
added_time_ids = self._get_add_time_ids(
|
| 541 |
+
fps,
|
| 542 |
+
motion_bucket_id,
|
| 543 |
+
noise_aug_strength,
|
| 544 |
+
image_embeddings.dtype,
|
| 545 |
+
batch_size,
|
| 546 |
+
num_videos_per_prompt,
|
| 547 |
+
self.do_classifier_free_guidance,
|
| 548 |
+
)
|
| 549 |
+
added_time_ids = added_time_ids.to(device)
|
| 550 |
+
|
| 551 |
+
# 6. Prepare timesteps
|
| 552 |
+
timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)
|
| 553 |
+
|
| 554 |
+
# 7. Prepare latent variables
|
| 555 |
+
num_channels_latents = self.unet.config.in_channels
|
| 556 |
+
latents = self.prepare_latents(
|
| 557 |
+
batch_size * num_videos_per_prompt,
|
| 558 |
+
num_frames,
|
| 559 |
+
num_channels_latents,
|
| 560 |
+
height,
|
| 561 |
+
width,
|
| 562 |
+
image_embeddings.dtype,
|
| 563 |
+
device,
|
| 564 |
+
generator,
|
| 565 |
+
latents,
|
| 566 |
+
)
|
| 567 |
+
|
| 568 |
+
# 8. Prepare guidance scale
|
| 569 |
+
guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
|
| 570 |
+
guidance_scale = guidance_scale.to(device, latents.dtype)
|
| 571 |
+
guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
|
| 572 |
+
guidance_scale = _append_dims(guidance_scale, latents.ndim)
|
| 573 |
+
|
| 574 |
+
self._guidance_scale = guidance_scale
|
| 575 |
+
|
| 576 |
+
# 9. Denoising loop
|
| 577 |
+
num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
|
| 578 |
+
self._num_timesteps = len(timesteps)
|
| 579 |
+
with self.progress_bar(total=num_inference_steps) as progress_bar:
|
| 580 |
+
for i, t in enumerate(timesteps):
|
| 581 |
+
# expand the latents if we are doing classifier free guidance
|
| 582 |
+
latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
|
| 583 |
+
latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
|
| 584 |
+
|
| 585 |
+
# Concatenate image_latents over channels dimension
|
| 586 |
+
latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
|
| 587 |
+
|
| 588 |
+
# predict the noise residual
|
| 589 |
+
noise_pred = self.unet(
|
| 590 |
+
latent_model_input,
|
| 591 |
+
t,
|
| 592 |
+
encoder_hidden_states=image_embeddings,
|
| 593 |
+
added_time_ids=added_time_ids,
|
| 594 |
+
return_dict=False,
|
| 595 |
+
)[0]
|
| 596 |
+
|
| 597 |
+
# perform guidance
|
| 598 |
+
if self.do_classifier_free_guidance:
|
| 599 |
+
noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
|
| 600 |
+
noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
|
| 601 |
+
|
| 602 |
+
# compute the previous noisy sample x_t -> x_t-1
|
| 603 |
+
latents = self.scheduler.step(noise_pred, t, latents).prev_sample
|
| 604 |
+
|
| 605 |
+
if callback_on_step_end is not None:
|
| 606 |
+
callback_kwargs = {}
|
| 607 |
+
for k in callback_on_step_end_tensor_inputs:
|
| 608 |
+
callback_kwargs[k] = locals()[k]
|
| 609 |
+
callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
|
| 610 |
+
|
| 611 |
+
latents = callback_outputs.pop("latents", latents)
|
| 612 |
+
|
| 613 |
+
if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
|
| 614 |
+
progress_bar.update()
|
| 615 |
+
|
| 616 |
+
if XLA_AVAILABLE:
|
| 617 |
+
xm.mark_step()
|
| 618 |
+
|
| 619 |
+
if not output_type == "latent":
|
| 620 |
+
# cast back to fp16 if needed
|
| 621 |
+
if needs_upcasting:
|
| 622 |
+
self.vae.to(dtype=torch.float16)
|
| 623 |
+
frames = self.decode_latents(latents, num_frames, decode_chunk_size)
|
| 624 |
+
frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
|
| 625 |
+
else:
|
| 626 |
+
frames = latents
|
| 627 |
+
|
| 628 |
+
self.maybe_free_model_hooks()
|
| 629 |
+
|
| 630 |
+
if not return_dict:
|
| 631 |
+
return frames
|
| 632 |
+
|
| 633 |
+
return StableVideoDiffusionPipelineOutput(frames=frames)
|
| 634 |
+
|
| 635 |
+
|
| 636 |
+
# resizing utils
|
| 637 |
+
# TODO: clean up later
|
| 638 |
+
def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
|
| 639 |
+
h, w = input.shape[-2:]
|
| 640 |
+
factors = (h / size[0], w / size[1])
|
| 641 |
+
|
| 642 |
+
# First, we have to determine sigma
|
| 643 |
+
# Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
|
| 644 |
+
sigmas = (
|
| 645 |
+
max((factors[0] - 1.0) / 2.0, 0.001),
|
| 646 |
+
max((factors[1] - 1.0) / 2.0, 0.001),
|
| 647 |
+
)
|
| 648 |
+
|
| 649 |
+
# Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
|
| 650 |
+
# https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
|
| 651 |
+
# But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
|
| 652 |
+
ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
|
| 653 |
+
|
| 654 |
+
# Make sure it is odd
|
| 655 |
+
if (ks[0] % 2) == 0:
|
| 656 |
+
ks = ks[0] + 1, ks[1]
|
| 657 |
+
|
| 658 |
+
if (ks[1] % 2) == 0:
|
| 659 |
+
ks = ks[0], ks[1] + 1
|
| 660 |
+
|
| 661 |
+
input = _gaussian_blur2d(input, ks, sigmas)
|
| 662 |
+
|
| 663 |
+
output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
|
| 664 |
+
return output
|
| 665 |
+
|
| 666 |
+
|
| 667 |
+
def _compute_padding(kernel_size):
|
| 668 |
+
"""Compute padding tuple."""
|
| 669 |
+
# 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
|
| 670 |
+
# https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
|
| 671 |
+
if len(kernel_size) < 2:
|
| 672 |
+
raise AssertionError(kernel_size)
|
| 673 |
+
computed = [k - 1 for k in kernel_size]
|
| 674 |
+
|
| 675 |
+
# for even kernels we need to do asymmetric padding :(
|
| 676 |
+
out_padding = 2 * len(kernel_size) * [0]
|
| 677 |
+
|
| 678 |
+
for i in range(len(kernel_size)):
|
| 679 |
+
computed_tmp = computed[-(i + 1)]
|
| 680 |
+
|
| 681 |
+
pad_front = computed_tmp // 2
|
| 682 |
+
pad_rear = computed_tmp - pad_front
|
| 683 |
+
|
| 684 |
+
out_padding[2 * i + 0] = pad_front
|
| 685 |
+
out_padding[2 * i + 1] = pad_rear
|
| 686 |
+
|
| 687 |
+
return out_padding
|
| 688 |
+
|
| 689 |
+
|
| 690 |
+
def _filter2d(input, kernel):
|
| 691 |
+
# prepare kernel
|
| 692 |
+
b, c, h, w = input.shape
|
| 693 |
+
tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
|
| 694 |
+
|
| 695 |
+
tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
|
| 696 |
+
|
| 697 |
+
height, width = tmp_kernel.shape[-2:]
|
| 698 |
+
|
| 699 |
+
padding_shape: List[int] = _compute_padding([height, width])
|
| 700 |
+
input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
|
| 701 |
+
|
| 702 |
+
# kernel and input tensor reshape to align element-wise or batch-wise params
|
| 703 |
+
tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
|
| 704 |
+
input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
|
| 705 |
+
|
| 706 |
+
# convolve the tensor with the kernel.
|
| 707 |
+
output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
|
| 708 |
+
|
| 709 |
+
out = output.view(b, c, h, w)
|
| 710 |
+
return out
|
| 711 |
+
|
| 712 |
+
|
| 713 |
+
def _gaussian(window_size: int, sigma):
|
| 714 |
+
if isinstance(sigma, float):
|
| 715 |
+
sigma = torch.tensor([[sigma]])
|
| 716 |
+
|
| 717 |
+
batch_size = sigma.shape[0]
|
| 718 |
+
|
| 719 |
+
x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
|
| 720 |
+
|
| 721 |
+
if window_size % 2 == 0:
|
| 722 |
+
x = x + 0.5
|
| 723 |
+
|
| 724 |
+
gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
|
| 725 |
+
|
| 726 |
+
return gauss / gauss.sum(-1, keepdim=True)
|
| 727 |
+
|
| 728 |
+
|
| 729 |
+
def _gaussian_blur2d(input, kernel_size, sigma):
|
| 730 |
+
if isinstance(sigma, tuple):
|
| 731 |
+
sigma = torch.tensor([sigma], dtype=input.dtype)
|
| 732 |
+
else:
|
| 733 |
+
sigma = sigma.to(dtype=input.dtype)
|
| 734 |
+
|
| 735 |
+
ky, kx = int(kernel_size[0]), int(kernel_size[1])
|
| 736 |
+
bs = sigma.shape[0]
|
| 737 |
+
kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
|
| 738 |
+
kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
|
| 739 |
+
out_x = _filter2d(input, kernel_x[..., None, :])
|
| 740 |
+
out = _filter2d(out_x, kernel_y[..., None])
|
| 741 |
+
|
| 742 |
+
return out
|
ctrl_world/src/models/unet_spatio_temporal_condition.py
ADDED
|
@@ -0,0 +1,522 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from dataclasses import dataclass
|
| 2 |
+
from typing import Dict, Optional, Tuple, Union
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
from diffusers.configuration_utils import ConfigMixin, register_to_config
|
| 8 |
+
from diffusers.loaders import UNet2DConditionLoadersMixin
|
| 9 |
+
from diffusers.utils import BaseOutput, logging
|
| 10 |
+
from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
|
| 11 |
+
from diffusers.models.embeddings import TimestepEmbedding, Timesteps
|
| 12 |
+
from diffusers.models.modeling_utils import ModelMixin
|
| 13 |
+
from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
@dataclass
|
| 20 |
+
class UNetSpatioTemporalConditionOutput(BaseOutput):
|
| 21 |
+
"""
|
| 22 |
+
The output of [`UNetSpatioTemporalConditionModel`].
|
| 23 |
+
|
| 24 |
+
Args:
|
| 25 |
+
sample (`torch.Tensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
|
| 26 |
+
The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
|
| 27 |
+
"""
|
| 28 |
+
|
| 29 |
+
sample: torch.Tensor = None
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
|
| 33 |
+
r"""
|
| 34 |
+
A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and
|
| 35 |
+
returns a sample shaped output.
|
| 36 |
+
|
| 37 |
+
This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
|
| 38 |
+
for all models (such as downloading or saving).
|
| 39 |
+
|
| 40 |
+
Parameters:
|
| 41 |
+
sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
|
| 42 |
+
Height and width of input/output sample.
|
| 43 |
+
in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
|
| 44 |
+
out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
|
| 45 |
+
down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
|
| 46 |
+
The tuple of downsample blocks to use.
|
| 47 |
+
up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
|
| 48 |
+
The tuple of upsample blocks to use.
|
| 49 |
+
block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
|
| 50 |
+
The tuple of output channels for each block.
|
| 51 |
+
addition_time_embed_dim: (`int`, defaults to 256):
|
| 52 |
+
Dimension to to encode the additional time ids.
|
| 53 |
+
projection_class_embeddings_input_dim (`int`, defaults to 768):
|
| 54 |
+
The dimension of the projection of encoded `added_time_ids`.
|
| 55 |
+
layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
|
| 56 |
+
cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
|
| 57 |
+
The dimension of the cross attention features.
|
| 58 |
+
transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
|
| 59 |
+
The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
|
| 60 |
+
[`~models.unets.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
|
| 61 |
+
[`~models.unets.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
|
| 62 |
+
[`~models.unets.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
|
| 63 |
+
num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
|
| 64 |
+
The number of attention heads.
|
| 65 |
+
dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
|
| 66 |
+
"""
|
| 67 |
+
|
| 68 |
+
_supports_gradient_checkpointing = True
|
| 69 |
+
|
| 70 |
+
@register_to_config
|
| 71 |
+
def __init__(
|
| 72 |
+
self,
|
| 73 |
+
sample_size: Optional[int] = None,
|
| 74 |
+
in_channels: int = 8,
|
| 75 |
+
out_channels: int = 4,
|
| 76 |
+
down_block_types: Tuple[str] = (
|
| 77 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 78 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 79 |
+
"CrossAttnDownBlockSpatioTemporal",
|
| 80 |
+
"DownBlockSpatioTemporal",
|
| 81 |
+
),
|
| 82 |
+
up_block_types: Tuple[str] = (
|
| 83 |
+
"UpBlockSpatioTemporal",
|
| 84 |
+
"CrossAttnUpBlockSpatioTemporal",
|
| 85 |
+
"CrossAttnUpBlockSpatioTemporal",
|
| 86 |
+
"CrossAttnUpBlockSpatioTemporal",
|
| 87 |
+
),
|
| 88 |
+
block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
|
| 89 |
+
addition_time_embed_dim: int = 256,
|
| 90 |
+
projection_class_embeddings_input_dim: int = 768,
|
| 91 |
+
layers_per_block: Union[int, Tuple[int]] = 2,
|
| 92 |
+
cross_attention_dim: Union[int, Tuple[int]] = 1024,
|
| 93 |
+
transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
|
| 94 |
+
num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
|
| 95 |
+
num_frames: int = 25,
|
| 96 |
+
):
|
| 97 |
+
super().__init__()
|
| 98 |
+
|
| 99 |
+
self.sample_size = sample_size
|
| 100 |
+
|
| 101 |
+
# Check inputs
|
| 102 |
+
if len(down_block_types) != len(up_block_types):
|
| 103 |
+
raise ValueError(
|
| 104 |
+
f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
|
| 105 |
+
)
|
| 106 |
+
|
| 107 |
+
if len(block_out_channels) != len(down_block_types):
|
| 108 |
+
raise ValueError(
|
| 109 |
+
f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
|
| 110 |
+
)
|
| 111 |
+
|
| 112 |
+
if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
|
| 113 |
+
raise ValueError(
|
| 114 |
+
f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
|
| 115 |
+
)
|
| 116 |
+
|
| 117 |
+
if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
|
| 118 |
+
raise ValueError(
|
| 119 |
+
f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
|
| 120 |
+
)
|
| 121 |
+
|
| 122 |
+
if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
|
| 123 |
+
raise ValueError(
|
| 124 |
+
f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
|
| 125 |
+
)
|
| 126 |
+
|
| 127 |
+
# input
|
| 128 |
+
self.conv_in = nn.Conv2d(
|
| 129 |
+
in_channels,
|
| 130 |
+
block_out_channels[0],
|
| 131 |
+
kernel_size=3,
|
| 132 |
+
padding=1,
|
| 133 |
+
)
|
| 134 |
+
|
| 135 |
+
# time
|
| 136 |
+
time_embed_dim = block_out_channels[0] * 4
|
| 137 |
+
|
| 138 |
+
self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
|
| 139 |
+
timestep_input_dim = block_out_channels[0]
|
| 140 |
+
|
| 141 |
+
self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
|
| 142 |
+
|
| 143 |
+
self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
|
| 144 |
+
self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
|
| 145 |
+
|
| 146 |
+
self.down_blocks = nn.ModuleList([])
|
| 147 |
+
self.up_blocks = nn.ModuleList([])
|
| 148 |
+
|
| 149 |
+
if isinstance(num_attention_heads, int):
|
| 150 |
+
num_attention_heads = (num_attention_heads,) * len(down_block_types)
|
| 151 |
+
|
| 152 |
+
if isinstance(cross_attention_dim, int):
|
| 153 |
+
cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
|
| 154 |
+
|
| 155 |
+
if isinstance(layers_per_block, int):
|
| 156 |
+
layers_per_block = [layers_per_block] * len(down_block_types)
|
| 157 |
+
|
| 158 |
+
if isinstance(transformer_layers_per_block, int):
|
| 159 |
+
transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
|
| 160 |
+
|
| 161 |
+
blocks_time_embed_dim = time_embed_dim
|
| 162 |
+
|
| 163 |
+
# down
|
| 164 |
+
output_channel = block_out_channels[0]
|
| 165 |
+
for i, down_block_type in enumerate(down_block_types):
|
| 166 |
+
input_channel = output_channel
|
| 167 |
+
output_channel = block_out_channels[i]
|
| 168 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 169 |
+
|
| 170 |
+
down_block = get_down_block(
|
| 171 |
+
down_block_type,
|
| 172 |
+
num_layers=layers_per_block[i],
|
| 173 |
+
transformer_layers_per_block=transformer_layers_per_block[i],
|
| 174 |
+
in_channels=input_channel,
|
| 175 |
+
out_channels=output_channel,
|
| 176 |
+
temb_channels=blocks_time_embed_dim,
|
| 177 |
+
add_downsample=not is_final_block,
|
| 178 |
+
resnet_eps=1e-5,
|
| 179 |
+
cross_attention_dim=cross_attention_dim[i],
|
| 180 |
+
num_attention_heads=num_attention_heads[i],
|
| 181 |
+
resnet_act_fn="silu",
|
| 182 |
+
)
|
| 183 |
+
self.down_blocks.append(down_block)
|
| 184 |
+
|
| 185 |
+
# mid
|
| 186 |
+
self.mid_block = UNetMidBlockSpatioTemporal(
|
| 187 |
+
block_out_channels[-1],
|
| 188 |
+
temb_channels=blocks_time_embed_dim,
|
| 189 |
+
transformer_layers_per_block=transformer_layers_per_block[-1],
|
| 190 |
+
cross_attention_dim=cross_attention_dim[-1],
|
| 191 |
+
num_attention_heads=num_attention_heads[-1],
|
| 192 |
+
)
|
| 193 |
+
|
| 194 |
+
# count how many layers upsample the images
|
| 195 |
+
self.num_upsamplers = 0
|
| 196 |
+
|
| 197 |
+
# up
|
| 198 |
+
reversed_block_out_channels = list(reversed(block_out_channels))
|
| 199 |
+
reversed_num_attention_heads = list(reversed(num_attention_heads))
|
| 200 |
+
reversed_layers_per_block = list(reversed(layers_per_block))
|
| 201 |
+
reversed_cross_attention_dim = list(reversed(cross_attention_dim))
|
| 202 |
+
reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
|
| 203 |
+
|
| 204 |
+
output_channel = reversed_block_out_channels[0]
|
| 205 |
+
for i, up_block_type in enumerate(up_block_types):
|
| 206 |
+
is_final_block = i == len(block_out_channels) - 1
|
| 207 |
+
|
| 208 |
+
prev_output_channel = output_channel
|
| 209 |
+
output_channel = reversed_block_out_channels[i]
|
| 210 |
+
input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
|
| 211 |
+
|
| 212 |
+
# add upsample block for all BUT final layer
|
| 213 |
+
if not is_final_block:
|
| 214 |
+
add_upsample = True
|
| 215 |
+
self.num_upsamplers += 1
|
| 216 |
+
else:
|
| 217 |
+
add_upsample = False
|
| 218 |
+
|
| 219 |
+
up_block = get_up_block(
|
| 220 |
+
up_block_type,
|
| 221 |
+
num_layers=reversed_layers_per_block[i] + 1,
|
| 222 |
+
transformer_layers_per_block=reversed_transformer_layers_per_block[i],
|
| 223 |
+
in_channels=input_channel,
|
| 224 |
+
out_channels=output_channel,
|
| 225 |
+
prev_output_channel=prev_output_channel,
|
| 226 |
+
temb_channels=blocks_time_embed_dim,
|
| 227 |
+
add_upsample=add_upsample,
|
| 228 |
+
resnet_eps=1e-5,
|
| 229 |
+
resolution_idx=i,
|
| 230 |
+
cross_attention_dim=reversed_cross_attention_dim[i],
|
| 231 |
+
num_attention_heads=reversed_num_attention_heads[i],
|
| 232 |
+
resnet_act_fn="silu",
|
| 233 |
+
)
|
| 234 |
+
self.up_blocks.append(up_block)
|
| 235 |
+
prev_output_channel = output_channel
|
| 236 |
+
|
| 237 |
+
# out
|
| 238 |
+
self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
|
| 239 |
+
self.conv_act = nn.SiLU()
|
| 240 |
+
|
| 241 |
+
self.conv_out = nn.Conv2d(
|
| 242 |
+
block_out_channels[0],
|
| 243 |
+
out_channels,
|
| 244 |
+
kernel_size=3,
|
| 245 |
+
padding=1,
|
| 246 |
+
)
|
| 247 |
+
|
| 248 |
+
@property
|
| 249 |
+
def attn_processors(self) -> Dict[str, AttentionProcessor]:
|
| 250 |
+
r"""
|
| 251 |
+
Returns:
|
| 252 |
+
`dict` of attention processors: A dictionary containing all attention processors used in the model with
|
| 253 |
+
indexed by its weight name.
|
| 254 |
+
"""
|
| 255 |
+
# set recursively
|
| 256 |
+
processors = {}
|
| 257 |
+
|
| 258 |
+
def fn_recursive_add_processors(
|
| 259 |
+
name: str,
|
| 260 |
+
module: torch.nn.Module,
|
| 261 |
+
processors: Dict[str, AttentionProcessor],
|
| 262 |
+
):
|
| 263 |
+
if hasattr(module, "get_processor"):
|
| 264 |
+
processors[f"{name}.processor"] = module.get_processor()
|
| 265 |
+
|
| 266 |
+
for sub_name, child in module.named_children():
|
| 267 |
+
fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
|
| 268 |
+
|
| 269 |
+
return processors
|
| 270 |
+
|
| 271 |
+
for name, module in self.named_children():
|
| 272 |
+
fn_recursive_add_processors(name, module, processors)
|
| 273 |
+
|
| 274 |
+
return processors
|
| 275 |
+
|
| 276 |
+
def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
|
| 277 |
+
r"""
|
| 278 |
+
Sets the attention processor to use to compute attention.
|
| 279 |
+
|
| 280 |
+
Parameters:
|
| 281 |
+
processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
|
| 282 |
+
The instantiated processor class or a dictionary of processor classes that will be set as the processor
|
| 283 |
+
for **all** `Attention` layers.
|
| 284 |
+
|
| 285 |
+
If `processor` is a dict, the key needs to define the path to the corresponding cross attention
|
| 286 |
+
processor. This is strongly recommended when setting trainable attention processors.
|
| 287 |
+
|
| 288 |
+
"""
|
| 289 |
+
count = len(self.attn_processors.keys())
|
| 290 |
+
|
| 291 |
+
if isinstance(processor, dict) and len(processor) != count:
|
| 292 |
+
raise ValueError(
|
| 293 |
+
f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
|
| 294 |
+
f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
|
| 295 |
+
)
|
| 296 |
+
|
| 297 |
+
def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
|
| 298 |
+
if hasattr(module, "set_processor"):
|
| 299 |
+
if not isinstance(processor, dict):
|
| 300 |
+
module.set_processor(processor)
|
| 301 |
+
else:
|
| 302 |
+
module.set_processor(processor.pop(f"{name}.processor"))
|
| 303 |
+
|
| 304 |
+
for sub_name, child in module.named_children():
|
| 305 |
+
fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
|
| 306 |
+
|
| 307 |
+
for name, module in self.named_children():
|
| 308 |
+
fn_recursive_attn_processor(name, module, processor)
|
| 309 |
+
|
| 310 |
+
def set_default_attn_processor(self):
|
| 311 |
+
"""
|
| 312 |
+
Disables custom attention processors and sets the default attention implementation.
|
| 313 |
+
"""
|
| 314 |
+
if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
|
| 315 |
+
processor = AttnProcessor()
|
| 316 |
+
else:
|
| 317 |
+
raise ValueError(
|
| 318 |
+
f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
|
| 319 |
+
)
|
| 320 |
+
|
| 321 |
+
self.set_attn_processor(processor)
|
| 322 |
+
|
| 323 |
+
# Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
|
| 324 |
+
def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
|
| 325 |
+
"""
|
| 326 |
+
Sets the attention processor to use [feed forward
|
| 327 |
+
chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
|
| 328 |
+
|
| 329 |
+
Parameters:
|
| 330 |
+
chunk_size (`int`, *optional*):
|
| 331 |
+
The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
|
| 332 |
+
over each tensor of dim=`dim`.
|
| 333 |
+
dim (`int`, *optional*, defaults to `0`):
|
| 334 |
+
The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
|
| 335 |
+
or dim=1 (sequence length).
|
| 336 |
+
"""
|
| 337 |
+
if dim not in [0, 1]:
|
| 338 |
+
raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
|
| 339 |
+
|
| 340 |
+
# By default chunk size is 1
|
| 341 |
+
chunk_size = chunk_size or 1
|
| 342 |
+
|
| 343 |
+
def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
|
| 344 |
+
if hasattr(module, "set_chunk_feed_forward"):
|
| 345 |
+
module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
|
| 346 |
+
|
| 347 |
+
for child in module.children():
|
| 348 |
+
fn_recursive_feed_forward(child, chunk_size, dim)
|
| 349 |
+
|
| 350 |
+
for module in self.children():
|
| 351 |
+
fn_recursive_feed_forward(module, chunk_size, dim)
|
| 352 |
+
|
| 353 |
+
def forward(
|
| 354 |
+
self,
|
| 355 |
+
sample: torch.Tensor,
|
| 356 |
+
timestep: Union[torch.Tensor, float, int],
|
| 357 |
+
encoder_hidden_states: torch.Tensor,
|
| 358 |
+
added_time_ids: torch.Tensor,
|
| 359 |
+
return_dict: bool = True,
|
| 360 |
+
frame_level_cond=False,
|
| 361 |
+
) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
|
| 362 |
+
r"""
|
| 363 |
+
The [`UNetSpatioTemporalConditionModel`] forward method.
|
| 364 |
+
|
| 365 |
+
Args:
|
| 366 |
+
sample (`torch.Tensor`):
|
| 367 |
+
The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
|
| 368 |
+
timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
|
| 369 |
+
encoder_hidden_states (`torch.Tensor`):
|
| 370 |
+
The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
|
| 371 |
+
added_time_ids: (`torch.Tensor`):
|
| 372 |
+
The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
|
| 373 |
+
embeddings and added to the time embeddings.
|
| 374 |
+
return_dict (`bool`, *optional*, defaults to `True`):
|
| 375 |
+
Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead
|
| 376 |
+
of a plain tuple.
|
| 377 |
+
Returns:
|
| 378 |
+
[`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
|
| 379 |
+
If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
|
| 380 |
+
returned, otherwise a `tuple` is returned where the first element is the sample tensor.
|
| 381 |
+
"""
|
| 382 |
+
# By default samples have to be AT least a multiple of the overall upsampling factor.
|
| 383 |
+
# The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
|
| 384 |
+
# However, the upsampling interpolation output size can be forced to fit any upsampling size
|
| 385 |
+
# on the fly if necessary.
|
| 386 |
+
default_overall_up_factor = 2**self.num_upsamplers
|
| 387 |
+
|
| 388 |
+
# upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
|
| 389 |
+
forward_upsample_size = False
|
| 390 |
+
upsample_size = None
|
| 391 |
+
|
| 392 |
+
if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
|
| 393 |
+
logger.info("Forward upsample size to force interpolation output size.")
|
| 394 |
+
forward_upsample_size = True
|
| 395 |
+
|
| 396 |
+
# 1. time
|
| 397 |
+
timesteps = timestep
|
| 398 |
+
if not torch.is_tensor(timesteps):
|
| 399 |
+
# TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
|
| 400 |
+
# This would be a good case for the `match` statement (Python 3.10+)
|
| 401 |
+
is_mps = sample.device.type == "mps"
|
| 402 |
+
is_npu = sample.device.type == "npu"
|
| 403 |
+
if isinstance(timestep, float):
|
| 404 |
+
dtype = torch.float32 if (is_mps or is_npu) else torch.float64
|
| 405 |
+
else:
|
| 406 |
+
dtype = torch.int32 if (is_mps or is_npu) else torch.int64
|
| 407 |
+
timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
|
| 408 |
+
elif len(timesteps.shape) == 0:
|
| 409 |
+
timesteps = timesteps[None].to(sample.device)
|
| 410 |
+
|
| 411 |
+
# broadcast to batch dimension in a way that's compatible with ONNX/Core ML
|
| 412 |
+
batch_size, num_frames = sample.shape[:2]
|
| 413 |
+
timesteps = timesteps.expand(batch_size)
|
| 414 |
+
|
| 415 |
+
t_emb = self.time_proj(timesteps)
|
| 416 |
+
|
| 417 |
+
# `Timesteps` does not contain any weights and will always return f32 tensors
|
| 418 |
+
# but time_embedding might actually be running in fp16. so we need to cast here.
|
| 419 |
+
# there might be better ways to encapsulate this.
|
| 420 |
+
t_emb = t_emb.to(dtype=sample.dtype)
|
| 421 |
+
|
| 422 |
+
emb = self.time_embedding(t_emb)
|
| 423 |
+
|
| 424 |
+
time_embeds = self.add_time_proj(added_time_ids.flatten())
|
| 425 |
+
time_embeds = time_embeds.reshape((batch_size, -1))
|
| 426 |
+
time_embeds = time_embeds.to(emb.dtype)
|
| 427 |
+
aug_emb = self.add_embedding(time_embeds)
|
| 428 |
+
emb = emb + aug_emb
|
| 429 |
+
|
| 430 |
+
# Flatten the batch and frames dimensions
|
| 431 |
+
# sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
|
| 432 |
+
sample = sample.flatten(0, 1)
|
| 433 |
+
# Repeat the embeddings num_video_frames times
|
| 434 |
+
# emb: [batch, channels] -> [batch * frames, channels]
|
| 435 |
+
emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
|
| 436 |
+
|
| 437 |
+
# encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
|
| 438 |
+
# encoder_hidden_states = encoder_hidden_states.repeat_interleave(
|
| 439 |
+
# num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
|
| 440 |
+
# )
|
| 441 |
+
|
| 442 |
+
############################# newly added to support frame_level pose conditioning ########################################
|
| 443 |
+
# print('new one!!!!!!!!!')
|
| 444 |
+
if not frame_level_cond:
|
| 445 |
+
encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
|
| 446 |
+
else:
|
| 447 |
+
encoder_hidden_states = encoder_hidden_states.reshape(batch_size * num_frames, -1, encoder_hidden_states.shape[-1])
|
| 448 |
+
############################################################################################################################
|
| 449 |
+
|
| 450 |
+
# 2. pre-process
|
| 451 |
+
sample = self.conv_in(sample)
|
| 452 |
+
|
| 453 |
+
image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
|
| 454 |
+
|
| 455 |
+
down_block_res_samples = (sample,)
|
| 456 |
+
for downsample_block in self.down_blocks:
|
| 457 |
+
if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
|
| 458 |
+
sample, res_samples = downsample_block(
|
| 459 |
+
hidden_states=sample,
|
| 460 |
+
temb=emb,
|
| 461 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 462 |
+
image_only_indicator=image_only_indicator,
|
| 463 |
+
)
|
| 464 |
+
else:
|
| 465 |
+
sample, res_samples = downsample_block(
|
| 466 |
+
hidden_states=sample,
|
| 467 |
+
temb=emb,
|
| 468 |
+
image_only_indicator=image_only_indicator,
|
| 469 |
+
)
|
| 470 |
+
|
| 471 |
+
down_block_res_samples += res_samples
|
| 472 |
+
|
| 473 |
+
# 4. mid
|
| 474 |
+
sample = self.mid_block(
|
| 475 |
+
hidden_states=sample,
|
| 476 |
+
temb=emb,
|
| 477 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 478 |
+
image_only_indicator=image_only_indicator,
|
| 479 |
+
)
|
| 480 |
+
|
| 481 |
+
# 5. up
|
| 482 |
+
for i, upsample_block in enumerate(self.up_blocks):
|
| 483 |
+
is_final_block = i == len(self.up_blocks) - 1
|
| 484 |
+
|
| 485 |
+
res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
|
| 486 |
+
down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
|
| 487 |
+
|
| 488 |
+
# if we have not reached the final block and need to forward the
|
| 489 |
+
# upsample size, we do it here
|
| 490 |
+
if not is_final_block and forward_upsample_size:
|
| 491 |
+
upsample_size = down_block_res_samples[-1].shape[2:]
|
| 492 |
+
|
| 493 |
+
if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
|
| 494 |
+
sample = upsample_block(
|
| 495 |
+
hidden_states=sample,
|
| 496 |
+
temb=emb,
|
| 497 |
+
res_hidden_states_tuple=res_samples,
|
| 498 |
+
encoder_hidden_states=encoder_hidden_states,
|
| 499 |
+
upsample_size=upsample_size,
|
| 500 |
+
image_only_indicator=image_only_indicator,
|
| 501 |
+
)
|
| 502 |
+
else:
|
| 503 |
+
sample = upsample_block(
|
| 504 |
+
hidden_states=sample,
|
| 505 |
+
temb=emb,
|
| 506 |
+
res_hidden_states_tuple=res_samples,
|
| 507 |
+
upsample_size=upsample_size,
|
| 508 |
+
image_only_indicator=image_only_indicator,
|
| 509 |
+
)
|
| 510 |
+
|
| 511 |
+
# 6. post-process
|
| 512 |
+
sample = self.conv_norm_out(sample)
|
| 513 |
+
sample = self.conv_act(sample)
|
| 514 |
+
sample = self.conv_out(sample)
|
| 515 |
+
|
| 516 |
+
# 7. Reshape back to original shape
|
| 517 |
+
sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
|
| 518 |
+
|
| 519 |
+
if not return_dict:
|
| 520 |
+
return (sample,)
|
| 521 |
+
|
| 522 |
+
return UNetSpatioTemporalConditionOutput(sample=sample)
|
ctrl_world/src/world_model.py
ADDED
|
@@ -0,0 +1,185 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from models.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline
|
| 2 |
+
from models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
|
| 3 |
+
from models.pipeline_ctrl_world import CtrlWorldDiffusionPipeline
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
|
| 7 |
+
import json
|
| 8 |
+
import einops
|
| 9 |
+
import numpy as np
|
| 10 |
+
from huggingface_hub import snapshot_download
|
| 11 |
+
from transformers import AutoTokenizer, CLIPTextModelWithProjection
|
| 12 |
+
|
| 13 |
+
class Action_encoder2(nn.Module):
|
| 14 |
+
def __init__(self, action_dim, action_num, hidden_size, text_cond=True):
|
| 15 |
+
super().__init__()
|
| 16 |
+
self.action_dim = action_dim
|
| 17 |
+
self.action_num = action_num
|
| 18 |
+
self.hidden_size = hidden_size
|
| 19 |
+
self.text_cond = text_cond
|
| 20 |
+
|
| 21 |
+
input_dim = int(action_dim)
|
| 22 |
+
self.action_encode = nn.Sequential(
|
| 23 |
+
nn.Linear(input_dim, 1024),
|
| 24 |
+
nn.SiLU(),
|
| 25 |
+
nn.Linear(1024, 1024),
|
| 26 |
+
nn.SiLU(),
|
| 27 |
+
nn.Linear(1024, 1024)
|
| 28 |
+
)
|
| 29 |
+
# kaiming initialization
|
| 30 |
+
nn.init.kaiming_normal_(self.action_encode[0].weight, mode='fan_in', nonlinearity='relu')
|
| 31 |
+
nn.init.kaiming_normal_(self.action_encode[2].weight, mode='fan_in', nonlinearity='relu')
|
| 32 |
+
|
| 33 |
+
def forward(self, action, texts=None, text_tokinizer=None, text_encoder=None, frame_level_cond=True,):
|
| 34 |
+
# action: (B, action_num, action_dim)
|
| 35 |
+
B,T,D = action.shape
|
| 36 |
+
if not frame_level_cond:
|
| 37 |
+
action = einops.rearrange(action, 'b t d -> b 1 (t d)')
|
| 38 |
+
action = self.action_encode(action)
|
| 39 |
+
|
| 40 |
+
if texts is not None and self.text_cond:
|
| 41 |
+
# with 50% probability, add text condition
|
| 42 |
+
with torch.no_grad():
|
| 43 |
+
inputs = text_tokinizer(texts, padding='max_length', return_tensors="pt", truncation=True).to(text_encoder.device)
|
| 44 |
+
outputs = text_encoder(**inputs)
|
| 45 |
+
hidden_text = outputs.text_embeds # (B, 512)
|
| 46 |
+
hidden_text = einops.repeat(hidden_text, 'b c -> b 1 (n c)', n=2) # (B, 1, 1024)
|
| 47 |
+
|
| 48 |
+
action = action + hidden_text # (B, T, hidden_size)
|
| 49 |
+
return action # (B, 1, hidden_size) or (B, T, hidden_size) if frame_level_cond
|
| 50 |
+
|
| 51 |
+
|
| 52 |
+
class CrtlWorld(nn.Module):
|
| 53 |
+
def __init__(self, config: dict):
|
| 54 |
+
super(CrtlWorld, self).__init__()
|
| 55 |
+
|
| 56 |
+
self.config = config
|
| 57 |
+
# load from pretrained stable video diffusion
|
| 58 |
+
model_local_path = snapshot_download(
|
| 59 |
+
repo_id=config["svd_model_path"], # e.g. "stabilityai/stable-video-diffusion-img2vid"
|
| 60 |
+
repo_type="model"
|
| 61 |
+
)
|
| 62 |
+
|
| 63 |
+
# Load pipeline from downloaded path
|
| 64 |
+
self.pipeline = StableVideoDiffusionPipeline.from_pretrained(
|
| 65 |
+
model_local_path,
|
| 66 |
+
torch_dtype="auto"
|
| 67 |
+
)
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
unet = UNetSpatioTemporalConditionModel()
|
| 71 |
+
unet.load_state_dict(self.pipeline.unet.state_dict(), strict=False)
|
| 72 |
+
self.pipeline.unet = unet
|
| 73 |
+
|
| 74 |
+
self.unet = self.pipeline.unet
|
| 75 |
+
self.vae = self.pipeline.vae
|
| 76 |
+
self.image_encoder = self.pipeline.image_encoder
|
| 77 |
+
self.scheduler = self.pipeline.scheduler
|
| 78 |
+
|
| 79 |
+
# freeze vae, image_encoder, enable unet gradient ckpt
|
| 80 |
+
self.vae.requires_grad_(False)
|
| 81 |
+
self.image_encoder.requires_grad_(False)
|
| 82 |
+
self.unet.requires_grad_(True)
|
| 83 |
+
self.unet.enable_gradient_checkpointing()
|
| 84 |
+
|
| 85 |
+
# SVD is a img2video model, load a clip text encoder
|
| 86 |
+
|
| 87 |
+
model_local_path = snapshot_download(
|
| 88 |
+
repo_id=config["clip_model_path"], # e.g. "stabilityai/stable-video-diffusion-img2vid"
|
| 89 |
+
repo_type="model"
|
| 90 |
+
)
|
| 91 |
+
|
| 92 |
+
self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
|
| 93 |
+
model_local_path,
|
| 94 |
+
torch_dtype="auto"
|
| 95 |
+
)
|
| 96 |
+
self.tokenizer = AutoTokenizer.from_pretrained(model_local_path, use_fast=False)
|
| 97 |
+
self.text_encoder.requires_grad_(False)
|
| 98 |
+
|
| 99 |
+
# initialize an action projector
|
| 100 |
+
self.action_encoder = Action_encoder2(action_dim=config["action_dim"], action_num=int(config["num_history"]+config["num_frames"]), hidden_size=1024, text_cond=config["text_cond"])
|
| 101 |
+
|
| 102 |
+
with open(f"{config["data_stat_path"]}", 'r') as f:
|
| 103 |
+
data_stat = json.load(f)
|
| 104 |
+
self.state_p01 = np.array(data_stat['state_01'])[None,:]
|
| 105 |
+
self.state_p99 = np.array(data_stat['state_99'])[None,:]
|
| 106 |
+
|
| 107 |
+
def normalize_bound(
|
| 108 |
+
self,
|
| 109 |
+
data: np.ndarray,
|
| 110 |
+
clip_min: float = -1,
|
| 111 |
+
clip_max: float = 1,
|
| 112 |
+
eps: float = 1e-8,
|
| 113 |
+
) -> np.ndarray:
|
| 114 |
+
ndata = 2 * (data - self.state_p01) / (self.state_p99 - self.state_p01 + eps) - 1
|
| 115 |
+
return np.clip(ndata, clip_min, clip_max)
|
| 116 |
+
|
| 117 |
+
def decode(self, latents: torch.Tensor):
|
| 118 |
+
|
| 119 |
+
bsz, frame_num = latents.shape[:2]
|
| 120 |
+
x = latents.flatten(0, 1)
|
| 121 |
+
|
| 122 |
+
decoded = []
|
| 123 |
+
chunk_size = self.config["decode_chunk_size"]
|
| 124 |
+
for i in range(0, x.shape[0], chunk_size):
|
| 125 |
+
chunk = x[i:i + chunk_size] / self.pipeline.vae.config.scaling_factor
|
| 126 |
+
decode_kwargs = {"num_frames": chunk.shape[0]}
|
| 127 |
+
out = self.pipeline.vae.decode(chunk, **decode_kwargs).sample
|
| 128 |
+
decoded.append(out)
|
| 129 |
+
|
| 130 |
+
videos = torch.cat(decoded, dim=0)
|
| 131 |
+
videos = videos.reshape(bsz, frame_num, *videos.shape[1:])
|
| 132 |
+
videos = ((videos / 2.0 + 0.5).clamp(0, 1))
|
| 133 |
+
videos = videos.detach().float().cpu()
|
| 134 |
+
|
| 135 |
+
def encode(self, img: torch.Tensor):
|
| 136 |
+
|
| 137 |
+
x = img.unsqueeze(0)
|
| 138 |
+
x = x * 2 - 1 # [0,1] β [-1,1]
|
| 139 |
+
|
| 140 |
+
vae = self.pipeline.vae
|
| 141 |
+
with torch.no_grad():
|
| 142 |
+
latent = vae.encode(x).latent_dist.sample()
|
| 143 |
+
latent = latent * vae.config.scaling_factor
|
| 144 |
+
|
| 145 |
+
return latent.detach()
|
| 146 |
+
|
| 147 |
+
def action_text_encode(self, action: torch.Tensor, text):
|
| 148 |
+
|
| 149 |
+
action_tensor = action.unsqueeze(0)
|
| 150 |
+
|
| 151 |
+
# ββ Encode action (+ optional text) βββββββββββββββββββ
|
| 152 |
+
with torch.no_grad():
|
| 153 |
+
if text is not None and self.config["text_cond"]:
|
| 154 |
+
text_token = self.action_encoder(action_tensor, [text], self.tokenizer, self.text_encoder)
|
| 155 |
+
else:
|
| 156 |
+
text_token = self.action_encoder(action_tensor)
|
| 157 |
+
|
| 158 |
+
return text_token.detach()
|
| 159 |
+
|
| 160 |
+
def get_latent_views(self, frames, current_latent, text_token):
|
| 161 |
+
|
| 162 |
+
his_cond = torch.cat(frames, dim=0).unsqueeze(0) # (1, num_history, 4, stacked_H, W)
|
| 163 |
+
|
| 164 |
+
# ββ Run CtrlWorldDiffusionPipeline ββββββββββββββββββββ
|
| 165 |
+
with torch.no_grad():
|
| 166 |
+
_, latents = CtrlWorldDiffusionPipeline.__call__(
|
| 167 |
+
self.pipeline,
|
| 168 |
+
image=current_latent,
|
| 169 |
+
text=text_token,
|
| 170 |
+
width=self.config["width"],
|
| 171 |
+
height=int(self.config["height"] * 3), # 3 views stacked
|
| 172 |
+
num_frames=self.config["num_frames"],
|
| 173 |
+
history=his_cond,
|
| 174 |
+
num_inference_steps=self.config["num_inference_steps"],
|
| 175 |
+
decode_chunk_size=self.config["decode_chunk_size"],
|
| 176 |
+
max_guidance_scale=self.config["guidance_scale"],
|
| 177 |
+
fps=self.config["fps"],
|
| 178 |
+
motion_bucket_id=self.config["motion_bucket_id"],
|
| 179 |
+
mask=None,
|
| 180 |
+
output_type="latent",
|
| 181 |
+
return_dict=False,
|
| 182 |
+
frame_level_cond=True,
|
| 183 |
+
)
|
| 184 |
+
|
| 185 |
+
return latents
|