Update animatediff/pipelines/pipeline_animation.py
Browse files
animatediff/pipelines/pipeline_animation.py
CHANGED
|
@@ -8,6 +8,8 @@ import numpy as np
|
|
| 8 |
import torch
|
| 9 |
from tqdm import tqdm
|
| 10 |
|
|
|
|
|
|
|
| 11 |
from diffusers.utils import is_accelerate_available
|
| 12 |
from packaging import version
|
| 13 |
from transformers import CLIPTextModel, CLIPTokenizer
|
|
@@ -28,7 +30,7 @@ from diffusers.utils import deprecate, logging, BaseOutput
|
|
| 28 |
from einops import rearrange
|
| 29 |
|
| 30 |
from ..models.unet import UNet3DConditionModel
|
| 31 |
-
|
| 32 |
|
| 33 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 34 |
|
|
@@ -283,8 +285,29 @@ class AnimationPipeline(DiffusionPipeline):
|
|
| 283 |
f" {type(callback_steps)}."
|
| 284 |
)
|
| 285 |
|
| 286 |
-
def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
|
|
|
| 287 |
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 288 |
if isinstance(generator, list) and len(generator) != batch_size:
|
| 289 |
raise ValueError(
|
| 290 |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
|
@@ -296,6 +319,7 @@ class AnimationPipeline(DiffusionPipeline):
|
|
| 296 |
if isinstance(generator, list):
|
| 297 |
shape = shape
|
| 298 |
# shape = (1,) + shape[1:]
|
|
|
|
| 299 |
latents = [
|
| 300 |
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
| 301 |
for i in range(batch_size)
|
|
@@ -303,19 +327,29 @@ class AnimationPipeline(DiffusionPipeline):
|
|
| 303 |
latents = torch.cat(latents, dim=0).to(device)
|
| 304 |
else:
|
| 305 |
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 306 |
else:
|
| 307 |
if latents.shape != shape:
|
| 308 |
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
| 309 |
latents = latents.to(device)
|
| 310 |
|
| 311 |
# scale the initial noise by the standard deviation required by the scheduler
|
| 312 |
-
latents = latents * self.scheduler.init_noise_sigma
|
|
|
|
|
|
|
| 313 |
return latents
|
| 314 |
|
| 315 |
@torch.no_grad()
|
| 316 |
def __call__(
|
| 317 |
self,
|
| 318 |
prompt: Union[str, List[str]],
|
|
|
|
| 319 |
video_length: Optional[int],
|
| 320 |
height: Optional[int] = None,
|
| 321 |
width: Optional[int] = None,
|
|
@@ -368,6 +402,7 @@ class AnimationPipeline(DiffusionPipeline):
|
|
| 368 |
# Prepare latent variables
|
| 369 |
num_channels_latents = self.unet.in_channels
|
| 370 |
latents = self.prepare_latents(
|
|
|
|
| 371 |
batch_size * num_videos_per_prompt,
|
| 372 |
num_channels_latents,
|
| 373 |
video_length,
|
|
|
|
| 8 |
import torch
|
| 9 |
from tqdm import tqdm
|
| 10 |
|
| 11 |
+
import PIL
|
| 12 |
+
|
| 13 |
from diffusers.utils import is_accelerate_available
|
| 14 |
from packaging import version
|
| 15 |
from transformers import CLIPTextModel, CLIPTokenizer
|
|
|
|
| 30 |
from einops import rearrange
|
| 31 |
|
| 32 |
from ..models.unet import UNet3DConditionModel
|
| 33 |
+
from ..utils.util import preprocess_image
|
| 34 |
|
| 35 |
logger = logging.get_logger(__name__) # pylint: disable=invalid-name
|
| 36 |
|
|
|
|
| 285 |
f" {type(callback_steps)}."
|
| 286 |
)
|
| 287 |
|
| 288 |
+
#def prepare_latents(self, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
| 289 |
+
def prepare_latents(self, init_image, batch_size, num_channels_latents, video_length, height, width, dtype, device, generator, latents=None):
|
| 290 |
shape = (batch_size, num_channels_latents, video_length, height // self.vae_scale_factor, width // self.vae_scale_factor)
|
| 291 |
+
|
| 292 |
+
if init_image is not None:
|
| 293 |
+
image = PIL.Image.open(init_image)
|
| 294 |
+
image = preprocess_image(image)
|
| 295 |
+
if not isinstance(image, (torch.Tensor, PIL.Image.Image, list)):
|
| 296 |
+
raise ValueError(
|
| 297 |
+
f"`image` has to be of type `torch.Tensor`, `PIL.Image.Image` or list but is {type(image)}"
|
| 298 |
+
)
|
| 299 |
+
image = image.to(device=device, dtype=dtype)
|
| 300 |
+
if isinstance(generator, list):
|
| 301 |
+
init_latents = [
|
| 302 |
+
self.vae.encode(image[i : i + 1]).latent_dist.sample(generator[i]) for i in range(batch_size)
|
| 303 |
+
]
|
| 304 |
+
init_latents = torch.cat(init_latents, dim=0)
|
| 305 |
+
else:
|
| 306 |
+
init_latents = self.vae.encode(image).latent_dist.sample(generator)
|
| 307 |
+
else:
|
| 308 |
+
init_latents = None
|
| 309 |
+
|
| 310 |
+
|
| 311 |
if isinstance(generator, list) and len(generator) != batch_size:
|
| 312 |
raise ValueError(
|
| 313 |
f"You have passed a list of generators of length {len(generator)}, but requested an effective batch"
|
|
|
|
| 319 |
if isinstance(generator, list):
|
| 320 |
shape = shape
|
| 321 |
# shape = (1,) + shape[1:]
|
| 322 |
+
# ignore init latents for batch model
|
| 323 |
latents = [
|
| 324 |
torch.randn(shape, generator=generator[i], device=rand_device, dtype=dtype)
|
| 325 |
for i in range(batch_size)
|
|
|
|
| 327 |
latents = torch.cat(latents, dim=0).to(device)
|
| 328 |
else:
|
| 329 |
latents = torch.randn(shape, generator=generator, device=rand_device, dtype=dtype).to(device)
|
| 330 |
+
|
| 331 |
+
if init_latents is not None:
|
| 332 |
+
for i in range(video_length):
|
| 333 |
+
# I just feel dividing by 30 yield stable result but I don't know why
|
| 334 |
+
# gradully reduce init alpha along video frames (loosen restriction)
|
| 335 |
+
init_alpha = (video_length - float(i)) / video_length / 30
|
| 336 |
+
latents[:, :, i, :, :] = init_latents * init_alpha + latents[:, :, i, :, :] * (1 - init_alpha)
|
| 337 |
else:
|
| 338 |
if latents.shape != shape:
|
| 339 |
raise ValueError(f"Unexpected latents shape, got {latents.shape}, expected {shape}")
|
| 340 |
latents = latents.to(device)
|
| 341 |
|
| 342 |
# scale the initial noise by the standard deviation required by the scheduler
|
| 343 |
+
#latents = latents * self.scheduler.init_noise_sigma
|
| 344 |
+
if init_latents is None:
|
| 345 |
+
latents = latents * self.scheduler.init_noise_sigma
|
| 346 |
return latents
|
| 347 |
|
| 348 |
@torch.no_grad()
|
| 349 |
def __call__(
|
| 350 |
self,
|
| 351 |
prompt: Union[str, List[str]],
|
| 352 |
+
init_image: str = None,
|
| 353 |
video_length: Optional[int],
|
| 354 |
height: Optional[int] = None,
|
| 355 |
width: Optional[int] = None,
|
|
|
|
| 402 |
# Prepare latent variables
|
| 403 |
num_channels_latents = self.unet.in_channels
|
| 404 |
latents = self.prepare_latents(
|
| 405 |
+
init_image,
|
| 406 |
batch_size * num_videos_per_prompt,
|
| 407 |
num_channels_latents,
|
| 408 |
video_length,
|