PencilFolder / examples /wanvideo /model_training /prepare_instancev_instancecap_bbox.py
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified
Raw
History Blame Contribute Delete
13.2 kB
#!/usr/bin/env python3
"""
Prepare InstanceV training data from InstanceCap + InstanceCap-BBox.
Outputs per-line JSON:
{
"video": "OpenVid1M-Video-InstanceCap/<video>.mp4",
"prompt": "global + background + camera",
"instance_prompts": ["instance prompt 1", ...],
"instance_mask_dirs": [
{"mask_dir": "/abs/path/to/masks", "instance_id": 1, "num_frames": 20},
...
]
}
"""
import argparse
import glob
import json
import math
import multiprocessing as mp
import os
from pathlib import Path
import imageio.v2 as imageio
from PIL import Image, ImageDraw
from tqdm import tqdm
_WORKER_ARGS = None
def parse_args():
parser = argparse.ArgumentParser(description="Prepare InstanceV data from InstanceCap-BBox")
parser.add_argument(
"--instancecap_path",
type=str,
default="/data/rczhang/PencilFolder/data/InstanceCap/InstanceCap.jsonl",
help="Path to InstanceCap.jsonl",
)
parser.add_argument(
"--instancecap_bbox_dir",
type=str,
default="/data/rczhang/PencilFolder/data/InstanceCap-BBox",
help="Directory containing InstanceCap-BBox folders",
)
parser.add_argument(
"--video_dir",
type=str,
default="/data/rczhang/PencilFolder/data/OpenVid1M-Video-InstanceCap",
help="Directory containing videos",
)
parser.add_argument(
"--mask_root_dir",
type=str,
default="/data/rczhang/PencilFolder/data/InstanceCap-BBox-Masks",
help="Root directory to store generated bbox masks",
)
parser.add_argument(
"--output_path",
type=str,
default="/data/rczhang/PencilFolder/data/InstanceCap/instancev_instancecap_bbox.jsonl",
help="Output JSONL path",
)
parser.add_argument(
"--dataset_base_path",
type=str,
default="/data/rczhang/PencilFolder/data",
help="Base path used by UnifiedDataset",
)
parser.add_argument("--min_instances", type=int, default=1)
parser.add_argument("--max_instances", type=int, default=5)
parser.add_argument("--min_frames", type=int, default=0, help="Minimum frame count required.")
parser.add_argument("--overwrite_masks", action="store_true")
parser.add_argument("--check_video", action="store_true")
parser.add_argument("--limit", type=int, default=None)
parser.add_argument(
"--num_workers",
type=int,
default=1,
help="Number of worker processes to use for preprocessing.",
)
return parser.parse_args()
def _safe_relpath(path: str, base_path: str) -> str:
if not base_path:
return path
return os.path.relpath(path, base_path)
def _clamp_bbox(bbox, width: int, height: int):
if not bbox or len(bbox) != 4:
return None
x0, y0, x1, y1 = bbox
left = max(0, int(math.floor(x0)))
top = max(0, int(math.floor(y0)))
right = min(width, int(math.ceil(x1)))
bottom = min(height, int(math.ceil(y1)))
if right <= left or bottom <= top:
return None
return left, top, right, bottom
def _is_video_readable(video_path: str) -> bool:
try:
reader = imageio.get_reader(video_path)
try:
reader.get_data(0)
finally:
reader.close()
except Exception:
return False
return True
def build_instance_prompt(instance_info: dict) -> str:
cls = instance_info.get("Class", "object")
appearance = instance_info.get("Appearance", "")
actions = instance_info.get("Actions and Motion", "")
parts = []
if cls:
parts.append(f"A {cls}")
if appearance:
parts.append(appearance)
if actions:
parts.append(actions)
prompt = ". ".join(parts).replace("..", ".").replace(" ", " ").strip()
if prompt and not prompt.endswith("."):
prompt += "."
return prompt
def _process_line(line: str, args):
if not line.strip():
return "skip_empty", None
try:
sample = json.loads(line)
except Exception as exc: # pragma: no cover - defensive
return "error", f"Invalid JSON: {exc}"
video_name = sample.get("Video", "")
if not video_name:
return "skip_missing_video", None
video_path = os.path.join(args.video_dir, video_name)
if not os.path.isfile(video_path):
return "skip_missing_video", None
if args.check_video and not _is_video_readable(video_path):
return "skip_unreadable", None
bbox_dir = os.path.join(args.instancecap_bbox_dir, Path(video_name).stem)
instances_path = os.path.join(bbox_dir, "instances.json")
meta_path = os.path.join(bbox_dir, "meta.json")
mask_json_dir = os.path.join(bbox_dir, "json_data")
if not (os.path.isdir(bbox_dir) and os.path.isfile(instances_path) and os.path.isdir(mask_json_dir)):
return "skip_missing_bbox", None
try:
with open(instances_path, "r", encoding="utf-8") as f:
inst_data = json.load(f)
except Exception as exc: # pragma: no cover - defensive
return "error", f"Invalid instances.json for {video_name}: {exc}"
instances = inst_data.get("instances", [])
if not instances:
return "skip_instances", None
mask_files = sorted(glob.glob(os.path.join(mask_json_dir, "mask_*.json")))
if not mask_files:
return "skip_missing_bbox", None
meta_frames = None
if os.path.isfile(meta_path):
try:
with open(meta_path, "r", encoding="utf-8") as f:
meta = json.load(f)
meta_frames = int(meta.get("frame_count", 0)) or None
except Exception:
meta_frames = None
if args.min_frames and meta_frames is not None and meta_frames < args.min_frames:
return "skip_short", None
if args.min_frames and len(mask_files) < args.min_frames:
return "skip_short", None
present_ids = set()
for mf in mask_files:
try:
with open(mf, "r", encoding="utf-8") as f:
frame_data = json.load(f)
except Exception as exc: # pragma: no cover - defensive
return "error", f"Invalid mask file {mf}: {exc}"
labels = frame_data.get("labels", {})
for label in labels.values():
inst_id = label.get("instance_id")
if inst_id is not None:
present_ids.add(int(inst_id))
struct_desc = sample.get("Structural Description", {})
main_instances = struct_desc.get("Main Instance", {})
filtered = []
for inst in instances:
inst_id = int(inst.get("instance_id", -1))
if inst_id < 0 or inst_id not in present_ids:
continue
source_key = inst.get("source_instance")
prompt = ""
if source_key and source_key in main_instances:
prompt = build_instance_prompt(main_instances[source_key])
if not prompt:
name = inst.get("name", "object")
caption = inst.get("caption", "").strip()
prompt = f"{name}. {caption}".strip() if caption else name
filtered.append((inst_id, prompt))
if args.max_instances is not None:
filtered = filtered[: args.max_instances]
if len(filtered) < args.min_instances:
return "skip_instances", None
mask_output_dir = os.path.join(args.mask_root_dir, f"{Path(video_name).stem}_masks")
if not os.path.isdir(mask_output_dir) or args.overwrite_masks:
os.makedirs(mask_output_dir, exist_ok=True)
keep_ids = {inst_id for inst_id, _ in filtered}
for frame_idx, mf in enumerate(mask_files):
with open(mf, "r", encoding="utf-8") as f:
frame_data = json.load(f)
mask_h = int(frame_data.get("mask_height", 0))
mask_w = int(frame_data.get("mask_width", 0))
labels = frame_data.get("labels", {})
bbox_map = {}
for label in labels.values():
inst_id = int(label.get("instance_id", -1))
if inst_id in keep_ids:
bbox_map[inst_id] = (
label.get("x1", 0),
label.get("y1", 0),
label.get("x2", 0),
label.get("y2", 0),
)
for inst_id, _prompt in filtered:
mask = Image.new("L", (mask_w, mask_h), 0)
bbox = bbox_map.get(inst_id)
if bbox is not None:
coords = _clamp_bbox(bbox, mask_w, mask_h)
if coords is not None:
draw = ImageDraw.Draw(mask)
draw.rectangle(coords, fill=255)
mask_path = os.path.join(mask_output_dir, f"{frame_idx:06d}_No.{inst_id}.png")
mask.save(mask_path)
num_frames = len(mask_files)
global_desc = sample.get("Global Description", "")
background = struct_desc.get("Background Detail", "")
camera = struct_desc.get("Camera Movement", "")
full_prompt = " ".join([p for p in [global_desc, background, camera] if p])
instance_prompts = [p for _id, p in filtered]
instance_mask_dirs = [
{"mask_dir": mask_output_dir, "instance_id": inst_id, "num_frames": num_frames}
for inst_id, _prompt in filtered
]
entry = {
"video": _safe_relpath(video_path, args.dataset_base_path),
"prompt": full_prompt,
"instance_prompts": instance_prompts,
"instance_mask_dirs": instance_mask_dirs,
}
return "ok", entry
def _init_worker(args):
global _WORKER_ARGS
_WORKER_ARGS = args
def _process_line_worker(line: str):
return _process_line(line, _WORKER_ARGS)
def main():
args = parse_args()
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
os.makedirs(args.mask_root_dir, exist_ok=True)
processed = 0
wrote = 0
skipped_missing_video = 0
skipped_unreadable = 0
skipped_missing_bbox = 0
skipped_short = 0
skipped_instances = 0
errors = 0
with open(args.instancecap_path, "r", encoding="utf-8") as f_in, open(
args.output_path, "w", encoding="utf-8"
) as f_out:
if args.num_workers <= 1:
for line in tqdm(f_in, desc="Processing InstanceCap"):
if args.limit is not None and wrote >= args.limit:
break
status, payload = _process_line(line, args)
if status == "skip_empty":
continue
processed += 1
if status == "ok":
f_out.write(json.dumps(payload, ensure_ascii=False) + "\n")
wrote += 1
continue
if status == "skip_missing_video":
skipped_missing_video += 1
elif status == "skip_unreadable":
skipped_unreadable += 1
elif status == "skip_missing_bbox":
skipped_missing_bbox += 1
elif status == "skip_short":
skipped_short += 1
elif status == "skip_instances":
skipped_instances += 1
else:
errors += 1
else:
pool = mp.Pool(processes=args.num_workers, initializer=_init_worker, initargs=(args,))
stop_early = False
try:
results = pool.imap_unordered(_process_line_worker, f_in, chunksize=1)
for status, payload in tqdm(results, desc="Processing InstanceCap"):
if status == "skip_empty":
continue
processed += 1
if status == "ok":
f_out.write(json.dumps(payload, ensure_ascii=False) + "\n")
wrote += 1
if args.limit is not None and wrote >= args.limit:
stop_early = True
break
continue
if status == "skip_missing_video":
skipped_missing_video += 1
elif status == "skip_unreadable":
skipped_unreadable += 1
elif status == "skip_missing_bbox":
skipped_missing_bbox += 1
elif status == "skip_short":
skipped_short += 1
elif status == "skip_instances":
skipped_instances += 1
else:
errors += 1
finally:
if stop_early:
pool.terminate()
else:
pool.close()
pool.join()
print("Done.")
print(f"Processed: {processed}")
print(f"Wrote: {wrote}")
print(f"Skipped (missing video): {skipped_missing_video}")
print(f"Skipped (unreadable video): {skipped_unreadable}")
print(f"Skipped (missing bbox): {skipped_missing_bbox}")
print(f"Skipped (short clips): {skipped_short}")
print(f"Skipped (insufficient instances): {skipped_instances}")
print(f"Errors: {errors}")
if __name__ == "__main__":
main()