Spaces:
Running
on
Zero
Running
on
Zero
| import argparse | |
| import os | |
| import numpy as np | |
| import pandas as pd | |
| import torch | |
| from accelerate import PartialState | |
| from accelerate.utils import gather_object | |
| from natsort import natsorted | |
| from torch.utils.data import DataLoader | |
| from tqdm import tqdm | |
| from transformers import AutoImageProcessor, AutoModel | |
| from utils.filter import filter | |
| from utils.logger import logger | |
| from utils.video_dataset import VideoDataset, collate_fn | |
| from utils.video_utils import ALL_FRAME_SAMPLE_METHODS | |
| ALL_MODEL_NAME = [ | |
| "dinov2-small", | |
| "dinov2-base", | |
| "dinov2-large", | |
| "clip-vit-large-patch14", | |
| "clip-vit-base-patch32", | |
| "clip-vit-large-patch14-336", | |
| ] | |
| def init_model(model_name, device): | |
| processor = AutoImageProcessor.from_pretrained(model_name) | |
| model = AutoModel.from_pretrained(model_name).to(device) | |
| return processor, model | |
| def compute_adjacent_similarity(frame_features): | |
| frame_features /= frame_features.norm(dim=-1, keepdim=True) | |
| roll_frame_features = torch.roll(frame_features, shifts=-1, dims=0) | |
| similarity_matrix = frame_features.squeeze(dim=1).cpu().numpy() @ roll_frame_features.squeeze(dim=1).cpu().numpy().T | |
| return np.diag(similarity_matrix).tolist()[:-1] | |
| def parse_args(): | |
| parser = argparse.ArgumentParser(description="Compute the semantic consistency score across frames.") | |
| parser.add_argument( | |
| "--video_metadata_path", type=str, required=True, help="The path to the video dataset metadata (csv/jsonl)." | |
| ) | |
| parser.add_argument( | |
| "--video_path_column", | |
| type=str, | |
| default="video_path", | |
| help="The column contains the video path (an absolute path or a relative path w.r.t the video_folder).", | |
| ) | |
| parser.add_argument("--video_folder", type=str, default="", help="The video folder.") | |
| parser.add_argument( | |
| "--model_path", type=str, default="openai/clip-vit-large-patch14-336", help="The path to the DINO/CLIP model." | |
| ) | |
| parser.add_argument("--frame_sample_method", type=str, choices=ALL_FRAME_SAMPLE_METHODS, default="keyframe+first") | |
| parser.add_argument("--num_sampled_frames", type=int, default=1, help="The number of sampled frames.") | |
| parser.add_argument("--sample_stride", type=int, default=None, help="The stride between two sampled frames.") | |
| parser.add_argument("--batch_size", type=int, default=1, help="The batch size for the video dataset.") | |
| parser.add_argument("--num_workers", type=int, default=1, help="The number of workers for the video dataset.") | |
| parser.add_argument("--saved_path", type=str, required=True, help="The save path to the output results (csv/jsonl).") | |
| parser.add_argument("--saved_freq", type=int, default=1, help="The frequency to save the output results.") | |
| parser.add_argument("--basic_metadata_path", type=str, default=None, help="The path to the basic metadata (csv/jsonl).") | |
| parser.add_argument("--min_resolution", type=float, default=0, help="The resolution threshold.") | |
| parser.add_argument("--min_duration", type=float, default=-1, help="The minimum duration.") | |
| parser.add_argument("--max_duration", type=float, default=-1, help="The maximum duration.") | |
| parser.add_argument( | |
| "--aesthetic_score_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." | |
| ) | |
| parser.add_argument("--min_aesthetic_score", type=float, default=4.0, help="The aesthetic score threshold.") | |
| parser.add_argument( | |
| "--aesthetic_score_siglip_metadata_path", type=str, default=None, help="The path to the video quality metadata (csv/jsonl)." | |
| ) | |
| parser.add_argument("--min_aesthetic_score_siglip", type=float, default=4.0, help="The aesthetic score (SigLIP) threshold.") | |
| parser.add_argument( | |
| "--text_score_metadata_path", type=str, default=None, help="The path to the video text score metadata (csv/jsonl)." | |
| ) | |
| parser.add_argument("--min_text_score", type=float, default=0.02, help="The text threshold.") | |
| parser.add_argument( | |
| "--motion_score_metadata_path", type=str, default=None, help="The path to the video motion score metadata (csv/jsonl)." | |
| ) | |
| parser.add_argument("--min_motion_score", type=float, default=2, help="The motion threshold.") | |
| parser.add_argument("--max_motion_score", type=float, default=999999, help="The maximum motion threshold.") | |
| args = parser.parse_args() | |
| return args | |
| def main(): | |
| args = parse_args() | |
| if args.video_metadata_path.endswith(".csv"): | |
| video_metadata_df = pd.read_csv(args.video_metadata_path) | |
| elif args.video_metadata_path.endswith(".jsonl"): | |
| video_metadata_df = pd.read_json(args.video_metadata_path, lines=True) | |
| else: | |
| raise ValueError("The video_metadata_path must end with .csv or .jsonl.") | |
| video_path_list = video_metadata_df[args.video_path_column].tolist() | |
| if not (args.saved_path.endswith(".csv") or args.saved_path.endswith(".jsonl")): | |
| raise ValueError("The saved_path must end with .csv or .jsonl.") | |
| if os.path.exists(args.saved_path): | |
| if args.saved_path.endswith(".csv"): | |
| saved_metadata_df = pd.read_csv(args.saved_path) | |
| elif args.saved_path.endswith(".jsonl"): | |
| saved_metadata_df = pd.read_json(args.saved_path, lines=True) | |
| saved_video_path_list = saved_metadata_df[args.video_path_column].tolist() | |
| video_path_list = list(set(video_path_list).difference(set(saved_video_path_list))) | |
| logger.info(f"Resume from {args.saved_path}: {len(saved_video_path_list)} processed and {len(video_path_list)} to be processed.") | |
| video_path_list = filter( | |
| video_path_list, | |
| basic_metadata_path=args.basic_metadata_path, | |
| min_resolution=args.min_resolution, | |
| min_duration=args.min_duration, | |
| max_duration=args.max_duration, | |
| aesthetic_score_metadata_path=args.aesthetic_score_metadata_path, | |
| min_aesthetic_score=args.min_aesthetic_score, | |
| aesthetic_score_siglip_metadata_path=args.aesthetic_score_siglip_metadata_path, | |
| min_aesthetic_score_siglip=args.min_aesthetic_score_siglip, | |
| text_score_metadata_path=args.text_score_metadata_path, | |
| min_text_score=args.min_text_score, | |
| motion_score_metadata_path=args.motion_score_metadata_path, | |
| min_motion_score=args.min_motion_score, | |
| max_motion_score=args.max_motion_score, | |
| video_path_column=args.video_path_column | |
| ) | |
| # Sorting to guarantee the same result for each process. | |
| video_path_list = natsorted(video_path_list) | |
| if not any(name in args.model_path for name in ALL_MODEL_NAME): | |
| raise ValueError(f"The model_path should be among the following list: {ALL_MODEL_NAME}.") | |
| state = PartialState() | |
| if state.is_main_process: | |
| # Check if the model is downloaded in the main process. | |
| processor, model = init_model(args.model_path, "cpu") | |
| state.wait_for_everyone() | |
| processor, model = init_model(args.model_path, state.device) | |
| index = len(video_path_list) - len(video_path_list) % state.num_processes | |
| # Avoid the NCCL timeout in the final gather operation. | |
| logger.warning( | |
| f"Drop the last {len(video_path_list) % state.num_processes} videos " | |
| "to ensure each process handles the same number of videos." | |
| ) | |
| video_path_list = video_path_list[:index] | |
| logger.info(f"{len(video_path_list)} videos are to be processed.") | |
| result_dict = { | |
| args.video_path_column: [], | |
| "similarity_cross_frame": [], | |
| "similarity_mean": [], | |
| "sample_frame_idx": [], | |
| } | |
| with state.split_between_processes(video_path_list) as splitted_video_path_list: | |
| video_dataset = VideoDataset( | |
| dataset_inputs={args.video_path_column: splitted_video_path_list}, | |
| video_folder=args.video_folder, | |
| video_path_column=args.video_path_column, | |
| sample_method=args.frame_sample_method, | |
| num_sampled_frames=args.num_sampled_frames, | |
| sample_stride=args.sample_stride, | |
| ) | |
| video_loader = DataLoader(video_dataset, batch_size=args.batch_size, num_workers=args.num_workers, collate_fn=collate_fn) | |
| for idx, batch in enumerate(tqdm(video_loader)): | |
| if len(batch) > 0: | |
| batch_video_path = [] | |
| batch_frame = [] | |
| batch_sampled_frame_idx = [] | |
| # At least two frames are required to calculate cross-frame semantic consistency. | |
| for path, frame, frame_idx in zip(batch["path"], batch["sampled_frame"], batch["sampled_frame_idx"]): | |
| if len(frame) > 1: | |
| batch_video_path.append(path) | |
| batch_frame.append(frame) | |
| batch_sampled_frame_idx.append(frame_idx) | |
| else: | |
| logger.warning(f"Skip {path} because it only has {len(frame)} frames.") | |
| frame_num_list = [len(video_frames) for video_frames in batch_frame] | |
| # [B, T, H, W, C] => [(B * T), H, W, C] | |
| reshaped_batch_frame = [frame for video_frames in batch_frame for frame in video_frames] | |
| with torch.no_grad(): | |
| inputs = processor(images=reshaped_batch_frame, return_tensors="pt").to(state.device) | |
| if "dino" in args.model_path.lower(): | |
| frame_features = model(**inputs).last_hidden_state.mean(dim=1) | |
| else: # CLIP | |
| frame_features = model.get_image_features(**inputs) | |
| # Each video may have a different number of sampled frames. | |
| # Map the flattened frame features back to their original shape. | |
| batch_frame_features = torch.split(frame_features, frame_num_list) | |
| batch_simi_cross_frame = [compute_adjacent_similarity(frame_features) for frame_features in batch_frame_features] | |
| batch_similarity_mean = [ | |
| sum(simi_cross_frame) / len(simi_cross_frame) for simi_cross_frame in batch_simi_cross_frame | |
| ] | |
| if args.video_folder == "": | |
| saved_video_path_list = batch_video_path | |
| else: | |
| saved_video_path_list = [os.path.relpath(video_path, args.video_folder) for video_path in batch_video_path] | |
| result_dict[args.video_path_column].extend(saved_video_path_list) | |
| result_dict["similarity_cross_frame"].extend(batch_simi_cross_frame) | |
| result_dict["similarity_mean"].extend(batch_similarity_mean) | |
| result_dict["sample_frame_idx"].extend(batch_sampled_frame_idx) | |
| # Save the metadata in the main process every saved_freq. | |
| if (idx % args.saved_freq) == 0 or idx == len(video_loader) - 1: | |
| state.wait_for_everyone() | |
| gathered_result_dict = {k: gather_object(v) for k, v in result_dict.items()} | |
| if state.is_main_process and len(gathered_result_dict[args.video_path_column]) != 0: | |
| result_df = pd.DataFrame(gathered_result_dict) | |
| # Append is not supported (oss). | |
| if args.saved_path.endswith(".csv"): | |
| if os.path.exists(args.saved_path): | |
| saved_df = pd.read_csv(args.saved_path) | |
| result_df = pd.concat([saved_df, result_df], ignore_index=True) | |
| result_df.to_csv(args.saved_path, index=False) | |
| elif args.saved_path.endswith(".jsonl"): | |
| if os.path.exists(args.saved_path): | |
| saved_df = pd.read_json(args.saved_path, orient="records", lines=True) | |
| result_df = pd.concat([saved_df, result_df], ignore_index=True) | |
| result_df.to_json(args.saved_path, orient="records", lines=True, force_ascii=False) | |
| logger.info(f"Save result to {args.saved_path}.") | |
| for k in result_dict.keys(): | |
| result_dict[k] = [] | |
| if __name__ == "__main__": | |
| main() | |