PyTorch
ssl-aasist
custom_code
ash56's picture
Add files using upload-large-folder tool
878264b verified
raw
history blame
8.32 kB
# Copyright Howto100M authors.
# Copyright (c) Facebook, Inc. All Rights Reserved
import torch as th
import pandas as pd
import os
import numpy as np
import ffmpeg
import random
from torch.utils.data import Dataset
class VideoLoader(Dataset):
"""modified from how2's video_feature_extractor."""
def __init__(
self,
csv=None,
video_dict=None,
framerate=1,
size=112,
centercrop=False,
hflip=False,
**kwargs
):
if csv is None and video_dict is None:
raise ValueError("csv and video_dict cannot be both None.")
if csv is not None:
self.csv = pd.read_csv(csv)
if video_dict is not None:
self.csv = pd.DataFrame.from_dict(video_dict)
self.centercrop = centercrop
self.size = size
self.framerate = framerate
self.hflip = hflip
def __len__(self):
return len(self.csv)
def _get_video_dim(self, video_path):
probe = ffmpeg.probe(video_path)
video_stream = next((stream for stream in probe['streams']
if stream['codec_type'] == 'video'), None)
width = int(video_stream['width'])
height = int(video_stream['height'])
return height, width
def _get_video_info(self, video_path):
probe = ffmpeg.probe(video_path)
video_stream = next((stream for stream in probe['streams']
if stream['codec_type'] == 'video'), None)
return video_stream
def _get_output_dim(self, h, w):
if isinstance(self.size, tuple) and len(self.size) == 2:
return self.size
elif h >= w:
return int(h * self.size / w), self.size
else:
return self.size, int(w * self.size / h)
def __getitem__(self, idx):
video_path = self.csv['video_path'].values[idx]
output_file = self.csv['feature_path'].values[idx]
return self._decode(output_file, video_path)
def _decode(self, output_file, video_path):
if not(os.path.isfile(output_file)) and os.path.isfile(video_path):
try:
h, w = self._get_video_dim(video_path)
except Exception:
print('ffprobe failed at: {}'.format(video_path))
return {'video': th.zeros(1), 'input': video_path,
'output': output_file}
try:
os.makedirs(os.path.dirname(output_file), exist_ok=True)
height, width = self._get_output_dim(h, w)
cmd = (
ffmpeg
.input(video_path)
.filter('fps', fps=self.framerate)
.filter('scale', width, height)
)
if self.hflip:
cmd = cmd.filter('hflip')
if self.centercrop:
x = int((width - self.size) / 2.0)
y = int((height - self.size) / 2.0)
cmd = cmd.crop(x, y, self.size, self.size)
video = self._run(cmd, output_file)
except Exception:
video = th.zeros(1)
else:
video = th.zeros(1)
return {'video': video, 'input': video_path, 'output': output_file}
def _run(self, cmd, output_file):
out, _ = (
cmd.output('pipe:', format='rawvideo', pix_fmt='rgb24')
.run(capture_stdout=True, quiet=True)
)
if self.centercrop and isinstance(self.size, int):
height, width = self.size, self.size
video = np.frombuffer(out, np.uint8).reshape([-1, height, width, 3])
video = th.from_numpy(video.astype('float32'))
return video.permute(0, 3, 1, 2)
class VideoVerifier(VideoLoader):
def __getitem__(self, idx):
video_path = self.csv['video_path'].values[idx]
try:
return self._get_video_info(video_path)
except Exception:
# print('ffprobe failed at: {}'.format(video_path))
return None
class VideoCompressor(VideoLoader):
def __init__(
self,
csv=None,
video_dict=None,
framerate=1,
size=112,
centercrop=False,
hflip=False,
crf=32,
**kwargs
):
super().__init__(
csv,
video_dict,
framerate,
size,
centercrop,
hflip
)
self.crf = crf
def _run(self, cmd, output_file):
out, _ = (
cmd.output(filename=output_file, crf=self.crf)
.run(quiet=True)
)
video = None
return video
class VideoDownloader(VideoCompressor):
"""download"""
def __getitem__(self, idx):
video_path = self.csv['video_path'].values[idx]
output_file = self.csv['feature_path'].values[idx]
if not(os.path.isfile(output_file)):
os.makedirs(os.path.dirname(output_file), exist_ok=True)
cmd = "wget -O" + output_file + " " + video_path
# import subprocess
# subprocess.check_output(
# cmd,
# stderr=subprocess.STDOUT, shell=True)
os.system(cmd)
return {'video': None, 'input': video_path, 'output': output_file}
class AvKeyframeVideoCompressor(VideoLoader):
"""extract keyframes from a video and save it as jpg.
TODO: consider to merge with `CodecProcessor`.
"""
def __init__(
self,
csv=None,
video_dict=None,
framerate=1,
size=112,
centercrop=False,
max_num_frames=5,
**kwargs
):
super().__init__(csv, video_dict, framerate, size, centercrop)
self.max_num_frames = max_num_frames
def _get_video_dim(self, video_fn):
"""decord cannot probe the size of a video, we use pyav instead."""
import av
with av.open(video_fn) as container:
height = container.streams.video[0].codec_context.height
width = container.streams.video[0].codec_context.width
return height, width
def _get_output_dim(self, height, width):
"""
keep the shorter side be `self.size`, strech the other.
"""
if height >= width:
return int(height * self.size / width), self.size
else:
return self.size, int(width * self.size / height)
def __getitem__(self, idx):
import av
video_path = self.csv['video_path'].values[idx]
output_file = self.csv['feature_path'].values[idx]
if not(os.path.isdir(output_file)) and os.path.isfile(video_path):
try:
h, w = self._get_video_dim(video_path)
except Exception:
print('probe failed at: {}'.format(video_path))
return {'video': th.zeros(1), 'input': video_path,
'output': output_file}
try:
height, width = self._get_output_dim(h, w)
# new for av.
with av.open(video_path) as container:
container.streams.video[0].thread_type = "AUTO"
container.streams.video[0].codec_context.height = height
container.streams.video[0].codec_context.width = width
if self.framerate == 0: # keyframe.
container.streams.video[0].codec_context.skip_frame = 'NONKEY'
frames = []
for frame in container.decode(video=0):
frames.append(frame)
frames = random.sample(frames, self.max_num_frames)
os.makedirs(output_file, exist_ok=True)
for frame in frames:
frame.to_image().save(
os.path.join(
output_file,
"%04d.jpg" % frame.index))
except Exception:
print('extract failed at: {}'.format(video_path))
return {'video': th.zeros(1), 'input': video_path,
'output': output_file}
video = th.zeros(1)
return {'video': video, 'input': video_path, 'output': output_file}