Spaces:
Runtime error
Runtime error
File size: 5,283 Bytes
be13417 |
1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 132 133 134 135 136 137 138 139 140 141 |
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
import torchaudio
import torchaudio.transforms as transforms
from moviepy.editor import VideoFileClip
from omegaconf import OmegaConf
import torchaudio.compliance.kaldi as ta_kaldi
from lavis.common.registry import registry
from lavis.processors.base_processor import BaseProcessor
from lavis.models.beats.Tokenizers import TokenizersConfig, Tokenizers
MAX_INT = registry.get("MAX_INT")
@registry.register_processor("beats_audio")
class BeatsAudioProcessor(BaseProcessor):
def __init__(self, model_name, sampling_rate, n_frames, frame_length, is_eval):
"""
Adapted from https://github.com/NINAnor/rare_species_detections/blob/main/BEATs/BEATs.py
"""
super().__init__()
self.model_name = model_name
self.sampling_rate = sampling_rate
self.n_frames = n_frames
self.frame_length = frame_length
self.fbank_mean = 15.41663
self.fbank_std = 6.55582
self.is_eval = is_eval
def _load_audio(self, aupath):
if aupath.endswith('.mp4'):
video = VideoFileClip(aupath)
audio_np = video.audio.to_soundarray(fps=self.sampling_rate)
if len(audio_np.shape) == 2:
audio_np = audio_np.mean(axis=1) # Convert to mono
waveform = torch.tensor(audio_np).float()
sr = self.sampling_rate
else:
waveform, sr = torchaudio.load(aupath)
if waveform.shape[0] == 2:
waveform = torch.mean(waveform, dim=0)
if sr != self.sampling_rate:
resampler = torchaudio.transforms.Resample(sr, self.sampling_rate)
waveform = resampler(waveform)
return waveform
def __call__(self, aupath, start_sec=None, end_sec=None):
"""
Args:
aupath: path to audio file
Returns:
torch.tensor: audio clip after transforms.
"""
# Helper function to return empty tensor for invalid audio
def empty_audio_tensor():
return torch.zeros((self.n_frames, self.frame_length, 128))
try:
# Handle MP4 files
if aupath.endswith('.mp4'):
video = VideoFileClip(aupath)
if start_sec is not None and end_sec is not None:
video = video.subclip(start_sec, end_sec)
audio_np = video.audio.to_soundarray(fps=self.sampling_rate)
if audio_np.ndim == 2:
audio_np = audio_np.mean(axis=1) # Convert to mono
waveform = torch.tensor(audio_np).float()
sr = self.sampling_rate
else:
waveform, sr = torchaudio.load(aupath)
# Validate waveform
if len(waveform.shape) == 0:
return empty_audio_tensor()
# Convert stereo to mono
if waveform.shape[0] == 2:
waveform = torch.mean(waveform, dim=0)
# Resample waveform if necessary
if sr != self.sampling_rate:
resampler = torchaudio.transforms.Resample(sr, self.sampling_rate)
waveform = resampler(waveform)
except:
return empty_audio_tensor()
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
waveform = waveform * 2**15
# Compute fbank features
try:
fbank = ta_kaldi.fbank(
waveform,
num_mel_bins=128,
sample_frequency=self.sampling_rate,
frame_length=25,
frame_shift=10,
)
fbank = (fbank - self.fbank_mean) / (2 * self.fbank_std)
except:
return empty_audio_tensor()
# Handle padding and frames extraction differently for eval and training modes
if not self.is_eval:
fbank_pad_len = self.frame_length * self.n_frames - fbank.shape[0]
if fbank_pad_len > 0:
fbank = torch.nn.ZeroPad2d((0, 0, 0, fbank_pad_len))(fbank)
fbank = fbank[:self.frame_length * self.n_frames]
frames = [fbank[i*self.frame_length:(i+1)*self.frame_length].unsqueeze(0) for i in range(self.n_frames)]
else:
fbank_pad_len = fbank.shape[0] % self.frame_length
if fbank_pad_len > 0:
fbank = torch.nn.ZeroPad2d((0, 0, 0, fbank_pad_len))(fbank)
curr_frames = fbank.shape[0] // self.frame_length
frames = [fbank[i*self.frame_length:(i+1)*self.frame_length].unsqueeze(0) for i in range(curr_frames)]
return torch.cat(frames, dim=0)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
return cls(
model_name=cfg.get("model_name", 'iter3'),
sampling_rate=cfg.get("sampling_rate", 16000),
n_frames=cfg.get("n_frames", 2),
frame_length=cfg.get("frame_length", 512),
is_eval=cfg.get("is_eval", False)
) |