|
|
import os |
|
|
import torch |
|
|
import argparse |
|
|
from PIL import Image |
|
|
from tqdm import tqdm |
|
|
from transformers import AutoModel, AutoProcessor |
|
|
import torch.multiprocessing as mp |
|
|
from torch.utils.data import Dataset, DataLoader |
|
|
import glob |
|
|
|
|
|
|
|
|
MODEL_ID = "/mnt/bn/ziyang-storage-cloudnative-hl/huggingface/siglip-so400m-patch14-384" |
|
|
BATCH_SIZE = 1024 |
|
|
|
|
|
def parse_arguments(): |
|
|
"""解析命令行参数""" |
|
|
parser = argparse.ArgumentParser( |
|
|
description="步骤 1: 使用 SigLIP (多GPU) 预计算所有视频帧的嵌入." |
|
|
) |
|
|
parser.add_argument( |
|
|
"--frames-path", |
|
|
"-fp", |
|
|
type=str, |
|
|
required=True, |
|
|
help="包含所有视频帧文件夹的基础目录的绝对路径。", |
|
|
) |
|
|
parser.add_argument( |
|
|
"--output-dir", |
|
|
"-o", |
|
|
type=str, |
|
|
required=True, |
|
|
help="用于保存嵌入.pt文件的输出目录路径。", |
|
|
) |
|
|
return parser.parse_args() |
|
|
|
|
|
class FrameDataset(Dataset): |
|
|
"""一个用于高效加载视频帧的PyTorch Dataset""" |
|
|
def __init__(self, frame_paths): |
|
|
self.frame_paths = frame_paths |
|
|
|
|
|
def __len__(self): |
|
|
return len(self.frame_paths) |
|
|
|
|
|
def __getitem__(self, idx): |
|
|
path = self.frame_paths[idx] |
|
|
try: |
|
|
image = Image.open(path).convert("RGB") |
|
|
return image |
|
|
except Exception: |
|
|
return None |
|
|
|
|
|
def collate_fn(batch): |
|
|
"""自定义collate函数,用于从批次中过滤掉None值""" |
|
|
batch = [item for item in batch if item is not None] |
|
|
if not batch: |
|
|
return None |
|
|
return batch |
|
|
|
|
|
def process_video_chunk(args_tuple): |
|
|
""" |
|
|
工作函数,用于在特定GPU上处理一批视频。 |
|
|
""" |
|
|
video_dirs_chunk, frames_base_path, gpu_id, output_dir = args_tuple |
|
|
device = f"cuda:{gpu_id}" |
|
|
|
|
|
|
|
|
model = AutoModel.from_pretrained(MODEL_ID).to(device).eval() |
|
|
processor = AutoProcessor.from_pretrained(MODEL_ID, use_fast=True) |
|
|
|
|
|
progress_bar = tqdm(video_dirs_chunk, position=gpu_id, desc=f"GPU-{gpu_id}") |
|
|
|
|
|
for video_dir in progress_bar: |
|
|
video_name = os.path.basename(video_dir) |
|
|
output_path = os.path.join(output_dir, f"{video_name}.pt") |
|
|
|
|
|
|
|
|
if os.path.exists(output_path): |
|
|
progress_bar.write(f"Skipping {video_name}, embeddings already exist.") |
|
|
continue |
|
|
|
|
|
frame_files = [f for f in os.listdir(video_dir) if f.endswith(".jpg")] |
|
|
if not frame_files: |
|
|
continue |
|
|
frame_files.sort(key=lambda x: int(x.split("_")[1].split(".")[0])) |
|
|
frame_paths = [os.path.join(video_dir, f) for f in frame_files] |
|
|
|
|
|
try: |
|
|
with torch.no_grad(): |
|
|
dataset = FrameDataset(frame_paths) |
|
|
loader = DataLoader( |
|
|
dataset, batch_size=BATCH_SIZE, shuffle=False, num_workers=0, |
|
|
pin_memory=True, collate_fn=collate_fn |
|
|
) |
|
|
|
|
|
all_frame_embeddings = [] |
|
|
for image_batch in loader: |
|
|
if image_batch is None: |
|
|
continue |
|
|
|
|
|
image_inputs = processor(images=image_batch, return_tensors="pt").to(device) |
|
|
frame_embeddings = model.get_image_features(**image_inputs) |
|
|
all_frame_embeddings.append(frame_embeddings) |
|
|
|
|
|
if not all_frame_embeddings: |
|
|
continue |
|
|
|
|
|
all_frame_embeddings = torch.cat(all_frame_embeddings, dim=0) |
|
|
|
|
|
|
|
|
data_to_save = { |
|
|
'filenames': frame_files, |
|
|
'embeddings': all_frame_embeddings.cpu() |
|
|
} |
|
|
torch.save(data_to_save, output_path) |
|
|
|
|
|
except Exception as e: |
|
|
progress_bar.write(f"Error on GPU-{gpu_id} for video '{video_name}': {e}") |
|
|
|
|
|
def main(): |
|
|
"""主函数,用于协调多GPU处理""" |
|
|
args = parse_arguments() |
|
|
|
|
|
num_gpus = torch.cuda.device_count() |
|
|
if num_gpus == 0: |
|
|
print("错误: 未找到启用CUDA的GPU。正在退出。") |
|
|
exit(1) |
|
|
|
|
|
print(f"找到 {num_gpus} 个GPU。开始并行处理...") |
|
|
|
|
|
|
|
|
os.makedirs(args.output_dir, exist_ok=True) |
|
|
|
|
|
|
|
|
video_dirs = [d for d in glob.glob(os.path.join(args.frames_path, '*')) if os.path.isdir(d)] |
|
|
|
|
|
if not video_dirs: |
|
|
print(f"错误: 在 {args.frames_path} 中未找到视频目录。") |
|
|
return |
|
|
|
|
|
|
|
|
chunk_size = (len(video_dirs) + num_gpus - 1) // num_gpus |
|
|
video_chunks = [video_dirs[i:i + chunk_size] for i in range(0, len(video_dirs), chunk_size)] |
|
|
|
|
|
|
|
|
process_args = [(video_chunks[i], args.frames_path, i, args.output_dir) for i in range(len(video_chunks))] |
|
|
|
|
|
with mp.Pool(processes=num_gpus) as pool: |
|
|
pool.map(process_video_chunk, process_args) |
|
|
|
|
|
print("\n所有视频帧嵌入已计算并保存。") |
|
|
|
|
|
if __name__ == "__main__": |
|
|
mp.set_start_method('spawn', force=True) |
|
|
main() |
|
|
|