Xseg-Baseline / correspondence /SegSwap /data /process_data_batch.py
YuqianFu's picture
Upload folder using huggingface_hub
944cdc2 verified
import os
import json
from lzstring import LZString
from pycocotools import mask as mask_utils
import numpy as np
from PIL import Image
from decord import VideoReader
from decord import cpu
import argparse
import cv2
from time import time
def save_frames(frames, frame_idxes, output_folder, is_aria=False):
# resize and save frames
scale = 4
if is_aria:
scale = 2
for img, fidx in zip(frames, frame_idxes):
H, W, C = img.shape
if H < 1408:
break
img2 = cv2.resize(img, (W//scale, H//scale))
cv2.imwrite(os.path.join(output_folder, f'{fidx}.jpg'), img2)
def processVideo(takepath, take_name, ego_cam, exo_cams, outputpath, take_id, fps=1, batch_size=1000):
"""
处理视频并抽帧
Args:
takepath: 输入视频路径
take_name: take名称
ego_cam: ego相机名称
exo_cams: exo相机列表
outputpath: 输出路径
take_id: take ID
fps: 采样帧率,1表示1fps,30表示30fps(每帧都采样)
batch_size: 批处理大小,避免内存问题
"""
if not os.path.exists(f"{takepath}/{take_name}/frame_aligned_videos/{ego_cam}.mp4"):
return -1
# Subsample the ego video
vr = VideoReader(
f"{takepath}/{take_name}/frame_aligned_videos/{ego_cam}.mp4", ctx=cpu(0)
)
len_video = len(vr)
# 根据fps设置采样间隔
if fps == 30:
# 30fps:每帧都采样
sample_interval = 1
elif fps == 1:
# 1fps:假设原视频是30fps,每30帧采样一次
sample_interval = 30
else:
# 其他fps:计算采样间隔
sample_interval = max(1, 30 // fps)
subsample_idx = np.arange(0, len_video, sample_interval)
print(f"Video length: {len_video}, Sample interval: {sample_interval}, Total frames to extract: {len(subsample_idx)}")
# 处理ego视频
if not os.path.exists(f"{outputpath}/{take_id}/{ego_cam}"):
os.makedirs(f"{outputpath}/{take_id}/{ego_cam}")
# 分批处理以避免内存问题
for i in range(0, len(subsample_idx), batch_size):
batch_idx = subsample_idx[i:i+batch_size]
print(f"Processing ego cam batch {i//batch_size + 1}/{(len(subsample_idx) + batch_size - 1)//batch_size}")
try:
frames = vr.get_batch(batch_idx).asnumpy()[...,::-1]
save_frames(frames=frames, frame_idxes=batch_idx,
output_folder=f"{outputpath}/{take_id}/{ego_cam}", is_aria=True)
except Exception as e:
print(f"Error processing ego cam batch: {e}")
continue
# Subsample the exo videos
for exo_cam in exo_cams:
if not os.path.isdir(f"{outputpath}/{take_id}/{exo_cam}"):
try:
vr_exo = VideoReader(
f"{takepath}/{take_name}/frame_aligned_videos/{exo_cam}.mp4", ctx=cpu(0)
)
except Exception as e:
print(f"{exo_cam} not available: {e}")
continue
os.makedirs(f"{outputpath}/{take_id}/{exo_cam}")
# 分批处理exo视频
for i in range(0, len(subsample_idx), batch_size):
batch_idx = subsample_idx[i:i+batch_size]
print(f"Processing {exo_cam} batch {i//batch_size + 1}/{(len(subsample_idx) + batch_size - 1)//batch_size}")
try:
frames = vr_exo.get_batch(batch_idx).asnumpy()[...,::-1]
save_frames(frames=frames, frame_idxes=batch_idx,
output_folder=f"{outputpath}/{take_id}/{exo_cam}", is_aria=False)
except Exception as e:
print(f"Error processing {exo_cam} batch: {e}")
continue
return subsample_idx.tolist()
def decode_mask(width, height, encoded_mask):
try:
decomp_string = LZString.decompressFromEncodedURIComponent(encoded_mask)
except:
return None
decomp_encoded = decomp_string.encode()
rle_obj = {
"size": [height, width],
"counts": decomp_encoded,
}
rle_obj['counts'] = rle_obj['counts'].decode('ascii')
return rle_obj
def processMask(anno, new_anno):
for object_id in anno.keys():
new_anno[object_id] = {}
for cam_id in anno[object_id].keys():
new_anno[object_id][cam_id] = {}
for frame_id in anno[object_id][cam_id]["annotation"].keys():
width = anno[object_id][cam_id]["annotation"][frame_id]["width"]
height = anno[object_id][cam_id]["annotation"][frame_id]["height"]
encoded_mask = anno[object_id][cam_id]["annotation"][frame_id]["encodedMask"]
coco_mask = decode_mask(width, height, encoded_mask)
new_anno[object_id][cam_id][frame_id] = coco_mask
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"--takepath",
help="EgoExo take data root",
required=True
)
parser.add_argument(
"--annotationpath",
help="Annotations json file path",
required=True
)
parser.add_argument(
"--split_path",
help="path to split.json",
required=True
)
parser.add_argument(
"--split",
help="train/val/test split to process",
required=True
)
parser.add_argument(
"--outputpath",
help="Output data root",
required=True
)
parser.add_argument(
"--fps",
type=int,
default=1,
help="Sampling frame rate (1 for 1fps, 30 for 30fps)"
)
parser.add_argument(
"--batch_size",
type=int,
default=1000,
help="Batch size for processing frames to avoid memory issues"
)
args = parser.parse_args()
with open(args.split_path, "r") as fp:
data_split = json.load(fp)
# 使用指定的split或调试用的单个take
if args.split in data_split:
take_list = data_split[args.split]
else:
# 调试模式
take_list = ['6ca51642-c089-4989-b5a3-07977ec927d7']
print(f"Debug mode: processing only {take_list}")
os.makedirs(args.outputpath, exist_ok=True)
# Read the annotation file
with open(args.annotationpath, "r") as f:
annos = json.load(f)
annos = annos['annotations']
print(f"Processing {len(take_list)} takes at {args.fps}fps with batch size {args.batch_size}")
start = time()
for idx, take_id in enumerate(take_list):
print(f"\n=== Processing take {idx+1}/{len(take_list)}: {take_id} ===")
if os.path.exists(f"{args.outputpath}/{take_id}"):
print(f"{take_id} already done!")
continue
# Create the output folder
os.makedirs(f"{args.outputpath}/{take_id}", exist_ok=True)
new_anno = {}
# Get the corresponding take name
if take_id not in annos:
print(f"Take {take_id} not found in annotations!")
continue
anno = annos[take_id]
take_name = anno["take_name"]
valid_cams = set()
for x in anno['object_masks'].keys():
valid_cams.update(set(anno['object_masks'][x].keys()))
ego_cams = []
exo_cams = []
for vc in valid_cams:
if 'aria' in vc:
ego_cams.append(vc)
else:
exo_cams.append(vc)
if len(ego_cams) == 0:
print(f"No ego camera found for take {take_id}")
continue
if len(ego_cams) > 1:
print(f"{take_id} HAS MORE THAN ONE EGO: {ego_cams}")
# 使用第一个ego相机
print(f"Processing take {take_id} {take_name}")
print(f"Ego cams: {ego_cams}")
print(f"Exo cams: {exo_cams}")
# Process the masks
print("Start processing masks")
new_anno["masks"] = {}
processMask(anno['object_masks'], new_anno["masks"])
# Process the videos
print("Start processing Videos")
subsample_idx = processVideo(
args.takepath,
take_name,
ego_cam=ego_cams[0],
exo_cams=exo_cams,
outputpath=args.outputpath,
take_id=take_id,
fps=args.fps,
batch_size=args.batch_size
)
if subsample_idx == -1:
print(f"{args.takepath}/{take_name}/frame_aligned_videos/{ego_cams[0]}.mp4 does not exist")
continue
new_anno["subsample_idx"] = subsample_idx
new_anno["fps"] = args.fps
# Save the annotation
with open(f"{args.outputpath}/{take_id}/annotation.json", "w") as f:
json.dump(new_anno, f)
print(f"Completed take {take_id}, extracted {len(subsample_idx)} frames")
end = time()
print(f"\nTotal processing time: {end-start:.2f} seconds")
print(f"Average time per take: {(end-start)/len(take_list):.2f} seconds")