|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
import os |
|
|
import tqdm |
|
|
import torch |
|
|
import torchvision |
|
|
import shutil |
|
|
from multiprocessing import Process |
|
|
import numpy as np |
|
|
from decord import VideoReader |
|
|
from einops import rearrange |
|
|
from eval.hyper_iqa import HyperNet, TargetNet |
|
|
|
|
|
|
|
|
paths = [] |
|
|
|
|
|
|
|
|
def gather_paths(input_dir, output_dir): |
|
|
|
|
|
|
|
|
for video in tqdm.tqdm(sorted(os.listdir(input_dir))): |
|
|
if video.endswith(".mp4"): |
|
|
video_input = os.path.join(input_dir, video) |
|
|
video_output = os.path.join(output_dir, video) |
|
|
if os.path.isfile(video_output): |
|
|
continue |
|
|
paths.append((video_input, video_output)) |
|
|
elif os.path.isdir(os.path.join(input_dir, video)): |
|
|
gather_paths(os.path.join(input_dir, video), os.path.join(output_dir, video)) |
|
|
|
|
|
|
|
|
def read_video(video_path: str): |
|
|
vr = VideoReader(video_path) |
|
|
first_frame = vr[0].asnumpy() |
|
|
middle_frame = vr[len(vr) // 2].asnumpy() |
|
|
last_frame = vr[-1].asnumpy() |
|
|
vr.seek(0) |
|
|
video_frames = np.stack([first_frame, middle_frame, last_frame], axis=0) |
|
|
video_frames = torch.from_numpy(rearrange(video_frames, "b h w c -> b c h w")) |
|
|
video_frames = video_frames / 255.0 |
|
|
return video_frames |
|
|
|
|
|
|
|
|
def func(paths, device_id): |
|
|
device = f"cuda:{device_id}" |
|
|
|
|
|
model_hyper = HyperNet(16, 112, 224, 112, 56, 28, 14, 7).to(device) |
|
|
model_hyper.train(False) |
|
|
|
|
|
|
|
|
model_hyper.load_state_dict((torch.load("checkpoints/auxiliary/koniq_pretrained.pkl"))) |
|
|
|
|
|
transforms = torchvision.transforms.Compose( |
|
|
[ |
|
|
torchvision.transforms.CenterCrop(size=224), |
|
|
torchvision.transforms.Normalize(mean=(0.485, 0.456, 0.406), std=(0.229, 0.224, 0.225)), |
|
|
] |
|
|
) |
|
|
|
|
|
for video_input, video_output in paths: |
|
|
try: |
|
|
video_frames = read_video(video_input) |
|
|
video_frames = transforms(video_frames) |
|
|
video_frames = video_frames.clone().detach().to(device) |
|
|
paras = model_hyper(video_frames) |
|
|
|
|
|
|
|
|
model_target = TargetNet(paras).cuda() |
|
|
for param in model_target.parameters(): |
|
|
param.requires_grad = False |
|
|
|
|
|
|
|
|
pred = model_target(paras["target_in_vec"]) |
|
|
|
|
|
|
|
|
quality_score = pred.mean().item() |
|
|
print(f"Input video: {video_input}\nVisual quality score: {quality_score:.2f}") |
|
|
|
|
|
if quality_score >= 40: |
|
|
os.makedirs(os.path.dirname(video_output), exist_ok=True) |
|
|
shutil.copy(video_input, video_output) |
|
|
except Exception as e: |
|
|
print(e) |
|
|
|
|
|
|
|
|
def split(a, n): |
|
|
k, m = divmod(len(a), n) |
|
|
return (a[i * k + min(i, m) : (i + 1) * k + min(i + 1, m)] for i in range(n)) |
|
|
|
|
|
|
|
|
def filter_visual_quality_multi_gpus(input_dir, output_dir, num_workers): |
|
|
gather_paths(input_dir, output_dir) |
|
|
num_devices = torch.cuda.device_count() |
|
|
if num_devices == 0: |
|
|
raise RuntimeError("No GPUs found") |
|
|
split_paths = list(split(paths, num_workers * num_devices)) |
|
|
processes = [] |
|
|
|
|
|
for i in range(num_devices): |
|
|
for j in range(num_workers): |
|
|
process_index = i * num_workers + j |
|
|
process = Process(target=func, args=(split_paths[process_index], i)) |
|
|
process.start() |
|
|
processes.append(process) |
|
|
|
|
|
for process in processes: |
|
|
process.join() |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
input_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/av_synced_high" |
|
|
output_dir = "/mnt/bn/maliva-gen-ai-v2/chunyu.li/VoxCeleb2/high_visual_quality" |
|
|
num_workers = 20 |
|
|
|
|
|
filter_visual_quality_multi_gpus(input_dir, output_dir, num_workers) |
|
|
|