#!/usr/bin/env python3 """ InstanceV 训练数据预处理脚本 功能: 1. 读取 InstanceCap.jsonl(包含 Video, Global Description, Structural Description) 2. 匹配 InstanceLabel 中的 mask 目录 3. 匹配 OpenVid1M-Video 中的视频文件 4. 生成适合训练的 JSONL 格式 数据目录结构: data/InstanceCap/InstanceCap.jsonl # 原始标注 data/InstanceLabel/{video_name}_masks/ # 每个 instance 的 mask 序列 data/OpenVid1M-Video/{video_name}.mp4 # 原始视频 输出格式(每行一个 JSON): { "video": "path/to/video.mp4", "prompt": "global description", "instance_prompts": ["instance 0 description", "instance 1 description", ...], "instance_mask_dirs": ["path/to/masks/No.0", "path/to/masks/No.1", ...] } """ import os import json import argparse from pathlib import Path from collections import defaultdict from tqdm import tqdm def parse_args(): parser = argparse.ArgumentParser(description="Prepare InstanceV training data") parser.add_argument( "--instancecap_path", type=str, default="/data/rczhang/PencilFolder/data/InstanceCap/InstanceCap.jsonl", help="Path to InstanceCap.jsonl", ) parser.add_argument( "--instance_label_dir", type=str, default="/data/rczhang/PencilFolder/data/InstanceLabel", help="Directory containing instance masks", ) parser.add_argument( "--video_dir", type=str, default="/data/rczhang/PencilFolder/data/OpenVid1M-Video", help="Directory containing source videos", ) parser.add_argument( "--output_path", type=str, default="/data/rczhang/PencilFolder/data/instancev_train.jsonl", help="Output JSONL path", ) parser.add_argument( "--min_instances", type=int, default=1, help="Minimum number of instances required", ) parser.add_argument( "--max_instances", type=int, default=10, help="Maximum number of instances to keep", ) parser.add_argument( "--use_dense_caption", action="store_true", help="Use dense caption format for instance prompts", ) return parser.parse_args() def get_video_name_from_path(video_path: str) -> str: """从 video 路径提取 video name(不含扩展名)""" return Path(video_path).stem def find_mask_dirs(instance_label_dir: str, video_name: str) -> dict: """ 查找某个视频对应的所有 instance mask 目录 mask 目录结构: {video_name}_masks/ mask 文件命名: {frame_id:06d}_No.{instance_id}.png Returns: dict: {instance_id: [mask_file_paths]} """ mask_dir = os.path.join(instance_label_dir, f"{video_name}_masks") if not os.path.isdir(mask_dir): return {} instance_masks = defaultdict(list) for fname in sorted(os.listdir(mask_dir)): if not fname.endswith(".png"): continue # 解析文件名: 000000_No.0.png parts = fname.replace(".png", "").split("_No.") if len(parts) != 2: continue frame_id_str, inst_id_str = parts try: frame_id = int(frame_id_str) inst_id = int(inst_id_str) except ValueError: continue instance_masks[inst_id].append(os.path.join(mask_dir, fname)) return dict(instance_masks) def build_instance_prompt(instance_info: dict) -> str: """ 从 InstanceCap 的 instance 信息构建 prompt instance_info 结构: { "Class": "person", "Appearance": "...", "Actions and Motion": "...", "Position": "..." } """ cls = instance_info.get("Class", "object") appearance = instance_info.get("Appearance", "") actions = instance_info.get("Actions and Motion", "") # 构建精简但信息丰富的 prompt parts = [] if cls: parts.append(f"A {cls}") if appearance: parts.append(appearance) if actions: parts.append(actions) prompt = ". ".join(parts) # 清理多余空格和标点 prompt = prompt.replace("..", ".").replace(" ", " ").strip() if prompt and not prompt.endswith("."): prompt += "." return prompt def process_sample( sample: dict, instance_label_dir: str, video_dir: str, min_instances: int, max_instances: int, ) -> dict | None: """ 处理单个样本,返回训练格式的 dict 或 None """ video_name = get_video_name_from_path(sample.get("Video", "")) if not video_name: return None # 检查视频是否存在 video_path = os.path.join(video_dir, f"{video_name}.mp4") if not os.path.isfile(video_path): return None # 查找 mask 目录 instance_masks = find_mask_dirs(instance_label_dir, video_name) if len(instance_masks) < min_instances: return None # 提取 instance 信息 struct_desc = sample.get("Structural Description", {}) main_instances = struct_desc.get("Main Instance", {}) instance_prompts = [] instance_mask_dirs = [] for inst_key in sorted(main_instances.keys()): # inst_key: "No.0", "No.1", ... try: inst_id = int(inst_key.replace("No.", "")) except ValueError: continue if inst_id not in instance_masks: continue inst_info = main_instances[inst_key] prompt = build_instance_prompt(inst_info) if not prompt: continue instance_prompts.append(prompt) # 存储整个 mask 目录路径(而非单个文件列表,训练时动态加载) mask_dir = os.path.dirname(instance_masks[inst_id][0]) instance_mask_dirs.append({ "mask_dir": mask_dir, "instance_id": inst_id, "num_frames": len(instance_masks[inst_id]), }) if len(instance_prompts) >= max_instances: break if len(instance_prompts) < min_instances: return None # 构建输出 global_desc = sample.get("Global Description", "") background = struct_desc.get("Background Detail", "") camera = struct_desc.get("Camera Movement", "") # 合并为完整 prompt full_prompt_parts = [global_desc] if background: full_prompt_parts.append(background) if camera: full_prompt_parts.append(camera) full_prompt = " ".join(full_prompt_parts) return { "video": video_path, "prompt": full_prompt, "instance_prompts": instance_prompts, "instance_mask_dirs": instance_mask_dirs, } def main(): args = parse_args() # 读取 InstanceCap print(f"Loading InstanceCap from {args.instancecap_path}") samples = [] with open(args.instancecap_path, "r", encoding="utf-8") as f: for line in f: line = line.strip() if not line: continue try: samples.append(json.loads(line)) except json.JSONDecodeError as e: print(f"Warning: Failed to parse line: {e}") continue print(f"Loaded {len(samples)} samples") # 处理每个样本 output_samples = [] for sample in tqdm(samples, desc="Processing samples"): result = process_sample( sample, args.instance_label_dir, args.video_dir, args.min_instances, args.max_instances, ) if result is not None: output_samples.append(result) print(f"Valid samples: {len(output_samples)} / {len(samples)}") # 写入输出 os.makedirs(os.path.dirname(args.output_path), exist_ok=True) with open(args.output_path, "w", encoding="utf-8") as f: for sample in output_samples: f.write(json.dumps(sample, ensure_ascii=False) + "\n") print(f"Saved to {args.output_path}") # 打印统计信息 if output_samples: avg_instances = sum(len(s["instance_prompts"]) for s in output_samples) / len(output_samples) print(f"Average instances per sample: {avg_instances:.2f}") if __name__ == "__main__": main()