ctrl_world/droid/checkpoint-10000.pt ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:ed17de48180d4e6f89fd33c53e9fb7a0196189c1a67d44c2c486a279a80ea8a8
3
+ size 9281040326
ctrl_world/droid/config.json ADDED
@@ -0,0 +1,52 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "name": "droid_ctrl_world",
3
+ "env": "DROID",
4
+ "model_type": "ctrl_world",
5
+ "metadata": {
6
+ "num_history": 6,
7
+ "num_frames": 5,
8
+ "action_dim": 7
9
+ },
10
+ "util_folders":{
11
+ "models": "../src/models"
12
+ },
13
+ "models": [
14
+ {
15
+ "name": "world_model",
16
+ "framework": null,
17
+ "format": "state_dict",
18
+ "source": {
19
+ "weights_path": "checkpoint-10000.pt",
20
+ "class_path": "../src/world_model.py",
21
+ "class_name": "CrtlWorld",
22
+ "class_args": [
23
+ {
24
+ "svd_model_path": "stabilityai/stable-video-diffusion-img2vid",
25
+ "clip_model_path": "openai/clip-vit-base-patch32",
26
+ "num_history": 6,
27
+ "num_frames": 5,
28
+ "action_dim": 7,
29
+ "text_cond": true,
30
+ "motion_bucket_id": 127,
31
+ "fps": 7,
32
+ "guidance_scale": 1.0,
33
+ "num_inference_steps": 50,
34
+ "decode_chunk_size": 7,
35
+ "width": 320,
36
+ "height": 192
37
+ }]
38
+ },
39
+ "methods":
40
+ [
41
+ {
42
+ "name": "blocks_left_in_kv_cache",
43
+ "method_name": "blocks_left_in_kv_cache"
44
+ },
45
+ {
46
+ "name": "reset_kv_cache",
47
+ "method_name": "reset_kv_cache"
48
+ }
49
+ ]
50
+ }
51
+ ]
52
+ }
ctrl_world/src/models/pipeline_ctrl_world.py ADDED
@@ -0,0 +1,823 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from typing import Callable, Dict, List, Optional, Union
2
+ import torch
3
+ from einops import rearrange, repeat
4
+ import PIL
5
+ import einops
6
+
7
+ # from diffusers import TextToVideoSDPipeline, StableVideoDiffusionPipeline
8
+ from diffusers import TextToVideoSDPipeline
9
+ from models.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline
10
+
11
+
12
+ from diffusers.pipelines.text_to_video_synthesis.pipeline_text_to_video_synth import TextToVideoSDPipelineOutput
13
+ from diffusers.pipelines.stable_video_diffusion.pipeline_stable_video_diffusion import StableVideoDiffusionPipelineOutput
14
+ from diffusers.utils.torch_utils import randn_tensor
15
+
16
+ def svd_tensor2vid(video: torch.Tensor, processor, output_type="np"):
17
+ # Based on:
18
+ # https://github.com/modelscope/modelscope/blob/1509fdb973e5871f37148a4b5e5964cafd43e64d/modelscope/pipelines/multi_modal/text_to_video_synthesis_pipeline.py#L78
19
+
20
+ batch_size, channels, num_frames, height, width = video.shape
21
+ outputs = []
22
+ for batch_idx in range(batch_size):
23
+ batch_vid = video[batch_idx].permute(1, 0, 2, 3)
24
+ batch_output = processor.postprocess(batch_vid, output_type)
25
+
26
+ outputs.append(batch_output)
27
+
28
+ return outputs
29
+
30
+ class LatentToVideoPipeline(TextToVideoSDPipeline):
31
+ @torch.no_grad()
32
+ def __call__(
33
+ self,
34
+ prompt = None,
35
+ height= None,
36
+ width= None,
37
+ num_frames: int = 16,
38
+ num_inference_steps: int = 50,
39
+ guidance_scale= 9.0,
40
+ negative_prompt= None,
41
+ eta: float = 0.0,
42
+ generator= None,
43
+ latents= None,
44
+ prompt_embeds= None,
45
+ negative_prompt_embeds= None,
46
+ output_type= "np",
47
+ return_dict: bool = True,
48
+ callback= None,
49
+ callback_steps: int = 1,
50
+ cross_attention_kwargs= None,
51
+ condition_latent=None,
52
+ mask=None,
53
+ timesteps=None,
54
+ motion=None,
55
+ ):
56
+ r"""
57
+ Function invoked when calling the pipeline for generation.
58
+
59
+ Args:
60
+ prompt (`str` or `List[str]`, *optional*):
61
+ The prompt or prompts to guide the video generation. If not defined, one has to pass `prompt_embeds`.
62
+ instead.
63
+ height (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
64
+ The height in pixels of the generated video.
65
+ width (`int`, *optional*, defaults to self.unet.config.sample_size * self.vae_scale_factor):
66
+ The width in pixels of the generated video.
67
+ num_frames (`int`, *optional*, defaults to 16):
68
+ The number of video frames that are generated. Defaults to 16 frames which at 8 frames per seconds
69
+ amounts to 2 seconds of video.
70
+ num_inference_steps (`int`, *optional*, defaults to 50):
71
+ The number of denoising steps. More denoising steps usually lead to a higher quality videos at the
72
+ expense of slower inference.
73
+ guidance_scale (`float`, *optional*, defaults to 7.5):
74
+ Guidance scale as defined in [Classifier-Free Diffusion Guidance](https://arxiv.org/abs/2207.12598).
75
+ `guidance_scale` is defined as `w` of equation 2. of [Imagen
76
+ Paper](https://arxiv.org/pdf/2205.11487.pdf). Guidance scale is enabled by setting `guidance_scale >
77
+ 1`. Higher guidance scale encourages to generate videos that are closely linked to the text `prompt`,
78
+ usually at the expense of lower video quality.
79
+ negative_prompt (`str` or `List[str]`, *optional*):
80
+ The prompt or prompts not to guide the video generation. If not defined, one has to pass
81
+ `negative_prompt_embeds` instead. Ignored when not using guidance (i.e., ignored if `guidance_scale` is
82
+ less than `1`).
83
+ eta (`float`, *optional*, defaults to 0.0):
84
+ Corresponds to parameter eta (Ξ·) in the DDIM paper: https://arxiv.org/abs/2010.02502. Only applies to
85
+ [`schedulers.DDIMScheduler`], will be ignored for others.
86
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
87
+ One or a list of [torch generator(s)](https://pytorch.org/docs/stable/generated/torch.Generator.html)
88
+ to make generation deterministic.
89
+ latents (`torch.FloatTensor`, *optional*):
90
+ Pre-generated noisy latents, sampled from a Gaussian distribution, to be used as inputs for video
91
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
92
+ tensor will ge generated by sampling using the supplied random `generator`. Latents should be of shape
93
+ `(batch_size, num_channel, num_frames, height, width)`.
94
+ prompt_embeds (`torch.FloatTensor`, *optional*):
95
+ Pre-generated text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt weighting. If not
96
+ provided, text embeddings will be generated from `prompt` input argument.
97
+ negative_prompt_embeds (`torch.FloatTensor`, *optional*):
98
+ Pre-generated negative text embeddings. Can be used to easily tweak text inputs, *e.g.* prompt
99
+ weighting. If not provided, negative_prompt_embeds will be generated from `negative_prompt` input
100
+ argument.
101
+ output_type (`str`, *optional*, defaults to `"np"`):
102
+ The output format of the generate video. Choose between `torch.FloatTensor` or `np.array`.
103
+ return_dict (`bool`, *optional*, defaults to `True`):
104
+ Whether or not to return a [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] instead of a
105
+ plain tuple.
106
+ callback (`Callable`, *optional*):
107
+ A function that will be called every `callback_steps` steps during inference. The function will be
108
+ called with the following arguments: `callback(step: int, timestep: int, latents: torch.FloatTensor)`.
109
+ callback_steps (`int`, *optional*, defaults to 1):
110
+ The frequency at which the `callback` function will be called. If not specified, the callback will be
111
+ called at every step.
112
+ cross_attention_kwargs (`dict`, *optional*):
113
+ A kwargs dictionary that if specified is passed along to the `AttentionProcessor` as defined under
114
+ `self.processor` in
115
+ [diffusers.cross_attention](https://github.com/huggingface/diffusers/blob/main/src/diffusers/models/cross_attention.py).
116
+
117
+ Examples:
118
+
119
+ Returns:
120
+ [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] or `tuple`:
121
+ [`~pipelines.stable_diffusion.TextToVideoSDPipelineOutput`] if `return_dict` is True, otherwise a `tuple.
122
+ When returning a tuple, the first element is a list with the generated frames.
123
+ """
124
+ # 0. Default height and width to unet
125
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
126
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
127
+
128
+ num_images_per_prompt = 1
129
+
130
+ # 1. Check inputs. Raise error if not correct
131
+ self.check_inputs(
132
+ prompt, height, width, callback_steps, negative_prompt, prompt_embeds, negative_prompt_embeds
133
+ )
134
+
135
+ # 2. Define call parameters
136
+ if prompt is not None and isinstance(prompt, str):
137
+ batch_size = 1
138
+ elif prompt is not None and isinstance(prompt, list):
139
+ batch_size = len(prompt)
140
+ else:
141
+ batch_size = prompt_embeds.shape[0]
142
+
143
+ #device = self._execution_device
144
+ device = latents.device
145
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
146
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
147
+ # corresponds to doing no classifier free guidance.
148
+ do_classifier_free_guidance = guidance_scale > 1.0
149
+
150
+ # 3. Encode input prompt
151
+ text_encoder_lora_scale = (
152
+ cross_attention_kwargs.get("scale", None) if cross_attention_kwargs is not None else None
153
+ )
154
+ prompt_embeds = self._encode_prompt(
155
+ prompt,
156
+ device,
157
+ num_images_per_prompt,
158
+ do_classifier_free_guidance,
159
+ negative_prompt,
160
+ prompt_embeds=prompt_embeds,
161
+ negative_prompt_embeds=negative_prompt_embeds,
162
+ lora_scale=text_encoder_lora_scale,
163
+ )
164
+
165
+ # 4. Prepare timesteps
166
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
167
+ if timesteps is None:
168
+ timesteps = self.scheduler.timesteps
169
+ else:
170
+ num_inference_steps = len(timesteps)
171
+ # 5. Prepare latent variables. do nothing
172
+
173
+ # 6. Prepare extra step kwargs. TODO: Logic should ideally just be moved out of the pipeline
174
+ extra_step_kwargs = self.prepare_extra_step_kwargs(generator, eta)
175
+
176
+ # 7. Denoising loop
177
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
178
+ uncondition_latent = condition_latent
179
+ condition_latent = torch.cat([uncondition_latent, condition_latent]) if do_classifier_free_guidance else condition_latent
180
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
181
+ for i, t in enumerate(timesteps):
182
+ # expand the latents if we are doing classifier free guidance
183
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
184
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
185
+ if motion is not None:
186
+ motion = torch.tensor(motion, device=device)
187
+ noise_pred = self.unet(
188
+ latent_model_input,
189
+ t,
190
+ encoder_hidden_states=prompt_embeds,
191
+ cross_attention_kwargs=cross_attention_kwargs,
192
+ condition_latent=condition_latent,
193
+ mask=mask,
194
+ motion=motion
195
+ ).sample
196
+ # perform guidance
197
+ if do_classifier_free_guidance:
198
+ noise_pred_uncond, noise_pred_text = noise_pred.chunk(2)
199
+ noise_pred = noise_pred_uncond + guidance_scale * (noise_pred_text - noise_pred_uncond)
200
+
201
+ # reshape latents
202
+ bsz, channel, frames, width, height = latents.shape
203
+ latents = latents.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
204
+ noise_pred = noise_pred.permute(0, 2, 1, 3, 4).reshape(bsz * frames, channel, width, height)
205
+
206
+ # compute the previous noisy sample x_t -> x_t-1
207
+ latents = self.scheduler.step(noise_pred, t, latents, **extra_step_kwargs).prev_sample
208
+
209
+ # reshape latents back
210
+ latents = latents[None, :].reshape(bsz, frames, channel, width, height).permute(0, 2, 1, 3, 4)
211
+
212
+ # call the callback, if provided
213
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
214
+ progress_bar.update()
215
+ if callback is not None and i % callback_steps == 0:
216
+ callback(i, t, latents)
217
+
218
+ video_tensor = self.decode_latents(latents)
219
+
220
+ if output_type == "pt":
221
+ video = video_tensor
222
+ else:
223
+ video = tensor2vid(video_tensor)
224
+
225
+ # Offload last model to CPU
226
+ if hasattr(self, "final_offload_hook") and self.final_offload_hook is not None:
227
+ self.final_offload_hook.offload()
228
+
229
+ if not return_dict:
230
+ return (video, latents)
231
+
232
+ return TextToVideoSDPipelineOutput(frames=video)
233
+
234
+ def _append_dims(x, target_dims):
235
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
236
+ dims_to_append = target_dims - x.ndim
237
+ if dims_to_append < 0:
238
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
239
+ return x[(...,) + (None,) * dims_to_append]
240
+
241
+ class CtrlWorldDiffusionPipeline(StableVideoDiffusionPipeline):
242
+ @torch.no_grad()
243
+ def __call__(
244
+ self,
245
+ image,
246
+ text,
247
+ height: int = 576,
248
+ width: int = 1024,
249
+ num_frames: Optional[int] = None,
250
+ num_inference_steps: int = 25,
251
+ min_guidance_scale: float = 1.0,
252
+ max_guidance_scale: float = 3.0,
253
+ fps: int = 7,
254
+ motion_bucket_id: int = 127,
255
+ noise_aug_strength: int = 0.02,
256
+ decode_chunk_size: Optional[int] = None,
257
+ num_videos_per_prompt: Optional[int] = 1,
258
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
259
+ latents: Optional[torch.FloatTensor] = None,
260
+ output_type: Optional[str] = "pil",
261
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
262
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
263
+ return_dict: bool = True,
264
+ mask = None,
265
+ cond_wrist=None,
266
+ history=None,
267
+ frame_level_cond=False,
268
+ his_cond_zero=False,
269
+ ):
270
+ r"""
271
+ The call function to the pipeline for generation.
272
+
273
+ Args:
274
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
275
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
276
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
277
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
278
+ The height in pixels of the generated image.
279
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
280
+ The width in pixels of the generated image.
281
+ num_frames (`int`, *optional*):
282
+ The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
283
+ num_inference_steps (`int`, *optional*, defaults to 25):
284
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
285
+ expense of slower inference. This parameter is modulated by `strength`.
286
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
287
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
288
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
289
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
290
+ fps (`int`, *optional*, defaults to 7):
291
+ Frames per second. The rate at which the generated images shall be exported to a video after generation.
292
+ Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
293
+ motion_bucket_id (`int`, *optional*, defaults to 127):
294
+ The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
295
+ noise_aug_strength (`int`, *optional*, defaults to 0.02):
296
+ The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
297
+ decode_chunk_size (`int`, *optional*):
298
+ The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
299
+ between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
300
+ for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
301
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
302
+ The number of images to generate per prompt.
303
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
304
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
305
+ generation deterministic.
306
+ latents (`torch.FloatTensor`, *optional*):
307
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
308
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
309
+ tensor is generated by sampling using the supplied random `generator`.
310
+ output_type (`str`, *optional*, defaults to `"pil"`):
311
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
312
+ callback_on_step_end (`Callable`, *optional*):
313
+ A function that calls at the end of each denoising steps during the inference. The function is called
314
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
315
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
316
+ `callback_on_step_end_tensor_inputs`.
317
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
318
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
319
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
320
+ `._callback_tensor_inputs` attribute of your pipeline class.
321
+ return_dict (`bool`, *optional*, defaults to `True`):
322
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
323
+ plain tuple.
324
+
325
+ Returns:
326
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
327
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
328
+ otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
329
+
330
+ Examples:
331
+
332
+ ```py
333
+ from diffusers import StableVideoDiffusionPipeline
334
+ from diffusers.utils import load_image, export_to_video
335
+
336
+ pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
337
+ pipe.to("cuda")
338
+
339
+ image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
340
+ image = image.resize((1024, 576))
341
+
342
+ frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
343
+ export_to_video(frames, "generated.mp4", fps=7)
344
+ ```
345
+ """
346
+ # 0. Default height and width to unet
347
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
348
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
349
+
350
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
351
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
352
+ # device = self._execution_device
353
+ device = self.unet.device
354
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
355
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
356
+ # corresponds to doing no classifier free guidance.
357
+ do_classifier_free_guidance = max_guidance_scale > 1.0
358
+
359
+ # # 1. Check inputs. Raise error if not correct
360
+ # self.check_inputs(image, height, width)
361
+
362
+ # # 2. Define call parameters
363
+ # if isinstance(image, PIL.Image.Image):
364
+ # batch_size = 1
365
+ # elif isinstance(image, list):
366
+ # batch_size = len(image)
367
+ # else:
368
+ # batch_size = image.shape[0]
369
+ # # 3. Encode input image
370
+ # # clip_imgae = self.video_processor.preprocess(image, height=224, width=224)
371
+ # clip_image = _resize_with_antialiasing(image, (224, 224))
372
+ # image_embeddings = self._encode_image(clip_image, device, num_videos_per_prompt, do_classifier_free_guidance)
373
+ image_embeddings = text
374
+ batch_size = image_embeddings.shape[0]
375
+ if do_classifier_free_guidance:
376
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
377
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
378
+
379
+
380
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
381
+ # is why it is reduced here.
382
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
383
+ # fps = fps - 1 # we only use fps = 7 in train, so just set to 7
384
+
385
+ # 4. Encode input image using VAE
386
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
387
+ if image.shape[-3] == 3: # (batch, 3, 256, 256)
388
+ image = self.video_processor.preprocess(image, height=height, width=width)
389
+ noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
390
+ # image = image + noise_aug_strength * noise
391
+
392
+ if needs_upcasting:
393
+ self.vae.to(dtype=torch.float32)
394
+
395
+ image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
396
+ image_latents = image_latents.to(image_embeddings.dtype)
397
+
398
+ # cast back to fp16 if needed
399
+ if needs_upcasting:
400
+ self.vae.to(dtype=torch.float16)
401
+ else: # (batch, 4, 32, 32)
402
+ image_latents = image/self.vae.config.scaling_factor
403
+ if do_classifier_free_guidance:
404
+ # negative_image_latent = torch.zeros_like(image_latents)
405
+ # image_latents = torch.cat([negative_image_latent, image_latents])
406
+ image_latents = torch.cat([image_latents]*2)
407
+ image_latents = image_latents.to(image_embeddings.dtype)
408
+
409
+ # Repeat the image latents for each frame so we can concatenate them with the noise
410
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
411
+ if history is not None:
412
+ B, num_his, C, H, W = history.shape
413
+ num_frames_all = num_frames + num_his
414
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames_all, 1, 1, 1)
415
+ if his_cond_zero:
416
+ image_latents[:,:num_his] = 0.0 # set history to 0
417
+ else:
418
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
419
+ # mask = repeat(mask, '1 h w -> 2 f 1 h w', f=num_frames)
420
+ # 5. Get Added Time IDs
421
+ added_time_ids = self._get_add_time_ids(
422
+ fps,
423
+ motion_bucket_id,
424
+ noise_aug_strength,
425
+ image_embeddings.dtype,
426
+ batch_size,
427
+ num_videos_per_prompt,
428
+ do_classifier_free_guidance,
429
+ )
430
+ added_time_ids = added_time_ids.to(device)
431
+
432
+ # 4. Prepare timesteps
433
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
434
+ timesteps = self.scheduler.timesteps
435
+
436
+ # 5. Prepare latent variables
437
+ num_channels_latents = self.unet.config.in_channels
438
+ latents = self.prepare_latents(
439
+ batch_size * num_videos_per_prompt,
440
+ num_frames,
441
+ num_channels_latents,
442
+ height,
443
+ width,
444
+ image_embeddings.dtype,
445
+ device,
446
+ generator,
447
+ latents,
448
+ )
449
+
450
+ # 7. Prepare guidance scale
451
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
452
+ guidance_scale = guidance_scale.to(device, latents.dtype)
453
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
454
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
455
+
456
+ self._guidance_scale = guidance_scale
457
+
458
+ # 8. Denoising loop
459
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
460
+ self._num_timesteps = len(timesteps)
461
+ # print("prediction type",self.scheduler.config.prediction_type)
462
+ if cond_wrist is not None:
463
+ B,F, C, H, W = latents.shape
464
+ cond_wrist = einops.repeat(cond_wrist, 'b l c h w -> b (f l) (n c) h w', n=3,f=num_frames) # (B, 8, 12 , 24, 40)
465
+ cond_wrist = torch.cat([cond_wrist]*2) if do_classifier_free_guidance else cond_wrist
466
+
467
+ if history is not None:
468
+ history = torch.cat([history] * 2) if do_classifier_free_guidance else history
469
+
470
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
471
+ for i, t in enumerate(timesteps):
472
+ # expand the latents if we are doing classifier free guidance
473
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
474
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
475
+
476
+ if history is not None:
477
+ latent_model_input = torch.cat([history, latent_model_input], dim=1) # (bsz*2,frame+F,4,32,32)
478
+
479
+ # Concatenate image_latents over channels dimention
480
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
481
+
482
+ if cond_wrist is not None and i==0:
483
+ # print('cond_wrist_shape:',cond_wrist.shape, 'latent_model_input_shape:',latent_model_input.shape)
484
+ latent_model_input = torch.cat([latent_model_input, cond_wrist], dim=3) # (B, 8, 12, 96, 40)
485
+
486
+
487
+ # predict the noise residual
488
+ latent_model_input = latent_model_input.to(self.unet.dtype)
489
+ image_embeddings = image_embeddings.to(self.unet.dtype)
490
+ # print('extract_layer_idx:',extract_layer_idx)
491
+ # print('latent_model_input_shape:',latent_model_input.shape)
492
+ # print('encoder_hidden_states:',image_embeddings.shape)
493
+ # print('added_time_ids:',added_time_ids.shape)
494
+ noise_pred = self.unet(
495
+ latent_model_input,
496
+ t,
497
+ encoder_hidden_states=image_embeddings,
498
+ added_time_ids=added_time_ids,
499
+ return_dict=False,
500
+ frame_level_cond=frame_level_cond,
501
+ )[0]
502
+
503
+ if cond_wrist is not None:
504
+ noise_pred = noise_pred[:, :,:,:H, :W] # remove cond_wrist
505
+ if history is not None:
506
+ # print('history_shape:',history.shape)
507
+ # print('noise_pred_shape:',noise_pred.shape)
508
+ noise_pred = noise_pred[:, num_his:, :, :, :] # remove history
509
+
510
+ # perform guidance
511
+ if do_classifier_free_guidance:
512
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
513
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
514
+
515
+ # model_output = noise_pred
516
+ # # sigma = self.scheduler.get_sigma(t)
517
+ # # sigma = self.scheduler.sigmas[t]
518
+ # self.scheduler._init_step_index(t)
519
+ # sigma = self.scheduler.sigmas[self.scheduler.step_index]
520
+ # print("sigma", sigma)
521
+ # print(t)
522
+ # pred_original_sample = model_output * (-sigma / (sigma**2 + 1) ** 0.5) + (latents / (sigma**2 + 1))
523
+ # print(pred_original_sample.shape)
524
+ # latents = pred_original_sample
525
+ # # return pred_original_sample
526
+ # break
527
+
528
+ # compute the previous noisy sample x_t -> x_t-1
529
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
530
+
531
+ if callback_on_step_end is not None:
532
+ callback_kwargs = {}
533
+ for k in callback_on_step_end_tensor_inputs:
534
+ callback_kwargs[k] = locals()[k]
535
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
536
+
537
+ latents = callback_outputs.pop("latents", latents)
538
+
539
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
540
+ progress_bar.update()
541
+
542
+ if not output_type == "latent":
543
+ # cast back to fp16 if needed
544
+ if needs_upcasting:
545
+ self.vae.to(dtype=torch.float16)
546
+ # latents = latents/self.vae.config.scaling_factor
547
+ latents = latents.to(self.vae.dtype)
548
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
549
+ frames = svd_tensor2vid(frames, self.video_processor, output_type=output_type)
550
+ else:
551
+ frames = latents
552
+
553
+ self.maybe_free_model_hooks()
554
+
555
+ if not return_dict:
556
+ return frames,latents
557
+
558
+ return StableVideoDiffusionPipelineOutput(frames=frames)
559
+
560
+ class TextStableVideoDiffusionPipeline(StableVideoDiffusionPipeline):
561
+ @torch.no_grad()
562
+ def __call__(
563
+ self,
564
+ image,
565
+ prompt_embeds = None,
566
+ negative_prompt_embeds = None,
567
+ height: int = 576,
568
+ width: int = 1024,
569
+ num_frames: Optional[int] = None,
570
+ num_inference_steps: int = 25,
571
+ min_guidance_scale: float = 1.0,
572
+ max_guidance_scale: float = 3.0,
573
+ fps: int = 7,
574
+ motion_bucket_id: int = 127,
575
+ noise_aug_strength: int = 0.02,
576
+ decode_chunk_size: Optional[int] = None,
577
+ num_videos_per_prompt: Optional[int] = 1,
578
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
579
+ latents: Optional[torch.FloatTensor] = None,
580
+ output_type: Optional[str] = "pil",
581
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
582
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
583
+ return_dict: bool = True,
584
+ mask = None,
585
+ condition_type = "image",
586
+ condition_latent = None,
587
+ ):
588
+ r"""
589
+ The call function to the pipeline for generation.
590
+
591
+ Args:
592
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.FloatTensor`):
593
+ Image or images to guide image generation. If you provide a tensor, it needs to be compatible with
594
+ [`CLIPImageProcessor`](https://huggingface.co/lambdalabs/sd-image-variations-diffusers/blob/main/feature_extractor/preprocessor_config.json).
595
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
596
+ The height in pixels of the generated image.
597
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
598
+ The width in pixels of the generated image.
599
+ num_frames (`int`, *optional*):
600
+ The number of video frames to generate. Defaults to 14 for `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`
601
+ num_inference_steps (`int`, *optional*, defaults to 25):
602
+ The number of denoising steps. More denoising steps usually lead to a higher quality image at the
603
+ expense of slower inference. This parameter is modulated by `strength`.
604
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
605
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
606
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
607
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
608
+ fps (`int`, *optional*, defaults to 7):
609
+ Frames per second. The rate at which the generated images shall be exported to a video after generation.
610
+ Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
611
+ motion_bucket_id (`int`, *optional*, defaults to 127):
612
+ The motion bucket ID. Used as conditioning for the generation. The higher the number the more motion will be in the video.
613
+ noise_aug_strength (`int`, *optional*, defaults to 0.02):
614
+ The amount of noise added to the init image, the higher it is the less the video will look like the init image. Increase it for more motion.
615
+ decode_chunk_size (`int`, *optional*):
616
+ The number of frames to decode at a time. The higher the chunk size, the higher the temporal consistency
617
+ between frames, but also the higher the memory consumption. By default, the decoder will decode all frames at once
618
+ for maximal quality. Reduce `decode_chunk_size` to reduce memory usage.
619
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
620
+ The number of images to generate per prompt.
621
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
622
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
623
+ generation deterministic.
624
+ latents (`torch.FloatTensor`, *optional*):
625
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for image
626
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
627
+ tensor is generated by sampling using the supplied random `generator`.
628
+ output_type (`str`, *optional*, defaults to `"pil"`):
629
+ The output format of the generated image. Choose between `PIL.Image` or `np.array`.
630
+ callback_on_step_end (`Callable`, *optional*):
631
+ A function that calls at the end of each denoising steps during the inference. The function is called
632
+ with the following arguments: `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int,
633
+ callback_kwargs: Dict)`. `callback_kwargs` will include a list of all tensors as specified by
634
+ `callback_on_step_end_tensor_inputs`.
635
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
636
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
637
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
638
+ `._callback_tensor_inputs` attribute of your pipeline class.
639
+ return_dict (`bool`, *optional*, defaults to `True`):
640
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
641
+ plain tuple.
642
+
643
+ Returns:
644
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
645
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is returned,
646
+ otherwise a `tuple` is returned where the first element is a list of list with the generated frames.
647
+
648
+ Examples:
649
+
650
+ ```py
651
+ from diffusers import StableVideoDiffusionPipeline
652
+ from diffusers.utils import load_image, export_to_video
653
+
654
+ pipe = StableVideoDiffusionPipeline.from_pretrained("stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16")
655
+ pipe.to("cuda")
656
+
657
+ image = load_image("https://lh3.googleusercontent.com/y-iFOHfLTwkuQSUegpwDdgKmOjRSTvPxat63dQLB25xkTs4lhIbRUFeNBWZzYf370g=s1200")
658
+ image = image.resize((1024, 576))
659
+
660
+ frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
661
+ export_to_video(frames, "generated.mp4", fps=7)
662
+ ```
663
+ """
664
+ # 0. Default height and width to unet
665
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
666
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
667
+
668
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
669
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
670
+
671
+ # 1. Check inputs. Raise error if not correct
672
+ self.check_inputs(image, height, width)
673
+
674
+ # 2. Define call parameters
675
+ if isinstance(image, PIL.Image.Image):
676
+ batch_size = 1
677
+ elif isinstance(image, list):
678
+ batch_size = len(image)
679
+ else:
680
+ batch_size = image.shape[0]
681
+ device = self._execution_device
682
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
683
+ # of the Imagen paper: https://arxiv.org/pdf/2205.11487.pdf . `guidance_scale = 1`
684
+ # corresponds to doing no classifier free guidance.
685
+ do_classifier_free_guidance = max_guidance_scale > 1.0
686
+
687
+ # 3. Encode input image
688
+ if condition_type=="image":
689
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
690
+ elif condition_type=="text":
691
+ if do_classifier_free_guidance:
692
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
693
+ image_embeddings = prompt_embeds
694
+ else:
695
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
696
+ if do_classifier_free_guidance:
697
+ prompt_embeds = torch.cat([negative_prompt_embeds, prompt_embeds])
698
+ image_embeddings = torch.cat([image_embeddings, prompt_embeds], dim=1)
699
+ motion_mask = self.unet.config.in_channels == 9
700
+ if do_classifier_free_guidance:
701
+ mask = torch.cat([mask]*2)
702
+ # NOTE: Stable Diffusion Video was conditioned on fps - 1, which
703
+ # is why it is reduced here.
704
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
705
+ fps = fps - 1
706
+
707
+ # 4. Encode input image using VAE
708
+ image = self.video_processor.preprocess(image, height=height, width=width)
709
+ noise = randn_tensor(image.shape, generator=generator, device=image.device, dtype=image.dtype)
710
+ image = image + noise_aug_strength * noise
711
+
712
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
713
+ if needs_upcasting:
714
+ self.vae.to(dtype=torch.float32)
715
+
716
+ if condition_latent is None:
717
+ image_latents = self._encode_vae_image(image, device, num_videos_per_prompt, do_classifier_free_guidance)
718
+ image_latents = image_latents.to(image_embeddings.dtype)
719
+
720
+ # Repeat the image latents for each frame so we can concatenate them with the noise
721
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
722
+ condition_latent = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
723
+ else:
724
+ if do_classifier_free_guidance:
725
+ condition_latent = torch.cat([condition_latent] * 2)
726
+ # 5. Get Added Time IDs
727
+
728
+ # cast back to fp16 if needed
729
+ if needs_upcasting:
730
+ self.vae.to(dtype=torch.float16)
731
+
732
+ added_time_ids = self._get_add_time_ids(
733
+ fps,
734
+ motion_bucket_id,
735
+ noise_aug_strength,
736
+ image_embeddings.dtype,
737
+ batch_size,
738
+ num_videos_per_prompt,
739
+ do_classifier_free_guidance,
740
+ )
741
+ added_time_ids = added_time_ids.to(device)
742
+
743
+ # 4. Prepare timesteps
744
+ self.scheduler.set_timesteps(num_inference_steps, device=device)
745
+ timesteps = self.scheduler.timesteps
746
+
747
+ # 5. Prepare latent variables
748
+ num_channels_latents = self.unet.config.in_channels
749
+ latents = self.prepare_latents(
750
+ batch_size * num_videos_per_prompt,
751
+ num_frames,
752
+ num_channels_latents,
753
+ height,
754
+ width,
755
+ image_embeddings.dtype,
756
+ device,
757
+ generator,
758
+ latents,
759
+ )
760
+
761
+ # 7. Prepare guidance scale
762
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
763
+ guidance_scale = guidance_scale.to(device, latents.dtype)
764
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
765
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
766
+
767
+ self._guidance_scale = guidance_scale
768
+
769
+ # 8. Denoising loop
770
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
771
+ self._num_timesteps = len(timesteps)
772
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
773
+ for i, t in enumerate(timesteps):
774
+ # expand the latents if we are doing classifier free guidance
775
+ latent_model_input = torch.cat([latents] * 2) if do_classifier_free_guidance else latents
776
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
777
+
778
+ if motion_mask:
779
+ # Concatenate image_latents over channels dimention
780
+ latent_model_input = torch.cat([mask, latent_model_input, condition_latent], dim=2)
781
+ else:
782
+ latent_model_input = torch.cat([latent_model_input, condition_latent], dim=2)
783
+ # predict the noise residual
784
+ noise_pred = self.unet(
785
+ latent_model_input,
786
+ t,
787
+ encoder_hidden_states=image_embeddings,
788
+ added_time_ids=added_time_ids,
789
+ return_dict=False,
790
+ )[0]
791
+
792
+ # perform guidance
793
+ if do_classifier_free_guidance:
794
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
795
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
796
+ # compute the previous noisy sample x_t -> x_t-1
797
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
798
+ if callback_on_step_end is not None:
799
+ callback_kwargs = {}
800
+ for k in callback_on_step_end_tensor_inputs:
801
+ callback_kwargs[k] = locals()[k]
802
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
803
+
804
+ latents = callback_outputs.pop("latents", latents)
805
+
806
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
807
+ progress_bar.update()
808
+
809
+ if not output_type == "latent":
810
+ # cast back to fp16 if needed
811
+ if needs_upcasting:
812
+ self.vae.to(dtype=torch.float16)
813
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
814
+ frames = svd_tensor2vid(frames, self.video_processor, output_type=output_type)
815
+ else:
816
+ frames = latents
817
+
818
+ self.maybe_free_model_hooks()
819
+
820
+ if not return_dict:
821
+ return frames
822
+
823
+ return StableVideoDiffusionPipelineOutput(frames=frames)
ctrl_world/src/models/pipeline_stable_video_diffusion.py ADDED
@@ -0,0 +1,742 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Copyright 2025 The HuggingFace Team. All rights reserved.
2
+ #
3
+ # Licensed under the Apache License, Version 2.0 (the "License");
4
+ # you may not use this file except in compliance with the License.
5
+ # You may obtain a copy of the License at
6
+ #
7
+ # http://www.apache.org/licenses/LICENSE-2.0
8
+ #
9
+ # Unless required by applicable law or agreed to in writing, software
10
+ # distributed under the License is distributed on an "AS IS" BASIS,
11
+ # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12
+ # See the License for the specific language governing permissions and
13
+ # limitations under the License.
14
+
15
+ import inspect
16
+ from dataclasses import dataclass
17
+ from typing import Callable, Dict, List, Optional, Union
18
+
19
+ import numpy as np
20
+ import PIL.Image
21
+ import torch
22
+ from transformers import CLIPImageProcessor, CLIPVisionModelWithProjection
23
+
24
+ from diffusers.image_processor import PipelineImageInput
25
+
26
+ # import from our own models instead of diffusers
27
+ # from diffusers.models import AutoencoderKLTemporalDecoder, UNetSpatioTemporalConditionModel
28
+ from diffusers.models import AutoencoderKLTemporalDecoder
29
+ from models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
30
+
31
+ from diffusers.schedulers import EulerDiscreteScheduler
32
+ from diffusers.utils import BaseOutput, is_torch_xla_available, logging, replace_example_docstring
33
+ from diffusers.utils.torch_utils import is_compiled_module, randn_tensor
34
+ from diffusers.video_processor import VideoProcessor
35
+ from diffusers.pipelines.pipeline_utils import DiffusionPipeline
36
+
37
+
38
+ if is_torch_xla_available():
39
+ import torch_xla.core.xla_model as xm
40
+
41
+ XLA_AVAILABLE = True
42
+ else:
43
+ XLA_AVAILABLE = False
44
+
45
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
46
+
47
+
48
+ EXAMPLE_DOC_STRING = """
49
+ Examples:
50
+ ```py
51
+ >>> from diffusers import StableVideoDiffusionPipeline
52
+ >>> from diffusers.utils import load_image, export_to_video
53
+
54
+ >>> pipe = StableVideoDiffusionPipeline.from_pretrained(
55
+ ... "stabilityai/stable-video-diffusion-img2vid-xt", torch_dtype=torch.float16, variant="fp16"
56
+ ... )
57
+ >>> pipe.to("cuda")
58
+
59
+ >>> image = load_image(
60
+ ... "https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/diffusers/svd-docstring-example.jpeg"
61
+ ... )
62
+ >>> image = image.resize((1024, 576))
63
+
64
+ >>> frames = pipe(image, num_frames=25, decode_chunk_size=8).frames[0]
65
+ >>> export_to_video(frames, "generated.mp4", fps=7)
66
+ ```
67
+ """
68
+
69
+
70
+ def _append_dims(x, target_dims):
71
+ """Appends dimensions to the end of a tensor until it has target_dims dimensions."""
72
+ dims_to_append = target_dims - x.ndim
73
+ if dims_to_append < 0:
74
+ raise ValueError(f"input has {x.ndim} dims but target_dims is {target_dims}, which is less")
75
+ return x[(...,) + (None,) * dims_to_append]
76
+
77
+
78
+ # Copied from diffusers.pipelines.stable_diffusion.pipeline_stable_diffusion.retrieve_timesteps
79
+ def retrieve_timesteps(
80
+ scheduler,
81
+ num_inference_steps: Optional[int] = None,
82
+ device: Optional[Union[str, torch.device]] = None,
83
+ timesteps: Optional[List[int]] = None,
84
+ sigmas: Optional[List[float]] = None,
85
+ **kwargs,
86
+ ):
87
+ r"""
88
+ Calls the scheduler's `set_timesteps` method and retrieves timesteps from the scheduler after the call. Handles
89
+ custom timesteps. Any kwargs will be supplied to `scheduler.set_timesteps`.
90
+
91
+ Args:
92
+ scheduler (`SchedulerMixin`):
93
+ The scheduler to get timesteps from.
94
+ num_inference_steps (`int`):
95
+ The number of diffusion steps used when generating samples with a pre-trained model. If used, `timesteps`
96
+ must be `None`.
97
+ device (`str` or `torch.device`, *optional*):
98
+ The device to which the timesteps should be moved to. If `None`, the timesteps are not moved.
99
+ timesteps (`List[int]`, *optional*):
100
+ Custom timesteps used to override the timestep spacing strategy of the scheduler. If `timesteps` is passed,
101
+ `num_inference_steps` and `sigmas` must be `None`.
102
+ sigmas (`List[float]`, *optional*):
103
+ Custom sigmas used to override the timestep spacing strategy of the scheduler. If `sigmas` is passed,
104
+ `num_inference_steps` and `timesteps` must be `None`.
105
+
106
+ Returns:
107
+ `Tuple[torch.Tensor, int]`: A tuple where the first element is the timestep schedule from the scheduler and the
108
+ second element is the number of inference steps.
109
+ """
110
+ if timesteps is not None and sigmas is not None:
111
+ raise ValueError("Only one of `timesteps` or `sigmas` can be passed. Please choose one to set custom values")
112
+ if timesteps is not None:
113
+ accepts_timesteps = "timesteps" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
114
+ if not accepts_timesteps:
115
+ raise ValueError(
116
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
117
+ f" timestep schedules. Please check whether you are using the correct scheduler."
118
+ )
119
+ scheduler.set_timesteps(timesteps=timesteps, device=device, **kwargs)
120
+ timesteps = scheduler.timesteps
121
+ num_inference_steps = len(timesteps)
122
+ elif sigmas is not None:
123
+ accept_sigmas = "sigmas" in set(inspect.signature(scheduler.set_timesteps).parameters.keys())
124
+ if not accept_sigmas:
125
+ raise ValueError(
126
+ f"The current scheduler class {scheduler.__class__}'s `set_timesteps` does not support custom"
127
+ f" sigmas schedules. Please check whether you are using the correct scheduler."
128
+ )
129
+ scheduler.set_timesteps(sigmas=sigmas, device=device, **kwargs)
130
+ timesteps = scheduler.timesteps
131
+ num_inference_steps = len(timesteps)
132
+ else:
133
+ scheduler.set_timesteps(num_inference_steps, device=device, **kwargs)
134
+ timesteps = scheduler.timesteps
135
+ return timesteps, num_inference_steps
136
+
137
+
138
+ @dataclass
139
+ class StableVideoDiffusionPipelineOutput(BaseOutput):
140
+ r"""
141
+ Output class for Stable Video Diffusion pipeline.
142
+
143
+ Args:
144
+ frames (`[List[List[PIL.Image.Image]]`, `np.ndarray`, `torch.Tensor`]):
145
+ List of denoised PIL images of length `batch_size` or numpy array or torch tensor of shape `(batch_size,
146
+ num_frames, height, width, num_channels)`.
147
+ """
148
+
149
+ frames: Union[List[List[PIL.Image.Image]], np.ndarray, torch.Tensor]
150
+
151
+
152
+ class StableVideoDiffusionPipeline(DiffusionPipeline):
153
+ r"""
154
+ Pipeline to generate video from an input image using Stable Video Diffusion.
155
+
156
+ This model inherits from [`DiffusionPipeline`]. Check the superclass documentation for the generic methods
157
+ implemented for all pipelines (downloading, saving, running on a particular device, etc.).
158
+
159
+ Args:
160
+ vae ([`AutoencoderKLTemporalDecoder`]):
161
+ Variational Auto-Encoder (VAE) model to encode and decode images to and from latent representations.
162
+ image_encoder ([`~transformers.CLIPVisionModelWithProjection`]):
163
+ Frozen CLIP image-encoder
164
+ ([laion/CLIP-ViT-H-14-laion2B-s32B-b79K](https://huggingface.co/laion/CLIP-ViT-H-14-laion2B-s32B-b79K)).
165
+ unet ([`UNetSpatioTemporalConditionModel`]):
166
+ A `UNetSpatioTemporalConditionModel` to denoise the encoded image latents.
167
+ scheduler ([`EulerDiscreteScheduler`]):
168
+ A scheduler to be used in combination with `unet` to denoise the encoded image latents.
169
+ feature_extractor ([`~transformers.CLIPImageProcessor`]):
170
+ A `CLIPImageProcessor` to extract features from generated images.
171
+ """
172
+
173
+ model_cpu_offload_seq = "image_encoder->unet->vae"
174
+ _callback_tensor_inputs = ["latents"]
175
+
176
+ def __init__(
177
+ self,
178
+ vae: AutoencoderKLTemporalDecoder,
179
+ image_encoder: CLIPVisionModelWithProjection,
180
+ unet: UNetSpatioTemporalConditionModel,
181
+ scheduler: EulerDiscreteScheduler,
182
+ feature_extractor: CLIPImageProcessor,
183
+ ):
184
+ super().__init__()
185
+
186
+ self.register_modules(
187
+ vae=vae,
188
+ image_encoder=image_encoder,
189
+ unet=unet,
190
+ scheduler=scheduler,
191
+ feature_extractor=feature_extractor,
192
+ )
193
+ self.vae_scale_factor = 2 ** (len(self.vae.config.block_out_channels) - 1) if getattr(self, "vae", None) else 8
194
+ self.video_processor = VideoProcessor(do_resize=True, vae_scale_factor=self.vae_scale_factor)
195
+
196
+ def _encode_image(
197
+ self,
198
+ image: PipelineImageInput,
199
+ device: Union[str, torch.device],
200
+ num_videos_per_prompt: int,
201
+ do_classifier_free_guidance: bool,
202
+ ) -> torch.Tensor:
203
+ dtype = next(self.image_encoder.parameters()).dtype
204
+
205
+ if not isinstance(image, torch.Tensor):
206
+ image = self.video_processor.pil_to_numpy(image)
207
+ image = self.video_processor.numpy_to_pt(image)
208
+
209
+ # We normalize the image before resizing to match with the original implementation.
210
+ # Then we unnormalize it after resizing.
211
+ image = image * 2.0 - 1.0
212
+ image = _resize_with_antialiasing(image, (224, 224))
213
+ image = (image + 1.0) / 2.0
214
+
215
+ # Normalize the image with for CLIP input
216
+ image = self.feature_extractor(
217
+ images=image,
218
+ do_normalize=True,
219
+ do_center_crop=False,
220
+ do_resize=False,
221
+ do_rescale=False,
222
+ return_tensors="pt",
223
+ ).pixel_values
224
+
225
+ image = image.to(device=device, dtype=dtype)
226
+ image_embeddings = self.image_encoder(image).image_embeds
227
+ image_embeddings = image_embeddings.unsqueeze(1)
228
+
229
+ # duplicate image embeddings for each generation per prompt, using mps friendly method
230
+ bs_embed, seq_len, _ = image_embeddings.shape
231
+ image_embeddings = image_embeddings.repeat(1, num_videos_per_prompt, 1)
232
+ image_embeddings = image_embeddings.view(bs_embed * num_videos_per_prompt, seq_len, -1)
233
+
234
+ if do_classifier_free_guidance:
235
+ negative_image_embeddings = torch.zeros_like(image_embeddings)
236
+
237
+ # For classifier free guidance, we need to do two forward passes.
238
+ # Here we concatenate the unconditional and text embeddings into a single batch
239
+ # to avoid doing two forward passes
240
+ image_embeddings = torch.cat([negative_image_embeddings, image_embeddings])
241
+
242
+ return image_embeddings
243
+
244
+ def _encode_vae_image(
245
+ self,
246
+ image: torch.Tensor,
247
+ device: Union[str, torch.device],
248
+ num_videos_per_prompt: int,
249
+ do_classifier_free_guidance: bool,
250
+ ):
251
+ image = image.to(device=device)
252
+ image_latents = self.vae.encode(image).latent_dist.mode()
253
+
254
+ # duplicate image_latents for each generation per prompt, using mps friendly method
255
+ image_latents = image_latents.repeat(num_videos_per_prompt, 1, 1, 1)
256
+
257
+ if do_classifier_free_guidance:
258
+ negative_image_latents = torch.zeros_like(image_latents)
259
+
260
+ # For classifier free guidance, we need to do two forward passes.
261
+ # Here we concatenate the unconditional and text embeddings into a single batch
262
+ # to avoid doing two forward passes
263
+ image_latents = torch.cat([negative_image_latents, image_latents])
264
+
265
+ return image_latents
266
+
267
+ def _get_add_time_ids(
268
+ self,
269
+ fps: int,
270
+ motion_bucket_id: int,
271
+ noise_aug_strength: float,
272
+ dtype: torch.dtype,
273
+ batch_size: int,
274
+ num_videos_per_prompt: int,
275
+ do_classifier_free_guidance: bool,
276
+ ):
277
+ add_time_ids = [fps, motion_bucket_id, noise_aug_strength]
278
+
279
+ passed_add_embed_dim = self.unet.config.addition_time_embed_dim * len(add_time_ids)
280
+ expected_add_embed_dim = self.unet.add_embedding.linear_1.in_features
281
+
282
+ if expected_add_embed_dim != passed_add_embed_dim:
283
+ raise ValueError(
284
+ f"Model expects an added time embedding vector of length {expected_add_embed_dim}, but a vector of {passed_add_embed_dim} was created. The model has an incorrect config. Please check `unet.config.time_embedding_type` and `text_encoder_2.config.projection_dim`."
285
+ )
286
+
287
+ add_time_ids = torch.tensor([add_time_ids], dtype=dtype)
288
+ add_time_ids = add_time_ids.repeat(batch_size * num_videos_per_prompt, 1)
289
+
290
+ if do_classifier_free_guidance:
291
+ add_time_ids = torch.cat([add_time_ids, add_time_ids])
292
+
293
+ return add_time_ids
294
+
295
+ def decode_latents(self, latents: torch.Tensor, num_frames: int, decode_chunk_size: int = 14):
296
+ # [batch, frames, channels, height, width] -> [batch*frames, channels, height, width]
297
+ latents = latents.flatten(0, 1)
298
+
299
+ latents = 1 / self.vae.config.scaling_factor * latents
300
+
301
+ forward_vae_fn = self.vae._orig_mod.forward if is_compiled_module(self.vae) else self.vae.forward
302
+ accepts_num_frames = "num_frames" in set(inspect.signature(forward_vae_fn).parameters.keys())
303
+
304
+ # decode decode_chunk_size frames at a time to avoid OOM
305
+ frames = []
306
+ for i in range(0, latents.shape[0], decode_chunk_size):
307
+ num_frames_in = latents[i : i + decode_chunk_size].shape[0]
308
+ decode_kwargs = {}
309
+ if accepts_num_frames:
310
+ # we only pass num_frames_in if it's expected
311
+ decode_kwargs["num_frames"] = num_frames_in
312
+
313
+ frame = self.vae.decode(latents[i : i + decode_chunk_size], **decode_kwargs).sample
314
+ frames.append(frame)
315
+ frames = torch.cat(frames, dim=0)
316
+
317
+ # [batch*frames, channels, height, width] -> [batch, channels, frames, height, width]
318
+ frames = frames.reshape(-1, num_frames, *frames.shape[1:]).permute(0, 2, 1, 3, 4)
319
+
320
+ # we always cast to float32 as this does not cause significant overhead and is compatible with bfloat16
321
+ frames = frames.float()
322
+ return frames
323
+
324
+ def check_inputs(self, image, height, width):
325
+ if (
326
+ not isinstance(image, torch.Tensor)
327
+ and not isinstance(image, PIL.Image.Image)
328
+ and not isinstance(image, list)
329
+ ):
330
+ raise ValueError(
331
+ "`image` has to be of type `torch.Tensor` or `PIL.Image.Image` or `List[PIL.Image.Image]` but is"
332
+ f" {type(image)}"
333
+ )
334
+
335
+ if height % 8 != 0 or width % 8 != 0:
336
+ raise ValueError(f"`height` and `width` have to be divisible by 8 but are {height} and {width}.")
337
+
338
+ def prepare_latents(
339
+ self,
340
+ batch_size: int,
341
+ num_frames: int,
342
+ num_channels_latents: int,
343
+ height: int,
344
+ width: int,
345
+ dtype: torch.dtype,
346
+ device: Union[str, torch.device],
347
+ generator: torch.Generator,
348
+ latents: Optional[torch.Tensor] = None,
349
+ ):
350
+ shape = (
351
+ batch_size,
352
+ num_frames,
353
+ num_channels_latents // 2,
354
+ height // self.vae_scale_factor,
355
+ width // self.vae_scale_factor,
356
+ )
357
+ if isinstance(generator, list) and len(generator) != batch_size:
358
+ raise ValueError(
359
+ f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
360
+ f" size of {batch_size}. Make sure the batch size matches the length of the generators."
361
+ )
362
+
363
+ if latents is None:
364
+ latents = randn_tensor(shape, generator=generator, device=device, dtype=dtype)
365
+ else:
366
+ latents = latents.to(device)
367
+
368
+ # scale the initial noise by the standard deviation required by the scheduler
369
+ latents = latents * self.scheduler.init_noise_sigma
370
+ return latents
371
+
372
+ @property
373
+ def guidance_scale(self):
374
+ return self._guidance_scale
375
+
376
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
377
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
378
+ # corresponds to doing no classifier free guidance.
379
+ @property
380
+ def do_classifier_free_guidance(self):
381
+ if isinstance(self.guidance_scale, (int, float)):
382
+ return self.guidance_scale > 1
383
+ return self.guidance_scale.max() > 1
384
+
385
+ @property
386
+ def num_timesteps(self):
387
+ return self._num_timesteps
388
+
389
+ @torch.no_grad()
390
+ @replace_example_docstring(EXAMPLE_DOC_STRING)
391
+ def __call__(
392
+ self,
393
+ image: Union[PIL.Image.Image, List[PIL.Image.Image], torch.Tensor],
394
+ height: int = 576,
395
+ width: int = 1024,
396
+ num_frames: Optional[int] = None,
397
+ num_inference_steps: int = 25,
398
+ sigmas: Optional[List[float]] = None,
399
+ min_guidance_scale: float = 1.0,
400
+ max_guidance_scale: float = 3.0,
401
+ fps: int = 7,
402
+ motion_bucket_id: int = 127,
403
+ noise_aug_strength: float = 0.02,
404
+ decode_chunk_size: Optional[int] = None,
405
+ num_videos_per_prompt: Optional[int] = 1,
406
+ generator: Optional[Union[torch.Generator, List[torch.Generator]]] = None,
407
+ latents: Optional[torch.Tensor] = None,
408
+ output_type: Optional[str] = "pil",
409
+ callback_on_step_end: Optional[Callable[[int, int, Dict], None]] = None,
410
+ callback_on_step_end_tensor_inputs: List[str] = ["latents"],
411
+ return_dict: bool = True,
412
+ ):
413
+ r"""
414
+ The call function to the pipeline for generation.
415
+
416
+ Args:
417
+ image (`PIL.Image.Image` or `List[PIL.Image.Image]` or `torch.Tensor`):
418
+ Image(s) to guide image generation. If you provide a tensor, the expected value range is between `[0,
419
+ 1]`.
420
+ height (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
421
+ The height in pixels of the generated image.
422
+ width (`int`, *optional*, defaults to `self.unet.config.sample_size * self.vae_scale_factor`):
423
+ The width in pixels of the generated image.
424
+ num_frames (`int`, *optional*):
425
+ The number of video frames to generate. Defaults to `self.unet.config.num_frames` (14 for
426
+ `stable-video-diffusion-img2vid` and to 25 for `stable-video-diffusion-img2vid-xt`).
427
+ num_inference_steps (`int`, *optional*, defaults to 25):
428
+ The number of denoising steps. More denoising steps usually lead to a higher quality video at the
429
+ expense of slower inference. This parameter is modulated by `strength`.
430
+ sigmas (`List[float]`, *optional*):
431
+ Custom sigmas to use for the denoising process with schedulers which support a `sigmas` argument in
432
+ their `set_timesteps` method. If not defined, the default behavior when `num_inference_steps` is passed
433
+ will be used.
434
+ min_guidance_scale (`float`, *optional*, defaults to 1.0):
435
+ The minimum guidance scale. Used for the classifier free guidance with first frame.
436
+ max_guidance_scale (`float`, *optional*, defaults to 3.0):
437
+ The maximum guidance scale. Used for the classifier free guidance with last frame.
438
+ fps (`int`, *optional*, defaults to 7):
439
+ Frames per second. The rate at which the generated images shall be exported to a video after
440
+ generation. Note that Stable Diffusion Video's UNet was micro-conditioned on fps-1 during training.
441
+ motion_bucket_id (`int`, *optional*, defaults to 127):
442
+ Used for conditioning the amount of motion for the generation. The higher the number the more motion
443
+ will be in the video.
444
+ noise_aug_strength (`float`, *optional*, defaults to 0.02):
445
+ The amount of noise added to the init image, the higher it is the less the video will look like the
446
+ init image. Increase it for more motion.
447
+ decode_chunk_size (`int`, *optional*):
448
+ The number of frames to decode at a time. Higher chunk size leads to better temporal consistency at the
449
+ expense of more memory usage. By default, the decoder decodes all frames at once for maximal quality.
450
+ For lower memory usage, reduce `decode_chunk_size`.
451
+ num_videos_per_prompt (`int`, *optional*, defaults to 1):
452
+ The number of videos to generate per prompt.
453
+ generator (`torch.Generator` or `List[torch.Generator]`, *optional*):
454
+ A [`torch.Generator`](https://pytorch.org/docs/stable/generated/torch.Generator.html) to make
455
+ generation deterministic.
456
+ latents (`torch.Tensor`, *optional*):
457
+ Pre-generated noisy latents sampled from a Gaussian distribution, to be used as inputs for video
458
+ generation. Can be used to tweak the same generation with different prompts. If not provided, a latents
459
+ tensor is generated by sampling using the supplied random `generator`.
460
+ output_type (`str`, *optional*, defaults to `"pil"`):
461
+ The output format of the generated image. Choose between `pil`, `np` or `pt`.
462
+ callback_on_step_end (`Callable`, *optional*):
463
+ A function that is called at the end of each denoising step during inference. The function is called
464
+ with the following arguments:
465
+ `callback_on_step_end(self: DiffusionPipeline, step: int, timestep: int, callback_kwargs: Dict)`.
466
+ `callback_kwargs` will include a list of all tensors as specified by
467
+ `callback_on_step_end_tensor_inputs`.
468
+ callback_on_step_end_tensor_inputs (`List`, *optional*):
469
+ The list of tensor inputs for the `callback_on_step_end` function. The tensors specified in the list
470
+ will be passed as `callback_kwargs` argument. You will only be able to include variables listed in the
471
+ `._callback_tensor_inputs` attribute of your pipeline class.
472
+ return_dict (`bool`, *optional*, defaults to `True`):
473
+ Whether or not to return a [`~pipelines.stable_diffusion.StableDiffusionPipelineOutput`] instead of a
474
+ plain tuple.
475
+
476
+ Examples:
477
+
478
+ Returns:
479
+ [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] or `tuple`:
480
+ If `return_dict` is `True`, [`~pipelines.stable_diffusion.StableVideoDiffusionPipelineOutput`] is
481
+ returned, otherwise a `tuple` of (`List[List[PIL.Image.Image]]` or `np.ndarray` or `torch.Tensor`) is
482
+ returned.
483
+ """
484
+ # 0. Default height and width to unet
485
+ height = height or self.unet.config.sample_size * self.vae_scale_factor
486
+ width = width or self.unet.config.sample_size * self.vae_scale_factor
487
+
488
+ num_frames = num_frames if num_frames is not None else self.unet.config.num_frames
489
+ decode_chunk_size = decode_chunk_size if decode_chunk_size is not None else num_frames
490
+
491
+ # 1. Check inputs. Raise error if not correct
492
+ self.check_inputs(image, height, width)
493
+
494
+ # 2. Define call parameters
495
+ if isinstance(image, PIL.Image.Image):
496
+ batch_size = 1
497
+ elif isinstance(image, list):
498
+ batch_size = len(image)
499
+ else:
500
+ batch_size = image.shape[0]
501
+ device = self._execution_device
502
+ # here `guidance_scale` is defined analog to the guidance weight `w` of equation (2)
503
+ # of the Imagen paper: https://huggingface.co/papers/2205.11487 . `guidance_scale = 1`
504
+ # corresponds to doing no classifier free guidance.
505
+ self._guidance_scale = max_guidance_scale
506
+
507
+ # 3. Encode input image
508
+ image_embeddings = self._encode_image(image, device, num_videos_per_prompt, self.do_classifier_free_guidance)
509
+
510
+ # NOTE: Stable Video Diffusion was conditioned on fps - 1, which is why it is reduced here.
511
+ # See: https://github.com/Stability-AI/generative-models/blob/ed0997173f98eaf8f4edf7ba5fe8f15c6b877fd3/scripts/sampling/simple_video_sample.py#L188
512
+ fps = fps - 1
513
+
514
+ # 4. Encode input image using VAE
515
+ image = self.video_processor.preprocess(image, height=height, width=width).to(device)
516
+ noise = randn_tensor(image.shape, generator=generator, device=device, dtype=image.dtype)
517
+ image = image + noise_aug_strength * noise
518
+
519
+ needs_upcasting = self.vae.dtype == torch.float16 and self.vae.config.force_upcast
520
+ if needs_upcasting:
521
+ self.vae.to(dtype=torch.float32)
522
+
523
+ image_latents = self._encode_vae_image(
524
+ image,
525
+ device=device,
526
+ num_videos_per_prompt=num_videos_per_prompt,
527
+ do_classifier_free_guidance=self.do_classifier_free_guidance,
528
+ )
529
+ image_latents = image_latents.to(image_embeddings.dtype)
530
+
531
+ # cast back to fp16 if needed
532
+ if needs_upcasting:
533
+ self.vae.to(dtype=torch.float16)
534
+
535
+ # Repeat the image latents for each frame so we can concatenate them with the noise
536
+ # image_latents [batch, channels, height, width] ->[batch, num_frames, channels, height, width]
537
+ image_latents = image_latents.unsqueeze(1).repeat(1, num_frames, 1, 1, 1)
538
+
539
+ # 5. Get Added Time IDs
540
+ added_time_ids = self._get_add_time_ids(
541
+ fps,
542
+ motion_bucket_id,
543
+ noise_aug_strength,
544
+ image_embeddings.dtype,
545
+ batch_size,
546
+ num_videos_per_prompt,
547
+ self.do_classifier_free_guidance,
548
+ )
549
+ added_time_ids = added_time_ids.to(device)
550
+
551
+ # 6. Prepare timesteps
552
+ timesteps, num_inference_steps = retrieve_timesteps(self.scheduler, num_inference_steps, device, None, sigmas)
553
+
554
+ # 7. Prepare latent variables
555
+ num_channels_latents = self.unet.config.in_channels
556
+ latents = self.prepare_latents(
557
+ batch_size * num_videos_per_prompt,
558
+ num_frames,
559
+ num_channels_latents,
560
+ height,
561
+ width,
562
+ image_embeddings.dtype,
563
+ device,
564
+ generator,
565
+ latents,
566
+ )
567
+
568
+ # 8. Prepare guidance scale
569
+ guidance_scale = torch.linspace(min_guidance_scale, max_guidance_scale, num_frames).unsqueeze(0)
570
+ guidance_scale = guidance_scale.to(device, latents.dtype)
571
+ guidance_scale = guidance_scale.repeat(batch_size * num_videos_per_prompt, 1)
572
+ guidance_scale = _append_dims(guidance_scale, latents.ndim)
573
+
574
+ self._guidance_scale = guidance_scale
575
+
576
+ # 9. Denoising loop
577
+ num_warmup_steps = len(timesteps) - num_inference_steps * self.scheduler.order
578
+ self._num_timesteps = len(timesteps)
579
+ with self.progress_bar(total=num_inference_steps) as progress_bar:
580
+ for i, t in enumerate(timesteps):
581
+ # expand the latents if we are doing classifier free guidance
582
+ latent_model_input = torch.cat([latents] * 2) if self.do_classifier_free_guidance else latents
583
+ latent_model_input = self.scheduler.scale_model_input(latent_model_input, t)
584
+
585
+ # Concatenate image_latents over channels dimension
586
+ latent_model_input = torch.cat([latent_model_input, image_latents], dim=2)
587
+
588
+ # predict the noise residual
589
+ noise_pred = self.unet(
590
+ latent_model_input,
591
+ t,
592
+ encoder_hidden_states=image_embeddings,
593
+ added_time_ids=added_time_ids,
594
+ return_dict=False,
595
+ )[0]
596
+
597
+ # perform guidance
598
+ if self.do_classifier_free_guidance:
599
+ noise_pred_uncond, noise_pred_cond = noise_pred.chunk(2)
600
+ noise_pred = noise_pred_uncond + self.guidance_scale * (noise_pred_cond - noise_pred_uncond)
601
+
602
+ # compute the previous noisy sample x_t -> x_t-1
603
+ latents = self.scheduler.step(noise_pred, t, latents).prev_sample
604
+
605
+ if callback_on_step_end is not None:
606
+ callback_kwargs = {}
607
+ for k in callback_on_step_end_tensor_inputs:
608
+ callback_kwargs[k] = locals()[k]
609
+ callback_outputs = callback_on_step_end(self, i, t, callback_kwargs)
610
+
611
+ latents = callback_outputs.pop("latents", latents)
612
+
613
+ if i == len(timesteps) - 1 or ((i + 1) > num_warmup_steps and (i + 1) % self.scheduler.order == 0):
614
+ progress_bar.update()
615
+
616
+ if XLA_AVAILABLE:
617
+ xm.mark_step()
618
+
619
+ if not output_type == "latent":
620
+ # cast back to fp16 if needed
621
+ if needs_upcasting:
622
+ self.vae.to(dtype=torch.float16)
623
+ frames = self.decode_latents(latents, num_frames, decode_chunk_size)
624
+ frames = self.video_processor.postprocess_video(video=frames, output_type=output_type)
625
+ else:
626
+ frames = latents
627
+
628
+ self.maybe_free_model_hooks()
629
+
630
+ if not return_dict:
631
+ return frames
632
+
633
+ return StableVideoDiffusionPipelineOutput(frames=frames)
634
+
635
+
636
+ # resizing utils
637
+ # TODO: clean up later
638
+ def _resize_with_antialiasing(input, size, interpolation="bicubic", align_corners=True):
639
+ h, w = input.shape[-2:]
640
+ factors = (h / size[0], w / size[1])
641
+
642
+ # First, we have to determine sigma
643
+ # Taken from skimage: https://github.com/scikit-image/scikit-image/blob/v0.19.2/skimage/transform/_warps.py#L171
644
+ sigmas = (
645
+ max((factors[0] - 1.0) / 2.0, 0.001),
646
+ max((factors[1] - 1.0) / 2.0, 0.001),
647
+ )
648
+
649
+ # Now kernel size. Good results are for 3 sigma, but that is kind of slow. Pillow uses 1 sigma
650
+ # https://github.com/python-pillow/Pillow/blob/master/src/libImaging/Resample.c#L206
651
+ # But they do it in the 2 passes, which gives better results. Let's try 2 sigmas for now
652
+ ks = int(max(2.0 * 2 * sigmas[0], 3)), int(max(2.0 * 2 * sigmas[1], 3))
653
+
654
+ # Make sure it is odd
655
+ if (ks[0] % 2) == 0:
656
+ ks = ks[0] + 1, ks[1]
657
+
658
+ if (ks[1] % 2) == 0:
659
+ ks = ks[0], ks[1] + 1
660
+
661
+ input = _gaussian_blur2d(input, ks, sigmas)
662
+
663
+ output = torch.nn.functional.interpolate(input, size=size, mode=interpolation, align_corners=align_corners)
664
+ return output
665
+
666
+
667
+ def _compute_padding(kernel_size):
668
+ """Compute padding tuple."""
669
+ # 4 or 6 ints: (padding_left, padding_right,padding_top,padding_bottom)
670
+ # https://pytorch.org/docs/stable/nn.html#torch.nn.functional.pad
671
+ if len(kernel_size) < 2:
672
+ raise AssertionError(kernel_size)
673
+ computed = [k - 1 for k in kernel_size]
674
+
675
+ # for even kernels we need to do asymmetric padding :(
676
+ out_padding = 2 * len(kernel_size) * [0]
677
+
678
+ for i in range(len(kernel_size)):
679
+ computed_tmp = computed[-(i + 1)]
680
+
681
+ pad_front = computed_tmp // 2
682
+ pad_rear = computed_tmp - pad_front
683
+
684
+ out_padding[2 * i + 0] = pad_front
685
+ out_padding[2 * i + 1] = pad_rear
686
+
687
+ return out_padding
688
+
689
+
690
+ def _filter2d(input, kernel):
691
+ # prepare kernel
692
+ b, c, h, w = input.shape
693
+ tmp_kernel = kernel[:, None, ...].to(device=input.device, dtype=input.dtype)
694
+
695
+ tmp_kernel = tmp_kernel.expand(-1, c, -1, -1)
696
+
697
+ height, width = tmp_kernel.shape[-2:]
698
+
699
+ padding_shape: List[int] = _compute_padding([height, width])
700
+ input = torch.nn.functional.pad(input, padding_shape, mode="reflect")
701
+
702
+ # kernel and input tensor reshape to align element-wise or batch-wise params
703
+ tmp_kernel = tmp_kernel.reshape(-1, 1, height, width)
704
+ input = input.view(-1, tmp_kernel.size(0), input.size(-2), input.size(-1))
705
+
706
+ # convolve the tensor with the kernel.
707
+ output = torch.nn.functional.conv2d(input, tmp_kernel, groups=tmp_kernel.size(0), padding=0, stride=1)
708
+
709
+ out = output.view(b, c, h, w)
710
+ return out
711
+
712
+
713
+ def _gaussian(window_size: int, sigma):
714
+ if isinstance(sigma, float):
715
+ sigma = torch.tensor([[sigma]])
716
+
717
+ batch_size = sigma.shape[0]
718
+
719
+ x = (torch.arange(window_size, device=sigma.device, dtype=sigma.dtype) - window_size // 2).expand(batch_size, -1)
720
+
721
+ if window_size % 2 == 0:
722
+ x = x + 0.5
723
+
724
+ gauss = torch.exp(-x.pow(2.0) / (2 * sigma.pow(2.0)))
725
+
726
+ return gauss / gauss.sum(-1, keepdim=True)
727
+
728
+
729
+ def _gaussian_blur2d(input, kernel_size, sigma):
730
+ if isinstance(sigma, tuple):
731
+ sigma = torch.tensor([sigma], dtype=input.dtype)
732
+ else:
733
+ sigma = sigma.to(dtype=input.dtype)
734
+
735
+ ky, kx = int(kernel_size[0]), int(kernel_size[1])
736
+ bs = sigma.shape[0]
737
+ kernel_x = _gaussian(kx, sigma[:, 1].view(bs, 1))
738
+ kernel_y = _gaussian(ky, sigma[:, 0].view(bs, 1))
739
+ out_x = _filter2d(input, kernel_x[..., None, :])
740
+ out = _filter2d(out_x, kernel_y[..., None])
741
+
742
+ return out
ctrl_world/src/models/unet_spatio_temporal_condition.py ADDED
@@ -0,0 +1,522 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from dataclasses import dataclass
2
+ from typing import Dict, Optional, Tuple, Union
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ from diffusers.configuration_utils import ConfigMixin, register_to_config
8
+ from diffusers.loaders import UNet2DConditionLoadersMixin
9
+ from diffusers.utils import BaseOutput, logging
10
+ from diffusers.models.attention_processor import CROSS_ATTENTION_PROCESSORS, AttentionProcessor, AttnProcessor
11
+ from diffusers.models.embeddings import TimestepEmbedding, Timesteps
12
+ from diffusers.models.modeling_utils import ModelMixin
13
+ from diffusers.models.unets.unet_3d_blocks import UNetMidBlockSpatioTemporal, get_down_block, get_up_block
14
+
15
+
16
+ logger = logging.get_logger(__name__) # pylint: disable=invalid-name
17
+
18
+
19
+ @dataclass
20
+ class UNetSpatioTemporalConditionOutput(BaseOutput):
21
+ """
22
+ The output of [`UNetSpatioTemporalConditionModel`].
23
+
24
+ Args:
25
+ sample (`torch.Tensor` of shape `(batch_size, num_frames, num_channels, height, width)`):
26
+ The hidden states output conditioned on `encoder_hidden_states` input. Output of last layer of model.
27
+ """
28
+
29
+ sample: torch.Tensor = None
30
+
31
+
32
+ class UNetSpatioTemporalConditionModel(ModelMixin, ConfigMixin, UNet2DConditionLoadersMixin):
33
+ r"""
34
+ A conditional Spatio-Temporal UNet model that takes a noisy video frames, conditional state, and a timestep and
35
+ returns a sample shaped output.
36
+
37
+ This model inherits from [`ModelMixin`]. Check the superclass documentation for it's generic methods implemented
38
+ for all models (such as downloading or saving).
39
+
40
+ Parameters:
41
+ sample_size (`int` or `Tuple[int, int]`, *optional*, defaults to `None`):
42
+ Height and width of input/output sample.
43
+ in_channels (`int`, *optional*, defaults to 8): Number of channels in the input sample.
44
+ out_channels (`int`, *optional*, defaults to 4): Number of channels in the output.
45
+ down_block_types (`Tuple[str]`, *optional*, defaults to `("CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "CrossAttnDownBlockSpatioTemporal", "DownBlockSpatioTemporal")`):
46
+ The tuple of downsample blocks to use.
47
+ up_block_types (`Tuple[str]`, *optional*, defaults to `("UpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal", "CrossAttnUpBlockSpatioTemporal")`):
48
+ The tuple of upsample blocks to use.
49
+ block_out_channels (`Tuple[int]`, *optional*, defaults to `(320, 640, 1280, 1280)`):
50
+ The tuple of output channels for each block.
51
+ addition_time_embed_dim: (`int`, defaults to 256):
52
+ Dimension to to encode the additional time ids.
53
+ projection_class_embeddings_input_dim (`int`, defaults to 768):
54
+ The dimension of the projection of encoded `added_time_ids`.
55
+ layers_per_block (`int`, *optional*, defaults to 2): The number of layers per block.
56
+ cross_attention_dim (`int` or `Tuple[int]`, *optional*, defaults to 1280):
57
+ The dimension of the cross attention features.
58
+ transformer_layers_per_block (`int`, `Tuple[int]`, or `Tuple[Tuple]` , *optional*, defaults to 1):
59
+ The number of transformer blocks of type [`~models.attention.BasicTransformerBlock`]. Only relevant for
60
+ [`~models.unets.unet_3d_blocks.CrossAttnDownBlockSpatioTemporal`],
61
+ [`~models.unets.unet_3d_blocks.CrossAttnUpBlockSpatioTemporal`],
62
+ [`~models.unets.unet_3d_blocks.UNetMidBlockSpatioTemporal`].
63
+ num_attention_heads (`int`, `Tuple[int]`, defaults to `(5, 10, 10, 20)`):
64
+ The number of attention heads.
65
+ dropout (`float`, *optional*, defaults to 0.0): The dropout probability to use.
66
+ """
67
+
68
+ _supports_gradient_checkpointing = True
69
+
70
+ @register_to_config
71
+ def __init__(
72
+ self,
73
+ sample_size: Optional[int] = None,
74
+ in_channels: int = 8,
75
+ out_channels: int = 4,
76
+ down_block_types: Tuple[str] = (
77
+ "CrossAttnDownBlockSpatioTemporal",
78
+ "CrossAttnDownBlockSpatioTemporal",
79
+ "CrossAttnDownBlockSpatioTemporal",
80
+ "DownBlockSpatioTemporal",
81
+ ),
82
+ up_block_types: Tuple[str] = (
83
+ "UpBlockSpatioTemporal",
84
+ "CrossAttnUpBlockSpatioTemporal",
85
+ "CrossAttnUpBlockSpatioTemporal",
86
+ "CrossAttnUpBlockSpatioTemporal",
87
+ ),
88
+ block_out_channels: Tuple[int] = (320, 640, 1280, 1280),
89
+ addition_time_embed_dim: int = 256,
90
+ projection_class_embeddings_input_dim: int = 768,
91
+ layers_per_block: Union[int, Tuple[int]] = 2,
92
+ cross_attention_dim: Union[int, Tuple[int]] = 1024,
93
+ transformer_layers_per_block: Union[int, Tuple[int], Tuple[Tuple]] = 1,
94
+ num_attention_heads: Union[int, Tuple[int]] = (5, 10, 20, 20),
95
+ num_frames: int = 25,
96
+ ):
97
+ super().__init__()
98
+
99
+ self.sample_size = sample_size
100
+
101
+ # Check inputs
102
+ if len(down_block_types) != len(up_block_types):
103
+ raise ValueError(
104
+ f"Must provide the same number of `down_block_types` as `up_block_types`. `down_block_types`: {down_block_types}. `up_block_types`: {up_block_types}."
105
+ )
106
+
107
+ if len(block_out_channels) != len(down_block_types):
108
+ raise ValueError(
109
+ f"Must provide the same number of `block_out_channels` as `down_block_types`. `block_out_channels`: {block_out_channels}. `down_block_types`: {down_block_types}."
110
+ )
111
+
112
+ if not isinstance(num_attention_heads, int) and len(num_attention_heads) != len(down_block_types):
113
+ raise ValueError(
114
+ f"Must provide the same number of `num_attention_heads` as `down_block_types`. `num_attention_heads`: {num_attention_heads}. `down_block_types`: {down_block_types}."
115
+ )
116
+
117
+ if isinstance(cross_attention_dim, list) and len(cross_attention_dim) != len(down_block_types):
118
+ raise ValueError(
119
+ f"Must provide the same number of `cross_attention_dim` as `down_block_types`. `cross_attention_dim`: {cross_attention_dim}. `down_block_types`: {down_block_types}."
120
+ )
121
+
122
+ if not isinstance(layers_per_block, int) and len(layers_per_block) != len(down_block_types):
123
+ raise ValueError(
124
+ f"Must provide the same number of `layers_per_block` as `down_block_types`. `layers_per_block`: {layers_per_block}. `down_block_types`: {down_block_types}."
125
+ )
126
+
127
+ # input
128
+ self.conv_in = nn.Conv2d(
129
+ in_channels,
130
+ block_out_channels[0],
131
+ kernel_size=3,
132
+ padding=1,
133
+ )
134
+
135
+ # time
136
+ time_embed_dim = block_out_channels[0] * 4
137
+
138
+ self.time_proj = Timesteps(block_out_channels[0], True, downscale_freq_shift=0)
139
+ timestep_input_dim = block_out_channels[0]
140
+
141
+ self.time_embedding = TimestepEmbedding(timestep_input_dim, time_embed_dim)
142
+
143
+ self.add_time_proj = Timesteps(addition_time_embed_dim, True, downscale_freq_shift=0)
144
+ self.add_embedding = TimestepEmbedding(projection_class_embeddings_input_dim, time_embed_dim)
145
+
146
+ self.down_blocks = nn.ModuleList([])
147
+ self.up_blocks = nn.ModuleList([])
148
+
149
+ if isinstance(num_attention_heads, int):
150
+ num_attention_heads = (num_attention_heads,) * len(down_block_types)
151
+
152
+ if isinstance(cross_attention_dim, int):
153
+ cross_attention_dim = (cross_attention_dim,) * len(down_block_types)
154
+
155
+ if isinstance(layers_per_block, int):
156
+ layers_per_block = [layers_per_block] * len(down_block_types)
157
+
158
+ if isinstance(transformer_layers_per_block, int):
159
+ transformer_layers_per_block = [transformer_layers_per_block] * len(down_block_types)
160
+
161
+ blocks_time_embed_dim = time_embed_dim
162
+
163
+ # down
164
+ output_channel = block_out_channels[0]
165
+ for i, down_block_type in enumerate(down_block_types):
166
+ input_channel = output_channel
167
+ output_channel = block_out_channels[i]
168
+ is_final_block = i == len(block_out_channels) - 1
169
+
170
+ down_block = get_down_block(
171
+ down_block_type,
172
+ num_layers=layers_per_block[i],
173
+ transformer_layers_per_block=transformer_layers_per_block[i],
174
+ in_channels=input_channel,
175
+ out_channels=output_channel,
176
+ temb_channels=blocks_time_embed_dim,
177
+ add_downsample=not is_final_block,
178
+ resnet_eps=1e-5,
179
+ cross_attention_dim=cross_attention_dim[i],
180
+ num_attention_heads=num_attention_heads[i],
181
+ resnet_act_fn="silu",
182
+ )
183
+ self.down_blocks.append(down_block)
184
+
185
+ # mid
186
+ self.mid_block = UNetMidBlockSpatioTemporal(
187
+ block_out_channels[-1],
188
+ temb_channels=blocks_time_embed_dim,
189
+ transformer_layers_per_block=transformer_layers_per_block[-1],
190
+ cross_attention_dim=cross_attention_dim[-1],
191
+ num_attention_heads=num_attention_heads[-1],
192
+ )
193
+
194
+ # count how many layers upsample the images
195
+ self.num_upsamplers = 0
196
+
197
+ # up
198
+ reversed_block_out_channels = list(reversed(block_out_channels))
199
+ reversed_num_attention_heads = list(reversed(num_attention_heads))
200
+ reversed_layers_per_block = list(reversed(layers_per_block))
201
+ reversed_cross_attention_dim = list(reversed(cross_attention_dim))
202
+ reversed_transformer_layers_per_block = list(reversed(transformer_layers_per_block))
203
+
204
+ output_channel = reversed_block_out_channels[0]
205
+ for i, up_block_type in enumerate(up_block_types):
206
+ is_final_block = i == len(block_out_channels) - 1
207
+
208
+ prev_output_channel = output_channel
209
+ output_channel = reversed_block_out_channels[i]
210
+ input_channel = reversed_block_out_channels[min(i + 1, len(block_out_channels) - 1)]
211
+
212
+ # add upsample block for all BUT final layer
213
+ if not is_final_block:
214
+ add_upsample = True
215
+ self.num_upsamplers += 1
216
+ else:
217
+ add_upsample = False
218
+
219
+ up_block = get_up_block(
220
+ up_block_type,
221
+ num_layers=reversed_layers_per_block[i] + 1,
222
+ transformer_layers_per_block=reversed_transformer_layers_per_block[i],
223
+ in_channels=input_channel,
224
+ out_channels=output_channel,
225
+ prev_output_channel=prev_output_channel,
226
+ temb_channels=blocks_time_embed_dim,
227
+ add_upsample=add_upsample,
228
+ resnet_eps=1e-5,
229
+ resolution_idx=i,
230
+ cross_attention_dim=reversed_cross_attention_dim[i],
231
+ num_attention_heads=reversed_num_attention_heads[i],
232
+ resnet_act_fn="silu",
233
+ )
234
+ self.up_blocks.append(up_block)
235
+ prev_output_channel = output_channel
236
+
237
+ # out
238
+ self.conv_norm_out = nn.GroupNorm(num_channels=block_out_channels[0], num_groups=32, eps=1e-5)
239
+ self.conv_act = nn.SiLU()
240
+
241
+ self.conv_out = nn.Conv2d(
242
+ block_out_channels[0],
243
+ out_channels,
244
+ kernel_size=3,
245
+ padding=1,
246
+ )
247
+
248
+ @property
249
+ def attn_processors(self) -> Dict[str, AttentionProcessor]:
250
+ r"""
251
+ Returns:
252
+ `dict` of attention processors: A dictionary containing all attention processors used in the model with
253
+ indexed by its weight name.
254
+ """
255
+ # set recursively
256
+ processors = {}
257
+
258
+ def fn_recursive_add_processors(
259
+ name: str,
260
+ module: torch.nn.Module,
261
+ processors: Dict[str, AttentionProcessor],
262
+ ):
263
+ if hasattr(module, "get_processor"):
264
+ processors[f"{name}.processor"] = module.get_processor()
265
+
266
+ for sub_name, child in module.named_children():
267
+ fn_recursive_add_processors(f"{name}.{sub_name}", child, processors)
268
+
269
+ return processors
270
+
271
+ for name, module in self.named_children():
272
+ fn_recursive_add_processors(name, module, processors)
273
+
274
+ return processors
275
+
276
+ def set_attn_processor(self, processor: Union[AttentionProcessor, Dict[str, AttentionProcessor]]):
277
+ r"""
278
+ Sets the attention processor to use to compute attention.
279
+
280
+ Parameters:
281
+ processor (`dict` of `AttentionProcessor` or only `AttentionProcessor`):
282
+ The instantiated processor class or a dictionary of processor classes that will be set as the processor
283
+ for **all** `Attention` layers.
284
+
285
+ If `processor` is a dict, the key needs to define the path to the corresponding cross attention
286
+ processor. This is strongly recommended when setting trainable attention processors.
287
+
288
+ """
289
+ count = len(self.attn_processors.keys())
290
+
291
+ if isinstance(processor, dict) and len(processor) != count:
292
+ raise ValueError(
293
+ f"A dict of processors was passed, but the number of processors {len(processor)} does not match the"
294
+ f" number of attention layers: {count}. Please make sure to pass {count} processor classes."
295
+ )
296
+
297
+ def fn_recursive_attn_processor(name: str, module: torch.nn.Module, processor):
298
+ if hasattr(module, "set_processor"):
299
+ if not isinstance(processor, dict):
300
+ module.set_processor(processor)
301
+ else:
302
+ module.set_processor(processor.pop(f"{name}.processor"))
303
+
304
+ for sub_name, child in module.named_children():
305
+ fn_recursive_attn_processor(f"{name}.{sub_name}", child, processor)
306
+
307
+ for name, module in self.named_children():
308
+ fn_recursive_attn_processor(name, module, processor)
309
+
310
+ def set_default_attn_processor(self):
311
+ """
312
+ Disables custom attention processors and sets the default attention implementation.
313
+ """
314
+ if all(proc.__class__ in CROSS_ATTENTION_PROCESSORS for proc in self.attn_processors.values()):
315
+ processor = AttnProcessor()
316
+ else:
317
+ raise ValueError(
318
+ f"Cannot call `set_default_attn_processor` when attention processors are of type {next(iter(self.attn_processors.values()))}"
319
+ )
320
+
321
+ self.set_attn_processor(processor)
322
+
323
+ # Copied from diffusers.models.unets.unet_3d_condition.UNet3DConditionModel.enable_forward_chunking
324
+ def enable_forward_chunking(self, chunk_size: Optional[int] = None, dim: int = 0) -> None:
325
+ """
326
+ Sets the attention processor to use [feed forward
327
+ chunking](https://huggingface.co/blog/reformer#2-chunked-feed-forward-layers).
328
+
329
+ Parameters:
330
+ chunk_size (`int`, *optional*):
331
+ The chunk size of the feed-forward layers. If not specified, will run feed-forward layer individually
332
+ over each tensor of dim=`dim`.
333
+ dim (`int`, *optional*, defaults to `0`):
334
+ The dimension over which the feed-forward computation should be chunked. Choose between dim=0 (batch)
335
+ or dim=1 (sequence length).
336
+ """
337
+ if dim not in [0, 1]:
338
+ raise ValueError(f"Make sure to set `dim` to either 0 or 1, not {dim}")
339
+
340
+ # By default chunk size is 1
341
+ chunk_size = chunk_size or 1
342
+
343
+ def fn_recursive_feed_forward(module: torch.nn.Module, chunk_size: int, dim: int):
344
+ if hasattr(module, "set_chunk_feed_forward"):
345
+ module.set_chunk_feed_forward(chunk_size=chunk_size, dim=dim)
346
+
347
+ for child in module.children():
348
+ fn_recursive_feed_forward(child, chunk_size, dim)
349
+
350
+ for module in self.children():
351
+ fn_recursive_feed_forward(module, chunk_size, dim)
352
+
353
+ def forward(
354
+ self,
355
+ sample: torch.Tensor,
356
+ timestep: Union[torch.Tensor, float, int],
357
+ encoder_hidden_states: torch.Tensor,
358
+ added_time_ids: torch.Tensor,
359
+ return_dict: bool = True,
360
+ frame_level_cond=False,
361
+ ) -> Union[UNetSpatioTemporalConditionOutput, Tuple]:
362
+ r"""
363
+ The [`UNetSpatioTemporalConditionModel`] forward method.
364
+
365
+ Args:
366
+ sample (`torch.Tensor`):
367
+ The noisy input tensor with the following shape `(batch, num_frames, channel, height, width)`.
368
+ timestep (`torch.Tensor` or `float` or `int`): The number of timesteps to denoise an input.
369
+ encoder_hidden_states (`torch.Tensor`):
370
+ The encoder hidden states with shape `(batch, sequence_length, cross_attention_dim)`.
371
+ added_time_ids: (`torch.Tensor`):
372
+ The additional time ids with shape `(batch, num_additional_ids)`. These are encoded with sinusoidal
373
+ embeddings and added to the time embeddings.
374
+ return_dict (`bool`, *optional*, defaults to `True`):
375
+ Whether or not to return a [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] instead
376
+ of a plain tuple.
377
+ Returns:
378
+ [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] or `tuple`:
379
+ If `return_dict` is True, an [`~models.unet_slatio_temporal.UNetSpatioTemporalConditionOutput`] is
380
+ returned, otherwise a `tuple` is returned where the first element is the sample tensor.
381
+ """
382
+ # By default samples have to be AT least a multiple of the overall upsampling factor.
383
+ # The overall upsampling factor is equal to 2 ** (# num of upsampling layears).
384
+ # However, the upsampling interpolation output size can be forced to fit any upsampling size
385
+ # on the fly if necessary.
386
+ default_overall_up_factor = 2**self.num_upsamplers
387
+
388
+ # upsample size should be forwarded when sample is not a multiple of `default_overall_up_factor`
389
+ forward_upsample_size = False
390
+ upsample_size = None
391
+
392
+ if any(s % default_overall_up_factor != 0 for s in sample.shape[-2:]):
393
+ logger.info("Forward upsample size to force interpolation output size.")
394
+ forward_upsample_size = True
395
+
396
+ # 1. time
397
+ timesteps = timestep
398
+ if not torch.is_tensor(timesteps):
399
+ # TODO: this requires sync between CPU and GPU. So try to pass timesteps as tensors if you can
400
+ # This would be a good case for the `match` statement (Python 3.10+)
401
+ is_mps = sample.device.type == "mps"
402
+ is_npu = sample.device.type == "npu"
403
+ if isinstance(timestep, float):
404
+ dtype = torch.float32 if (is_mps or is_npu) else torch.float64
405
+ else:
406
+ dtype = torch.int32 if (is_mps or is_npu) else torch.int64
407
+ timesteps = torch.tensor([timesteps], dtype=dtype, device=sample.device)
408
+ elif len(timesteps.shape) == 0:
409
+ timesteps = timesteps[None].to(sample.device)
410
+
411
+ # broadcast to batch dimension in a way that's compatible with ONNX/Core ML
412
+ batch_size, num_frames = sample.shape[:2]
413
+ timesteps = timesteps.expand(batch_size)
414
+
415
+ t_emb = self.time_proj(timesteps)
416
+
417
+ # `Timesteps` does not contain any weights and will always return f32 tensors
418
+ # but time_embedding might actually be running in fp16. so we need to cast here.
419
+ # there might be better ways to encapsulate this.
420
+ t_emb = t_emb.to(dtype=sample.dtype)
421
+
422
+ emb = self.time_embedding(t_emb)
423
+
424
+ time_embeds = self.add_time_proj(added_time_ids.flatten())
425
+ time_embeds = time_embeds.reshape((batch_size, -1))
426
+ time_embeds = time_embeds.to(emb.dtype)
427
+ aug_emb = self.add_embedding(time_embeds)
428
+ emb = emb + aug_emb
429
+
430
+ # Flatten the batch and frames dimensions
431
+ # sample: [batch, frames, channels, height, width] -> [batch * frames, channels, height, width]
432
+ sample = sample.flatten(0, 1)
433
+ # Repeat the embeddings num_video_frames times
434
+ # emb: [batch, channels] -> [batch * frames, channels]
435
+ emb = emb.repeat_interleave(num_frames, dim=0, output_size=emb.shape[0] * num_frames)
436
+
437
+ # encoder_hidden_states: [batch, 1, channels] -> [batch * frames, 1, channels]
438
+ # encoder_hidden_states = encoder_hidden_states.repeat_interleave(
439
+ # num_frames, dim=0, output_size=encoder_hidden_states.shape[0] * num_frames
440
+ # )
441
+
442
+ ############################# newly added to support frame_level pose conditioning ########################################
443
+ # print('new one!!!!!!!!!')
444
+ if not frame_level_cond:
445
+ encoder_hidden_states = encoder_hidden_states.repeat_interleave(num_frames, dim=0)
446
+ else:
447
+ encoder_hidden_states = encoder_hidden_states.reshape(batch_size * num_frames, -1, encoder_hidden_states.shape[-1])
448
+ ############################################################################################################################
449
+
450
+ # 2. pre-process
451
+ sample = self.conv_in(sample)
452
+
453
+ image_only_indicator = torch.zeros(batch_size, num_frames, dtype=sample.dtype, device=sample.device)
454
+
455
+ down_block_res_samples = (sample,)
456
+ for downsample_block in self.down_blocks:
457
+ if hasattr(downsample_block, "has_cross_attention") and downsample_block.has_cross_attention:
458
+ sample, res_samples = downsample_block(
459
+ hidden_states=sample,
460
+ temb=emb,
461
+ encoder_hidden_states=encoder_hidden_states,
462
+ image_only_indicator=image_only_indicator,
463
+ )
464
+ else:
465
+ sample, res_samples = downsample_block(
466
+ hidden_states=sample,
467
+ temb=emb,
468
+ image_only_indicator=image_only_indicator,
469
+ )
470
+
471
+ down_block_res_samples += res_samples
472
+
473
+ # 4. mid
474
+ sample = self.mid_block(
475
+ hidden_states=sample,
476
+ temb=emb,
477
+ encoder_hidden_states=encoder_hidden_states,
478
+ image_only_indicator=image_only_indicator,
479
+ )
480
+
481
+ # 5. up
482
+ for i, upsample_block in enumerate(self.up_blocks):
483
+ is_final_block = i == len(self.up_blocks) - 1
484
+
485
+ res_samples = down_block_res_samples[-len(upsample_block.resnets) :]
486
+ down_block_res_samples = down_block_res_samples[: -len(upsample_block.resnets)]
487
+
488
+ # if we have not reached the final block and need to forward the
489
+ # upsample size, we do it here
490
+ if not is_final_block and forward_upsample_size:
491
+ upsample_size = down_block_res_samples[-1].shape[2:]
492
+
493
+ if hasattr(upsample_block, "has_cross_attention") and upsample_block.has_cross_attention:
494
+ sample = upsample_block(
495
+ hidden_states=sample,
496
+ temb=emb,
497
+ res_hidden_states_tuple=res_samples,
498
+ encoder_hidden_states=encoder_hidden_states,
499
+ upsample_size=upsample_size,
500
+ image_only_indicator=image_only_indicator,
501
+ )
502
+ else:
503
+ sample = upsample_block(
504
+ hidden_states=sample,
505
+ temb=emb,
506
+ res_hidden_states_tuple=res_samples,
507
+ upsample_size=upsample_size,
508
+ image_only_indicator=image_only_indicator,
509
+ )
510
+
511
+ # 6. post-process
512
+ sample = self.conv_norm_out(sample)
513
+ sample = self.conv_act(sample)
514
+ sample = self.conv_out(sample)
515
+
516
+ # 7. Reshape back to original shape
517
+ sample = sample.reshape(batch_size, num_frames, *sample.shape[1:])
518
+
519
+ if not return_dict:
520
+ return (sample,)
521
+
522
+ return UNetSpatioTemporalConditionOutput(sample=sample)
ctrl_world/src/world_model.py ADDED
@@ -0,0 +1,185 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from models.pipeline_stable_video_diffusion import StableVideoDiffusionPipeline
2
+ from models.unet_spatio_temporal_condition import UNetSpatioTemporalConditionModel
3
+ from models.pipeline_ctrl_world import CtrlWorldDiffusionPipeline
4
+ import torch
5
+ import torch.nn as nn
6
+
7
+ import json
8
+ import einops
9
+ import numpy as np
10
+ from huggingface_hub import snapshot_download
11
+ from transformers import AutoTokenizer, CLIPTextModelWithProjection
12
+
13
+ class Action_encoder2(nn.Module):
14
+ def __init__(self, action_dim, action_num, hidden_size, text_cond=True):
15
+ super().__init__()
16
+ self.action_dim = action_dim
17
+ self.action_num = action_num
18
+ self.hidden_size = hidden_size
19
+ self.text_cond = text_cond
20
+
21
+ input_dim = int(action_dim)
22
+ self.action_encode = nn.Sequential(
23
+ nn.Linear(input_dim, 1024),
24
+ nn.SiLU(),
25
+ nn.Linear(1024, 1024),
26
+ nn.SiLU(),
27
+ nn.Linear(1024, 1024)
28
+ )
29
+ # kaiming initialization
30
+ nn.init.kaiming_normal_(self.action_encode[0].weight, mode='fan_in', nonlinearity='relu')
31
+ nn.init.kaiming_normal_(self.action_encode[2].weight, mode='fan_in', nonlinearity='relu')
32
+
33
+ def forward(self, action, texts=None, text_tokinizer=None, text_encoder=None, frame_level_cond=True,):
34
+ # action: (B, action_num, action_dim)
35
+ B,T,D = action.shape
36
+ if not frame_level_cond:
37
+ action = einops.rearrange(action, 'b t d -> b 1 (t d)')
38
+ action = self.action_encode(action)
39
+
40
+ if texts is not None and self.text_cond:
41
+ # with 50% probability, add text condition
42
+ with torch.no_grad():
43
+ inputs = text_tokinizer(texts, padding='max_length', return_tensors="pt", truncation=True).to(text_encoder.device)
44
+ outputs = text_encoder(**inputs)
45
+ hidden_text = outputs.text_embeds # (B, 512)
46
+ hidden_text = einops.repeat(hidden_text, 'b c -> b 1 (n c)', n=2) # (B, 1, 1024)
47
+
48
+ action = action + hidden_text # (B, T, hidden_size)
49
+ return action # (B, 1, hidden_size) or (B, T, hidden_size) if frame_level_cond
50
+
51
+
52
+ class CrtlWorld(nn.Module):
53
+ def __init__(self, config: dict):
54
+ super(CrtlWorld, self).__init__()
55
+
56
+ self.config = config
57
+ # load from pretrained stable video diffusion
58
+ model_local_path = snapshot_download(
59
+ repo_id=config["svd_model_path"], # e.g. "stabilityai/stable-video-diffusion-img2vid"
60
+ repo_type="model"
61
+ )
62
+
63
+ # Load pipeline from downloaded path
64
+ self.pipeline = StableVideoDiffusionPipeline.from_pretrained(
65
+ model_local_path,
66
+ torch_dtype="auto"
67
+ )
68
+
69
+
70
+ unet = UNetSpatioTemporalConditionModel()
71
+ unet.load_state_dict(self.pipeline.unet.state_dict(), strict=False)
72
+ self.pipeline.unet = unet
73
+
74
+ self.unet = self.pipeline.unet
75
+ self.vae = self.pipeline.vae
76
+ self.image_encoder = self.pipeline.image_encoder
77
+ self.scheduler = self.pipeline.scheduler
78
+
79
+ # freeze vae, image_encoder, enable unet gradient ckpt
80
+ self.vae.requires_grad_(False)
81
+ self.image_encoder.requires_grad_(False)
82
+ self.unet.requires_grad_(True)
83
+ self.unet.enable_gradient_checkpointing()
84
+
85
+ # SVD is a img2video model, load a clip text encoder
86
+
87
+ model_local_path = snapshot_download(
88
+ repo_id=config["clip_model_path"], # e.g. "stabilityai/stable-video-diffusion-img2vid"
89
+ repo_type="model"
90
+ )
91
+
92
+ self.text_encoder = CLIPTextModelWithProjection.from_pretrained(
93
+ model_local_path,
94
+ torch_dtype="auto"
95
+ )
96
+ self.tokenizer = AutoTokenizer.from_pretrained(model_local_path, use_fast=False)
97
+ self.text_encoder.requires_grad_(False)
98
+
99
+ # initialize an action projector
100
+ self.action_encoder = Action_encoder2(action_dim=config["action_dim"], action_num=int(config["num_history"]+config["num_frames"]), hidden_size=1024, text_cond=config["text_cond"])
101
+
102
+ with open(f"{config["data_stat_path"]}", 'r') as f:
103
+ data_stat = json.load(f)
104
+ self.state_p01 = np.array(data_stat['state_01'])[None,:]
105
+ self.state_p99 = np.array(data_stat['state_99'])[None,:]
106
+
107
+ def normalize_bound(
108
+ self,
109
+ data: np.ndarray,
110
+ clip_min: float = -1,
111
+ clip_max: float = 1,
112
+ eps: float = 1e-8,
113
+ ) -> np.ndarray:
114
+ ndata = 2 * (data - self.state_p01) / (self.state_p99 - self.state_p01 + eps) - 1
115
+ return np.clip(ndata, clip_min, clip_max)
116
+
117
+ def decode(self, latents: torch.Tensor):
118
+
119
+ bsz, frame_num = latents.shape[:2]
120
+ x = latents.flatten(0, 1)
121
+
122
+ decoded = []
123
+ chunk_size = self.config["decode_chunk_size"]
124
+ for i in range(0, x.shape[0], chunk_size):
125
+ chunk = x[i:i + chunk_size] / self.pipeline.vae.config.scaling_factor
126
+ decode_kwargs = {"num_frames": chunk.shape[0]}
127
+ out = self.pipeline.vae.decode(chunk, **decode_kwargs).sample
128
+ decoded.append(out)
129
+
130
+ videos = torch.cat(decoded, dim=0)
131
+ videos = videos.reshape(bsz, frame_num, *videos.shape[1:])
132
+ videos = ((videos / 2.0 + 0.5).clamp(0, 1))
133
+ videos = videos.detach().float().cpu()
134
+
135
+ def encode(self, img: torch.Tensor):
136
+
137
+ x = img.unsqueeze(0)
138
+ x = x * 2 - 1 # [0,1] β†’ [-1,1]
139
+
140
+ vae = self.pipeline.vae
141
+ with torch.no_grad():
142
+ latent = vae.encode(x).latent_dist.sample()
143
+ latent = latent * vae.config.scaling_factor
144
+
145
+ return latent.detach()
146
+
147
+ def action_text_encode(self, action: torch.Tensor, text):
148
+
149
+ action_tensor = action.unsqueeze(0)
150
+
151
+ # ── Encode action (+ optional text) ───────────────────
152
+ with torch.no_grad():
153
+ if text is not None and self.config["text_cond"]:
154
+ text_token = self.action_encoder(action_tensor, [text], self.tokenizer, self.text_encoder)
155
+ else:
156
+ text_token = self.action_encoder(action_tensor)
157
+
158
+ return text_token.detach()
159
+
160
+ def get_latent_views(self, frames, current_latent, text_token):
161
+
162
+ his_cond = torch.cat(frames, dim=0).unsqueeze(0) # (1, num_history, 4, stacked_H, W)
163
+
164
+ # ── Run CtrlWorldDiffusionPipeline ────────────────────
165
+ with torch.no_grad():
166
+ _, latents = CtrlWorldDiffusionPipeline.__call__(
167
+ self.pipeline,
168
+ image=current_latent,
169
+ text=text_token,
170
+ width=self.config["width"],
171
+ height=int(self.config["height"] * 3), # 3 views stacked
172
+ num_frames=self.config["num_frames"],
173
+ history=his_cond,
174
+ num_inference_steps=self.config["num_inference_steps"],
175
+ decode_chunk_size=self.config["decode_chunk_size"],
176
+ max_guidance_scale=self.config["guidance_scale"],
177
+ fps=self.config["fps"],
178
+ motion_bucket_id=self.config["motion_bucket_id"],
179
+ mask=None,
180
+ output_type="latent",
181
+ return_dict=False,
182
+ frame_level_cond=True,
183
+ )
184
+
185
+ return latents