VideoSimpleQA / compute_video_emb.py
hzy's picture
Initial upload of all project files
608eb1a verified
import os
import torch
import argparse
from PIL import Image
from tqdm import tqdm
from transformers import AutoModel, AutoProcessor
import torch.multiprocessing as mp
from torch.utils.data import Dataset, DataLoader
import glob
# --- 配置 ---
MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384"
BATCH_SIZE = 1024 # 根据你的 GPU VRAM 调整
def parse_arguments():
"""解析命令行参数"""
parser = argparse.ArgumentParser(
description="步骤 1: 使用 SigLIP (多GPU) 预计算所有视频帧的嵌入."
)
parser.add_argument(
"--frames-path",
"-fp",
type=str,
required=True,
help="包含所有视频帧文件夹的基础目录的绝对路径。",
)
parser.add_argument(
"--output-dir",
"-o",
type=str,
required=True,
help="用于保存嵌入.pt文件的输出目录路径。",
)
return parser.parse_args()
class FrameDataset(Dataset):
"""一个用于高效加载视频帧的PyTorch Dataset"""
def __init__(self, frame_paths):
self.frame_paths = frame_paths
def __len__(self):
return len(self.frame_paths)
def __getitem__(self, idx):
path = self.frame_paths[idx]
try:
image = Image.open(path).convert("RGB")
return image
except Exception:
return None
def collate_fn(batch):
"""自定义collate函数,用于从批次中过滤掉None值"""
batch = [item for item in batch if item is not None]
if not batch:
return None
return batch
def process_video_chunk(args_tuple):
"""
工作函数,用于在特定GPU上处理一批视频。
"""
video_dirs_chunk, frames_base_path, gpu_id, output_dir = args_tuple
device = f"cuda:{gpu_id}"
# 在工作进程中为指定的GPU加载模型和处理器
model = AutoModel.from_pretrained(MODEL_ID).to(device).eval()
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True)
progress_bar = tqdm(video_dirs_chunk, position=gpu_id, desc=f"GPU-{gpu_id}")
for video_dir in progress_bar:
video_name = os.path.basename(video_dir)
output_path = os.path.join(output_dir, f"{video_name}.pt")
# 如果文件已存在,则跳过,以支持断点续算
if os.path.exists(output_path):
progress_bar.write(f"Skipping {video_name}, embeddings already exist.")
continue
frame_files = [f for f in os.listdir(video_dir) if f.endswith(".jpg")]
if not frame_files:
continue
frame_files.sort(key=lambda x: int(x.split("_")[1].split(".")[0]))
frame_paths = [os.path.join(video_dir, f) for f in frame_files]
try:
with torch.no_grad():
dataset = FrameDataset(frame_paths)
loader = DataLoader(
dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0,
pin_memory=True, collate_fn=collate_fn
)
all_frame_embeddings = []
for image_batch in loader:
if image_batch is None:
continue
image_inputs = processor(images=image_batch, return_tensors="pt").to(device)
frame_embeddings = model.get_image_features(**image_inputs)
all_frame_embeddings.append(frame_embeddings)
if not all_frame_embeddings:
continue
all_frame_embeddings = torch.cat(all_frame_embeddings, dim=0)
# 将张量移动到CPU以便保存,避免后续加载时出现CUDA问题
data_to_save = {
'filenames': frame_files,
'embeddings': all_frame_embeddings.cpu()
}
torch.save(data_to_save, output_path)
except Exception as e:
progress_bar.write(f"Error on GPU-{gpu_id} for video '{video_name}': {e}")
def main():
"""主函数,用于协调多GPU处理"""
args = parse_arguments()
num_gpus = torch.cuda.device_count()
if num_gpus == 0:
print("错误: 未找到启用CUDA的GPU。正在退出。")
exit(1)
print(f"找到 {num_gpus} 个GPU。开始并行处理...")
# 创建输出目录
os.makedirs(args.output_dir, exist_ok=True)
# 获取所有唯一的视频目录
video_dirs = [d for d in glob.glob(os.path.join(args.frames_path, '*')) if os.path.isdir(d)]
if not video_dirs:
print(f"错误: 在 {args.frames_path} 中未找到视频目录。")
return
# 将视频目录分成块,每个GPU一块
chunk_size = (len(video_dirs) + num_gpus - 1) // num_gpus
video_chunks = [video_dirs[i:i + chunk_size] for i in range(0, len(video_dirs), chunk_size)]
# 为每个工作进程准备参数
process_args = [(video_chunks[i], args.frames_path, i, args.output_dir) for i in range(len(video_chunks))]
with mp.Pool(processes=num_gpus) as pool:
pool.map(process_video_chunk, process_args)
print("\n所有视频帧嵌入已计算并保存。")
if __name__ == "__main__":
mp.set_start_method('spawn', force=True)
main()