Spaces:
Runtime error
Runtime error
| from typing import Optional, Tuple | |
| import torch | |
| from diffusers.models.embeddings import get_3d_rotary_pos_embed | |
| from diffusers.pipelines.cogvideo.pipeline_cogvideox import get_resize_crop_region_for_grid | |
| def prepare_rotary_positional_embeddings( | |
| height: int, | |
| width: int, | |
| num_frames: int, | |
| vae_scale_factor_spatial: int = 8, | |
| patch_size: int = 2, | |
| patch_size_t: int = None, | |
| attention_head_dim: int = 64, | |
| device: Optional[torch.device] = None, | |
| base_height: int = 480, | |
| base_width: int = 720, | |
| ) -> Tuple[torch.Tensor, torch.Tensor]: | |
| grid_height = height // (vae_scale_factor_spatial * patch_size) | |
| grid_width = width // (vae_scale_factor_spatial * patch_size) | |
| base_size_width = base_width // (vae_scale_factor_spatial * patch_size) | |
| base_size_height = base_height // (vae_scale_factor_spatial * patch_size) | |
| if patch_size_t is None: | |
| # CogVideoX 1.0 | |
| grid_crops_coords = get_resize_crop_region_for_grid( | |
| (grid_height, grid_width), base_size_width, base_size_height | |
| ) | |
| freqs_cos, freqs_sin = get_3d_rotary_pos_embed( | |
| embed_dim=attention_head_dim, | |
| crops_coords=grid_crops_coords, | |
| grid_size=(grid_height, grid_width), | |
| temporal_size=num_frames, | |
| ) | |
| else: | |
| # CogVideoX 1.5 | |
| base_num_frames = (num_frames + patch_size_t - 1) // patch_size_t | |
| freqs_cos, freqs_sin = get_3d_rotary_pos_embed( | |
| embed_dim=attention_head_dim, | |
| crops_coords=None, | |
| grid_size=(grid_height, grid_width), | |
| temporal_size=base_num_frames, | |
| grid_type="slice", | |
| max_size=(base_size_height, base_size_width), | |
| ) | |
| freqs_cos = freqs_cos.to(device=device) | |
| freqs_sin = freqs_sin.to(device=device) | |
| return freqs_cos, freqs_sin | |