multishot / check_dataset.py
PencilHu's picture
Upload folder using huggingface_hub
85752bc verified
Raw
History Blame Contribute Delete
4.55 kB
import argparse
import json
import random
import time
from pathlib import Path
import imageio.v2 as imageio
import numpy as np
import yaml
from multi_view.datasets.videodataset import MulltiShot_MultiView_Dataset
def save_video(frames, path: Path, fps: int = 16) -> None:
if not frames:
raise ValueError("No frames to save.")
writer = imageio.get_writer(str(path), fps=fps)
try:
for frame in frames:
if hasattr(frame, "convert"):
frame = np.asarray(frame.convert("RGB"))
frame = np.asarray(frame)
if frame.dtype != np.uint8:
frame = np.clip(frame, 0, 255).astype(np.uint8)
writer.append_data(frame)
finally:
writer.close()
def ensure_dir(path: Path) -> None:
path.mkdir(parents=True, exist_ok=True)
def reseed(base_seed: int, idx: int) -> None:
seed = base_seed + idx
random.seed(seed)
np.random.seed(seed)
def main() -> int:
parser = argparse.ArgumentParser(description="Inspect dataset samples and dump training inputs.")
parser.add_argument("--train_yaml", type=str, required=True)
parser.add_argument("--dataset_json", type=str, default="")
parser.add_argument("--output_dir", type=str, default="")
parser.add_argument("--indices", type=int, nargs="+", default=[])
parser.add_argument("--num_samples", type=int, default=4)
parser.add_argument("--seed", type=int, default=1234)
parser.add_argument("--split", choices=["train", "test", "all"], default="train")
args = parser.parse_args()
with open(args.train_yaml, "r", encoding="utf-8") as f:
conf = yaml.safe_load(f)
dataset_args = conf.get("dataset_args", {})
dataset_json = args.dataset_json or dataset_args.get("base_path", "")
if not dataset_json:
raise ValueError("dataset_json is required (or set dataset_args.base_path in YAML).")
height = int(dataset_args.get("height", 480))
width = int(dataset_args.get("width", 832))
ref_num = int(dataset_args.get("ref_num", 3))
dataset = MulltiShot_MultiView_Dataset(
dataset_base_path=dataset_json,
resolution=(height, width),
ref_num=ref_num,
training=args.split != "test",
)
if args.split == "all":
dataset.data_train = dataset.data
dataset.data_test = dataset.data
dataset.training = True
ts = time.strftime("%Y%m%d_%H%M%S")
output_dir = Path(args.output_dir) if args.output_dir else Path(__file__).resolve().parent / "logs" / "dataset_check" / ts
ensure_dir(output_dir)
if args.indices:
indices = args.indices
else:
sample_count = min(args.num_samples, len(dataset))
indices = list(range(sample_count))
manifest = {
"train_yaml": args.train_yaml,
"dataset_json": dataset_json,
"split": args.split,
"height": height,
"width": width,
"ref_num": ref_num,
"indices": indices,
"samples": [],
}
for idx in indices:
reseed(args.seed, idx)
sample = dataset[idx]
video = sample.get("video", [])
ref_images = sample.get("ref_images", [])
shot_captions = sample.get("pre_shot_caption", [])
sample_dir = output_dir / f"sample_{idx}"
ensure_dir(sample_dir)
video_path = sample_dir / "input.mp4"
if video:
save_video(video, video_path, fps=16)
refs_dir = sample_dir / "refs"
ensure_dir(refs_dir)
for id_i, ref_group in enumerate(ref_images):
for img_i, img in enumerate(ref_group):
img.save(refs_dir / f"id{id_i}_img{img_i}.png")
summary = {
"index": idx,
"video_path": sample.get("video_path"),
"num_frames": len(video),
"shot_num": sample.get("shot_num"),
"pre_shot_caption": shot_captions,
"ref_num": sample.get("ref_num"),
"ID_num": sample.get("ID_num"),
"saved_video": str(video_path),
"saved_refs_dir": str(refs_dir),
}
with (sample_dir / "summary.json").open("w", encoding="utf-8") as f:
json.dump(summary, f, ensure_ascii=False, indent=2)
manifest["samples"].append(summary)
with (output_dir / "manifest.json").open("w", encoding="utf-8") as f:
json.dump(manifest, f, ensure_ascii=False, indent=2)
print(f"Saved dataset check logs to: {output_dir}")
return 0
if __name__ == "__main__":
raise SystemExit(main())