Spaces:
Build error
Build error
| # SPDX-FileCopyrightText: Copyright (c) 2025 NVIDIA CORPORATION & AFFILIATES. All rights reserved. | |
| # SPDX-License-Identifier: Apache-2.0 | |
| # | |
| # Licensed under the Apache License, Version 2.0 (the "License"); | |
| # you may not use this file except in compliance with the License. | |
| # You may obtain a copy of the License at | |
| # | |
| # http://www.apache.org/licenses/LICENSE-2.0 | |
| # | |
| # Unless required by applicable law or agreed to in writing, software | |
| # distributed under the License is distributed on an "AS IS" BASIS, | |
| # WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | |
| # See the License for the specific language governing permissions and | |
| # limitations under the License. | |
| """Implementations of dataset settings and augmentations for tokenization | |
| Run this command to interactively debug: | |
| python3 -m cosmos_predict1.tokenizer.training.datasets.dataset_provider | |
| """ | |
| from cosmos_predict1.tokenizer.training.datasets.augmentation_provider import ( | |
| video_train_augmentations, | |
| video_val_augmentations, | |
| ) | |
| from cosmos_predict1.tokenizer.training.datasets.utils import categorize_aspect_and_store | |
| from cosmos_predict1.tokenizer.training.datasets.video_dataset import Dataset | |
| from cosmos_predict1.utils.lazy_config import instantiate | |
| _VIDEO_PATTERN_DICT = { | |
| "hdvila_video": "datasets/hdvila/videos/*.mp4", | |
| } | |
| def apply_augmentations(data_dict, augmentations_dict): | |
| """ | |
| Loop over each LazyCall object and apply it to data_dict in place. | |
| """ | |
| for aug_name, lazy_aug in augmentations_dict.items(): | |
| aug_instance = instantiate(lazy_aug) | |
| data_dict = aug_instance(data_dict) | |
| return data_dict | |
| class AugmentDataset(Dataset): | |
| def __init__(self, base_dataset, augmentations_dict): | |
| """ | |
| base_dataset: the video dataset instance | |
| augmentations_dict: the dictionary returned by | |
| video_train_augmentations() or video_val_augmentations() | |
| """ | |
| self.base_dataset = base_dataset | |
| # Pre-instantiate every augmentation ONCE: | |
| self.augmentations = [] | |
| for aug_name, lazy_aug in augmentations_dict.items(): | |
| aug_instance = instantiate(lazy_aug) # build the actual augmentation | |
| self.augmentations.append((aug_name, aug_instance)) | |
| def __len__(self): | |
| return len(self.base_dataset) | |
| def __getitem__(self, index): | |
| # Get the raw sample from the base dataset | |
| data = self.base_dataset[index] | |
| data = categorize_aspect_and_store(data) | |
| # Apply each pre-instantiated augmentation | |
| for aug_name, aug_instance in self.augmentations: | |
| data = aug_instance(data) | |
| return data | |
| def dataset_entry( | |
| dataset_name: str, | |
| dataset_type: str, | |
| is_train: bool = True, | |
| resolution="720", | |
| crop_height=256, | |
| num_video_frames=25, | |
| ) -> AugmentDataset: | |
| if dataset_type != "video": | |
| raise ValueError(f"Dataset type {dataset_type} is not supported") | |
| # Instantiate the video dataset | |
| base_dataset = Dataset( | |
| video_pattern=_VIDEO_PATTERN_DICT[dataset_name.lower()], | |
| num_video_frames=num_video_frames, | |
| ) | |
| # Pick the training or validation augmentations | |
| if is_train: | |
| aug_dict = video_train_augmentations( | |
| input_keys=["video"], # adjust if necessary | |
| resolution=resolution, | |
| crop_height=crop_height, | |
| ) | |
| else: | |
| aug_dict = video_val_augmentations( | |
| input_keys=["video"], | |
| resolution=resolution, | |
| crop_height=crop_height, | |
| ) | |
| # Wrap the dataset with the augmentations | |
| return AugmentDataset(base_dataset, aug_dict) | |
| if __name__ == "__main__": | |
| # Example usage / quick test | |
| dataset = dataset_entry( | |
| dataset_name="davis_video", | |
| dataset_type="video", | |
| is_train=False, | |
| resolution="720", | |
| crop_height=256, | |
| num_video_frames=25, | |
| ) | |
| # 2) Print out some basic info: | |
| print(f"Total samples in dataset: {len(dataset)}") | |
| # 3) Grab one sample (or a few) to check shapes, keys, etc. | |
| if len(dataset) > 0: | |
| sample_idx = 0 | |
| sample = dataset[sample_idx] | |
| print(f"Sample index {sample_idx} keys: {list(sample.keys())}") | |
| if "video" in sample: | |
| print("Video shape:", sample["video"].shape) | |
| if "video_name" in sample: | |
| print("Video metadata:", sample["video_name"]) | |
| print("---\nSample loaded successfully.\n") | |
| else: | |
| print("Dataset has no samples!") | |