LVP / datasets /pandas.py
kiwhansong's picture
add demo
142a1ac
import pandas as pd
from typing import List, Tuple, Any, Dict
import time
import json
from pathlib import Path
import decord
from .video_base import VideoDataset
class PandasVideoDataset(VideoDataset):
def _load_records(self) -> Tuple[List[str], List[str]]:
"""
Given the metadata file, loads the records as a list.
Each record is a dictionary containing a datapoint's mp4 path / caption etc.
Require these entries: "video_path", "caption", "height", "width", "n_frames", "fps"
For pandas70m, there are one extra key "youtube_key_segment", looks like: "2NQDnwJEBeQ_segment_7".
It's the key identifier for the video.
Pandas 70M comes with json config file. This method will convert the json config file to a csv file and save it before using.
"""
if self.metadata_path.suffix == ".json":
# convert a legacy json file to a csv file we need
start_time = time.time()
records = []
with open(self.data_root / self.metadata_path, "r") as f:
for line in f:
item = json.loads(line)
if "mp4_path" in item:
item["video_path"] = item["mp4_path"]
del item["mp4_path"]
if "start_frame_index" in item:
item["trim_start"] = item["start_frame_index"]
del item["start_frame_index"]
if "end_frame_index" in item:
item["trim_end"] = item["end_frame_index"]
del item["end_frame_index"]
if "prompt_embed_path" in item:
item["prompt_embed_path"] = (
"prompt_embeds/" + item["prompt_embed_path"] + ".pt"
)
if "answers_for_four_questions" in item:
del item["answers_for_four_questions"]
records.append(item)
df = pd.DataFrame.from_records(records)
csv_path = self.metadata_path.with_suffix(".csv")
df.to_csv(self.data_root / csv_path, index=False)
self.metadata_path = csv_path
end_time = time.time()
print(f"Time taken for converting records: {end_time - start_time} seconds")
return super()._load_records()
if __name__ == "__main__":
# do debug test
import torch
from omegaconf import OmegaConf
debug_config = {
"debug": True,
"data_root": "/n/holylfs06/LABS/sham_lab/Lab/eiwm_data/pandas/",
"metadata_path": "pandas_filtered_human_clip_meta_gemini_1.5_flash.json",
"auto_download": False,
"force_download": False,
"test_percentage": 0.1,
"id_token": "",
"resolution": [256, 256],
"n_frames": 8,
"fps": 30,
"trim_mode": "speedup",
"pad_mode": "pad_last",
"filtering": {
"disable": False,
"height": [32, 2160],
"width": [32, 3840],
"n_frames": [8, 1000],
"fps": [1, 60],
},
"load_video_latent": False,
"load_prompt_embed": False,
"augmentation": {"random_flip": 0.5, "ratio": None, "scale": None},
"image_to_video": False,
"check_video_path": False,
}
# Convert dict to OmegaConf
cfg = OmegaConf.create(debug_config)
# Create dataset
dataset = PandasVideoDataset(cfg=cfg, split="training")
# Load one sample and print its contents
sample = dataset[0]
print("\nSample contents:")
for key, value in sample.items():
if isinstance(value, torch.Tensor):
print(f"{key}: Tensor of shape {value.shape}")
elif isinstance(value, dict):
print(f"{key}:")
for k, v in value.items():
print(f" {k}: {v}")
else:
print(f"{key}: {value}")