PencilFolder / examples /wanvideo /model_training /prepare_instancev_data.py
PencilHu's picture
Upload folder using huggingface_hub
1146a67 verified
Raw
History Blame Contribute Delete
8.35 kB
#!/usr/bin/env python3
"""
InstanceV 训练数据预处理脚本
功能:
1. 读取 InstanceCap.jsonl(包含 Video, Global Description, Structural Description)
2. 匹配 InstanceLabel 中的 mask 目录
3. 匹配 OpenVid1M-Video 中的视频文件
4. 生成适合训练的 JSONL 格式
数据目录结构:
data/InstanceCap/InstanceCap.jsonl # 原始标注
data/InstanceLabel/{video_name}_masks/ # 每个 instance 的 mask 序列
data/OpenVid1M-Video/{video_name}.mp4 # 原始视频
输出格式(每行一个 JSON):
{
"video": "path/to/video.mp4",
"prompt": "global description",
"instance_prompts": ["instance 0 description", "instance 1 description", ...],
"instance_mask_dirs": ["path/to/masks/No.0", "path/to/masks/No.1", ...]
}
"""
import os
import json
import argparse
from pathlib import Path
from collections import defaultdict
from tqdm import tqdm
def parse_args():
parser = argparse.ArgumentParser(description="Prepare InstanceV training data")
parser.add_argument(
"--instancecap_path",
type=str,
default="/data/rczhang/PencilFolder/data/InstanceCap/InstanceCap.jsonl",
help="Path to InstanceCap.jsonl",
)
parser.add_argument(
"--instance_label_dir",
type=str,
default="/data/rczhang/PencilFolder/data/InstanceLabel",
help="Directory containing instance masks",
)
parser.add_argument(
"--video_dir",
type=str,
default="/data/rczhang/PencilFolder/data/OpenVid1M-Video",
help="Directory containing source videos",
)
parser.add_argument(
"--output_path",
type=str,
default="/data/rczhang/PencilFolder/data/instancev_train.jsonl",
help="Output JSONL path",
)
parser.add_argument(
"--min_instances",
type=int,
default=1,
help="Minimum number of instances required",
)
parser.add_argument(
"--max_instances",
type=int,
default=10,
help="Maximum number of instances to keep",
)
parser.add_argument(
"--use_dense_caption",
action="store_true",
help="Use dense caption format for instance prompts",
)
return parser.parse_args()
def get_video_name_from_path(video_path: str) -> str:
"""从 video 路径提取 video name(不含扩展名)"""
return Path(video_path).stem
def find_mask_dirs(instance_label_dir: str, video_name: str) -> dict:
"""
查找某个视频对应的所有 instance mask 目录
mask 目录结构: {video_name}_masks/
mask 文件命名: {frame_id:06d}_No.{instance_id}.png
Returns:
dict: {instance_id: [mask_file_paths]}
"""
mask_dir = os.path.join(instance_label_dir, f"{video_name}_masks")
if not os.path.isdir(mask_dir):
return {}
instance_masks = defaultdict(list)
for fname in sorted(os.listdir(mask_dir)):
if not fname.endswith(".png"):
continue
# 解析文件名: 000000_No.0.png
parts = fname.replace(".png", "").split("_No.")
if len(parts) != 2:
continue
frame_id_str, inst_id_str = parts
try:
frame_id = int(frame_id_str)
inst_id = int(inst_id_str)
except ValueError:
continue
instance_masks[inst_id].append(os.path.join(mask_dir, fname))
return dict(instance_masks)
def build_instance_prompt(instance_info: dict) -> str:
"""
从 InstanceCap 的 instance 信息构建 prompt
instance_info 结构:
{
"Class": "person",
"Appearance": "...",
"Actions and Motion": "...",
"Position": "..."
}
"""
cls = instance_info.get("Class", "object")
appearance = instance_info.get("Appearance", "")
actions = instance_info.get("Actions and Motion", "")
# 构建精简但信息丰富的 prompt
parts = []
if cls:
parts.append(f"A {cls}")
if appearance:
parts.append(appearance)
if actions:
parts.append(actions)
prompt = ". ".join(parts)
# 清理多余空格和标点
prompt = prompt.replace("..", ".").replace(" ", " ").strip()
if prompt and not prompt.endswith("."):
prompt += "."
return prompt
def process_sample(
sample: dict,
instance_label_dir: str,
video_dir: str,
min_instances: int,
max_instances: int,
) -> dict | None:
"""
处理单个样本,返回训练格式的 dict 或 None
"""
video_name = get_video_name_from_path(sample.get("Video", ""))
if not video_name:
return None
# 检查视频是否存在
video_path = os.path.join(video_dir, f"{video_name}.mp4")
if not os.path.isfile(video_path):
return None
# 查找 mask 目录
instance_masks = find_mask_dirs(instance_label_dir, video_name)
if len(instance_masks) < min_instances:
return None
# 提取 instance 信息
struct_desc = sample.get("Structural Description", {})
main_instances = struct_desc.get("Main Instance", {})
instance_prompts = []
instance_mask_dirs = []
for inst_key in sorted(main_instances.keys()):
# inst_key: "No.0", "No.1", ...
try:
inst_id = int(inst_key.replace("No.", ""))
except ValueError:
continue
if inst_id not in instance_masks:
continue
inst_info = main_instances[inst_key]
prompt = build_instance_prompt(inst_info)
if not prompt:
continue
instance_prompts.append(prompt)
# 存储整个 mask 目录路径(而非单个文件列表,训练时动态加载)
mask_dir = os.path.dirname(instance_masks[inst_id][0])
instance_mask_dirs.append({
"mask_dir": mask_dir,
"instance_id": inst_id,
"num_frames": len(instance_masks[inst_id]),
})
if len(instance_prompts) >= max_instances:
break
if len(instance_prompts) < min_instances:
return None
# 构建输出
global_desc = sample.get("Global Description", "")
background = struct_desc.get("Background Detail", "")
camera = struct_desc.get("Camera Movement", "")
# 合并为完整 prompt
full_prompt_parts = [global_desc]
if background:
full_prompt_parts.append(background)
if camera:
full_prompt_parts.append(camera)
full_prompt = " ".join(full_prompt_parts)
return {
"video": video_path,
"prompt": full_prompt,
"instance_prompts": instance_prompts,
"instance_mask_dirs": instance_mask_dirs,
}
def main():
args = parse_args()
# 读取 InstanceCap
print(f"Loading InstanceCap from {args.instancecap_path}")
samples = []
with open(args.instancecap_path, "r", encoding="utf-8") as f:
for line in f:
line = line.strip()
if not line:
continue
try:
samples.append(json.loads(line))
except json.JSONDecodeError as e:
print(f"Warning: Failed to parse line: {e}")
continue
print(f"Loaded {len(samples)} samples")
# 处理每个样本
output_samples = []
for sample in tqdm(samples, desc="Processing samples"):
result = process_sample(
sample,
args.instance_label_dir,
args.video_dir,
args.min_instances,
args.max_instances,
)
if result is not None:
output_samples.append(result)
print(f"Valid samples: {len(output_samples)} / {len(samples)}")
# 写入输出
os.makedirs(os.path.dirname(args.output_path), exist_ok=True)
with open(args.output_path, "w", encoding="utf-8") as f:
for sample in output_samples:
f.write(json.dumps(sample, ensure_ascii=False) + "\n")
print(f"Saved to {args.output_path}")
# 打印统计信息
if output_samples:
avg_instances = sum(len(s["instance_prompts"]) for s in output_samples) / len(output_samples)
print(f"Average instances per sample: {avg_instances:.2f}")
if __name__ == "__main__":
main()