|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import torch |
|
|
from torch.utils.data import DataLoader |
|
|
|
|
|
from cosmos_predict1.tokenizer.training.datasets.mock_dataset import CombinedDictDataset, LambdaDataset |
|
|
from cosmos_predict1.tokenizer.training.datasets.utils import VIDEO_KEY, VIDEO_VAL_CROP_SIZE_INFO, get_crop_size_info |
|
|
from cosmos_predict1.utils import log |
|
|
from cosmos_predict1.utils.lazy_config import LazyCall as L |
|
|
from cosmos_predict1.utils.lazy_config import LazyDict |
|
|
|
|
|
_IMAGE_ASPECT_RATIO = "1,1" |
|
|
_VIDEO_ASPECT_RATIO = "16,9" |
|
|
|
|
|
|
|
|
def get_video_dataset( |
|
|
is_train: bool, |
|
|
resolution: str, |
|
|
crop_height: int, |
|
|
num_video_frames: int, |
|
|
): |
|
|
if is_train: |
|
|
crop_sizes = get_crop_size_info(crop_height) |
|
|
log.info( |
|
|
f"[video] training num_frames={num_video_frames}, crop_height={crop_height} and crop_sizes: {crop_sizes}." |
|
|
) |
|
|
else: |
|
|
if crop_height is None: |
|
|
crop_sizes = VIDEO_VAL_CROP_SIZE_INFO[resolution] |
|
|
else: |
|
|
crop_sizes = get_crop_size_info(crop_height) |
|
|
log.info(f"[video] validation num_frames={num_video_frames}, crop_sizes: {crop_sizes}") |
|
|
|
|
|
h = crop_sizes[_VIDEO_ASPECT_RATIO][1] |
|
|
w = crop_sizes[_VIDEO_ASPECT_RATIO][0] |
|
|
|
|
|
def video_fn(): |
|
|
return 2 * torch.rand(3, num_video_frames, h, w) - 1 |
|
|
|
|
|
return CombinedDictDataset( |
|
|
**{ |
|
|
VIDEO_KEY: LambdaDataset(video_fn), |
|
|
} |
|
|
) |
|
|
|
|
|
|
|
|
def get_mock_video_dataloader( |
|
|
batch_size: int, is_train: bool = True, num_video_frames: int = 9, resolution: str = "720", crop_height: int = 128 |
|
|
) -> LazyDict: |
|
|
"""A function to get mock video dataloader. |
|
|
|
|
|
Args: |
|
|
batch_size: The batch size. |
|
|
num_video_frames: The number of video frames. |
|
|
resolution: The resolution. Defaults to "1024". |
|
|
|
|
|
Returns: |
|
|
LazyDict: A LazyDict object specifying the video dataloader. |
|
|
""" |
|
|
if resolution not in VIDEO_VAL_CROP_SIZE_INFO: |
|
|
resolution = "720" |
|
|
return L(DataLoader)( |
|
|
dataset=L(get_video_dataset)( |
|
|
is_train=is_train, |
|
|
resolution=resolution, |
|
|
crop_height=crop_height, |
|
|
num_video_frames=num_video_frames, |
|
|
), |
|
|
batch_size=batch_size, |
|
|
shuffle=False, |
|
|
num_workers=8, |
|
|
) |
|
|
|