Spaces:
Runtime error
Runtime error
| import av | |
| import gc | |
| import torch | |
| import torchaudio | |
| import numpy as np | |
| import random | |
| import logging | |
| import io | |
| from torchvision.transforms.functional import pil_to_tensor | |
| logger = logging.getLogger(__name__) | |
| def get_index(num_frames, num_segments): | |
| seg_size = float(num_frames - 1) / num_segments | |
| start = int(seg_size / 2) | |
| offsets = np.array([ | |
| start + int(np.round(seg_size * idx)) for idx in range(num_segments) | |
| ]) | |
| return offsets | |
| def lazy_load_s3video(s3path_video, num_frames, video_start_frame, video_end_frame, client): | |
| # load video from ceph | |
| assert client is not None | |
| video_bytes_stream = client.get(s3path_video, enable_stream_lazyloding=True) | |
| container = av.open(video_bytes_stream) | |
| stream = container.streams.video[0] | |
| # duration = stream.duration | |
| real_fps = container.streams.video[0].average_rate | |
| time_base = container.streams.video[0].time_base | |
| start, end = video_start_frame, video_end_frame | |
| # Convert time to pts | |
| duration_frams = end - start + 1 | |
| frames_index = get_index(duration_frams, num_frames) | |
| pts_list = [] | |
| start_pts = int((start/real_fps) / time_base) | |
| end_pts = int((end/real_fps) / time_base) | |
| for frame_index in frames_index: | |
| pts_list.append(int((frame_index / real_fps)) / time_base) | |
| # Seek to nearest key frame from the start | |
| container.seek(max(start_pts, 0), stream=stream) | |
| frames = [] | |
| for frame in container.decode(**{"video":0}): | |
| if frame.pts < start_pts: | |
| continue | |
| # if frame.pts <= end_pts: | |
| if len(pts_list) >0: | |
| if frame.pts >= pts_list[0]: | |
| frames.append(frame) | |
| pts_list.pop(0) | |
| else: | |
| break | |
| frames = [pil_to_tensor(frames[idx].to_rgb().to_image()).unsqueeze(0) for idx in range(len(frames))] | |
| container.close() | |
| del video_bytes_stream # T C H W | |
| return torch.cat(frames, dim=0) # , start, end, float(real_fps) | |
| def load_audio_av(video_path, video_start_frame, video_end_frame, sr, max_audio_length, client): # sr should be 16000 | |
| assert client is not None | |
| video_bytes_stream = client.get(video_path, enable_stream_lazyloding=True) | |
| try: | |
| container = av.open(video_bytes_stream) | |
| except: | |
| logger.warn(f"Something wrong when av.open (video_path: {video_path})!") | |
| return None | |
| if len(container.streams.audio) == 0: | |
| logger.warn(f"There is no audio! (video_path: {video_path})!") | |
| return None | |
| audio_stream = container.streams.audio[0] | |
| real_fps = container.streams.video[0].average_rate | |
| time_base = audio_stream.time_base | |
| csr = audio_stream.sample_rate | |
| start_frame, end_frame = video_start_frame, video_end_frame | |
| start_pts = int((start_frame/real_fps) / time_base) | |
| end_pts = int((end_frame/real_fps) / time_base) | |
| frames = [] | |
| container.seek(max(start_pts, 0), stream=audio_stream) | |
| try: | |
| for frame in container.decode(**{"audio":0}): | |
| if frame.pts < start_pts: | |
| continue | |
| frames.append(frame.to_ndarray()) | |
| if frame.pts > end_pts: | |
| break | |
| except: | |
| gc.collect() | |
| pass | |
| # gc.collect() | |
| container.close() | |
| del video_bytes_stream | |
| audio_raw = np.concatenate(frames, 1) | |
| audio = torch.from_numpy(audio_raw) | |
| if audio.size(0) == 2: | |
| audio = torch.mean(audio, dim=0, keepdim=True) | |
| if len(audio.shape) == 1: | |
| audio = audio.unsqueeze(0) | |
| assert max_audio_length == 10, max_audio_length | |
| max_length = max_audio_length * sr | |
| if csr != sr: | |
| trans = torchaudio.transforms.Resample(csr, sr) | |
| audio = trans(audio) | |
| if audio.shape[1] >= max_length: | |
| max_start = audio.shape[1] - max_length | |
| start = random.randint(0, max_start) | |
| audio = audio[:, start: start + max_length] | |
| audio = audio * 2 ** 15 | |
| fbank = torchaudio.compliance.kaldi.fbank(audio, num_mel_bins=64, sample_frequency=16000, frame_length=25, frame_shift=10) | |
| fbank_mean = 15.41663 | |
| fbank_std = 6.55582 | |
| fbank = (fbank - fbank_mean) / (fbank_std * 2) # 998, 64 | |
| src_length = fbank.shape[0] | |
| pad_len = 998 - src_length | |
| fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank) | |
| padding_mask = torch.cat((torch.zeros(1, src_length), torch.ones(1, pad_len)), -1).bool() | |
| return fbank#, padding_mask | |
| def load_full_audio_av(video_path, sr, max_audio_length, client): | |
| assert client is not None | |
| video_bytes_stream = client.get(video_path) #, enable_stream_lazyloding=False) | |
| try: | |
| container = av.open(io.BytesIO(video_bytes_stream)) | |
| except Exception as e: | |
| logger.warn(f"Something wrong {e} when av.open (video_path: {video_path})!") | |
| return None | |
| if len(container.streams.audio) == 0: | |
| logger.warn(f"There is no audio! (video_path: {video_path})!") | |
| return None | |
| audio_stream = container.streams.audio[0] | |
| csr = audio_stream.sample_rate | |
| frames = [] | |
| try: | |
| for frame in container.decode(**{"audio":0}): | |
| frames.append(frame.to_ndarray()) | |
| except: | |
| gc.collect() | |
| pass | |
| # gc.collect() | |
| container.close() | |
| del video_bytes_stream | |
| audio_raw = np.concatenate(frames, 1) | |
| audio = torch.from_numpy(audio_raw) | |
| if audio.size(0) == 2: | |
| audio = torch.mean(audio, dim=0, keepdim=True) | |
| if len(audio.shape)==1: | |
| audio = audio.unsqueeze(0) | |
| assert max_audio_length == 10, max_audio_length | |
| max_length = max_audio_length * sr | |
| if csr != sr: | |
| trans = torchaudio.transforms.Resample(csr, sr) | |
| audio = trans(audio) | |
| if audio.shape[1] >= max_length: | |
| max_start = audio.shape[1] - max_length | |
| start = random.randint(0, max_start) | |
| audio = audio[:, start: start + max_length] | |
| audio = audio * 2 ** 15 | |
| fbank = torchaudio.compliance.kaldi.fbank(audio, num_mel_bins=64, sample_frequency=16000, frame_length=25, frame_shift=10) | |
| fbank_mean = 15.41663 | |
| fbank_std = 6.55582 | |
| fbank = (fbank - fbank_mean) / (fbank_std * 2) # 998, 64 | |
| src_length = fbank.shape[0] | |
| pad_len = 998 - src_length | |
| fbank = torch.nn.ZeroPad2d((0, 0, 0, pad_len))(fbank) | |
| padding_mask = torch.cat((torch.zeros(1, src_length), torch.ones(1, pad_len)), -1).bool() | |
| return fbank #, padding_mask | |
| # frames = video_reader.get_batch(frame_indices) # (T, H, W, C), torch.uint8 | |
| # # frames = frames.permute(0, 3, 1, 2) # (T, C, H, W), torch.uint8 | |