File size: 3,966 Bytes
142a1ac
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
import pandas as pd
from typing import List, Tuple, Any, Dict
import time
import json
from pathlib import Path
import decord
from .video_base import VideoDataset


class PandasVideoDataset(VideoDataset):
    def _load_records(self) -> Tuple[List[str], List[str]]:
        """
        Given the metadata file, loads the records as a list.
        Each record is a dictionary containing a datapoint's mp4 path / caption etc.
        Require these entries: "video_path", "caption", "height", "width", "n_frames", "fps"

        For pandas70m, there are one extra key "youtube_key_segment", looks like: "2NQDnwJEBeQ_segment_7".
        It's the key identifier for the video.

        Pandas 70M comes with json config file. This method will convert the json config file to a csv file and save it before using.
        """
        if self.metadata_path.suffix == ".json":
            # convert a legacy json file to a csv file we need
            start_time = time.time()
            records = []
            with open(self.data_root / self.metadata_path, "r") as f:
                for line in f:
                    item = json.loads(line)
                    if "mp4_path" in item:
                        item["video_path"] = item["mp4_path"]
                        del item["mp4_path"]
                    if "start_frame_index" in item:
                        item["trim_start"] = item["start_frame_index"]
                        del item["start_frame_index"]
                    if "end_frame_index" in item:
                        item["trim_end"] = item["end_frame_index"]
                        del item["end_frame_index"]
                    if "prompt_embed_path" in item:
                        item["prompt_embed_path"] = (
                            "prompt_embeds/" + item["prompt_embed_path"] + ".pt"
                        )
                    if "answers_for_four_questions" in item:
                        del item["answers_for_four_questions"]
                    records.append(item)

            df = pd.DataFrame.from_records(records)
            csv_path = self.metadata_path.with_suffix(".csv")
            df.to_csv(self.data_root / csv_path, index=False)
            self.metadata_path = csv_path
            end_time = time.time()
            print(f"Time taken for converting records: {end_time - start_time} seconds")

        return super()._load_records()


if __name__ == "__main__":
    # do debug test
    import torch
    from omegaconf import OmegaConf

    debug_config = {
        "debug": True,
        "data_root": "/n/holylfs06/LABS/sham_lab/Lab/eiwm_data/pandas/",
        "metadata_path": "pandas_filtered_human_clip_meta_gemini_1.5_flash.json",
        "auto_download": False,
        "force_download": False,
        "test_percentage": 0.1,
        "id_token": "",
        "resolution": [256, 256],
        "n_frames": 8,
        "fps": 30,
        "trim_mode": "speedup",
        "pad_mode": "pad_last",
        "filtering": {
            "disable": False,
            "height": [32, 2160],
            "width": [32, 3840],
            "n_frames": [8, 1000],
            "fps": [1, 60],
        },
        "load_video_latent": False,
        "load_prompt_embed": False,
        "augmentation": {"random_flip": 0.5, "ratio": None, "scale": None},
        "image_to_video": False,
        "check_video_path": False,
    }

    # Convert dict to OmegaConf
    cfg = OmegaConf.create(debug_config)

    # Create dataset
    dataset = PandasVideoDataset(cfg=cfg, split="training")

    # Load one sample and print its contents
    sample = dataset[0]
    print("\nSample contents:")
    for key, value in sample.items():
        if isinstance(value, torch.Tensor):
            print(f"{key}: Tensor of shape {value.shape}")
        elif isinstance(value, dict):
            print(f"{key}:")
            for k, v in value.items():
                print(f"  {k}: {v}")
        else:
            print(f"{key}: {value}")