instructblip / lavis /processors /audio_processors.py
Thien Huynh
Initialization
be13417
"""
Copyright (c) 2022, salesforce.com, inc.
All rights reserved.
SPDX-License-Identifier: BSD-3-Clause
For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""
import torch
import torchaudio
import torchaudio.transforms as transforms
from moviepy.editor import VideoFileClip
from omegaconf import OmegaConf
import torchaudio.compliance.kaldi as ta_kaldi
from lavis.common.registry import registry
from lavis.processors.base_processor import BaseProcessor
from lavis.models.beats.Tokenizers import TokenizersConfig, Tokenizers
MAX_INT = registry.get("MAX_INT")
@registry.register_processor("beats_audio")
class BeatsAudioProcessor(BaseProcessor):
def __init__(self, model_name, sampling_rate, n_frames, frame_length, is_eval):
"""
Adapted from https://github.com/NINAnor/rare_species_detections/blob/main/BEATs/BEATs.py
"""
super().__init__()
self.model_name = model_name
self.sampling_rate = sampling_rate
self.n_frames = n_frames
self.frame_length = frame_length
self.fbank_mean = 15.41663
self.fbank_std = 6.55582
self.is_eval = is_eval
def _load_audio(self, aupath):
if aupath.endswith('.mp4'):
video = VideoFileClip(aupath)
audio_np = video.audio.to_soundarray(fps=self.sampling_rate)
if len(audio_np.shape) == 2:
audio_np = audio_np.mean(axis=1) # Convert to mono
waveform = torch.tensor(audio_np).float()
sr = self.sampling_rate
else:
waveform, sr = torchaudio.load(aupath)
if waveform.shape[0] == 2:
waveform = torch.mean(waveform, dim=0)
if sr != self.sampling_rate:
resampler = torchaudio.transforms.Resample(sr, self.sampling_rate)
waveform = resampler(waveform)
return waveform
def __call__(self, aupath, start_sec=None, end_sec=None):
"""
Args:
aupath: path to audio file
Returns:
torch.tensor: audio clip after transforms.
"""
# Helper function to return empty tensor for invalid audio
def empty_audio_tensor():
return torch.zeros((self.n_frames, self.frame_length, 128))
try:
# Handle MP4 files
if aupath.endswith('.mp4'):
video = VideoFileClip(aupath)
if start_sec is not None and end_sec is not None:
video = video.subclip(start_sec, end_sec)
audio_np = video.audio.to_soundarray(fps=self.sampling_rate)
if audio_np.ndim == 2:
audio_np = audio_np.mean(axis=1) # Convert to mono
waveform = torch.tensor(audio_np).float()
sr = self.sampling_rate
else:
waveform, sr = torchaudio.load(aupath)
# Validate waveform
if len(waveform.shape) == 0:
return empty_audio_tensor()
# Convert stereo to mono
if waveform.shape[0] == 2:
waveform = torch.mean(waveform, dim=0)
# Resample waveform if necessary
if sr != self.sampling_rate:
resampler = torchaudio.transforms.Resample(sr, self.sampling_rate)
waveform = resampler(waveform)
except:
return empty_audio_tensor()
if waveform.ndim == 1:
waveform = waveform.unsqueeze(0)
waveform = waveform * 2**15
# Compute fbank features
try:
fbank = ta_kaldi.fbank(
waveform,
num_mel_bins=128,
sample_frequency=self.sampling_rate,
frame_length=25,
frame_shift=10,
)
fbank = (fbank - self.fbank_mean) / (2 * self.fbank_std)
except:
return empty_audio_tensor()
# Handle padding and frames extraction differently for eval and training modes
if not self.is_eval:
fbank_pad_len = self.frame_length * self.n_frames - fbank.shape[0]
if fbank_pad_len > 0:
fbank = torch.nn.ZeroPad2d((0, 0, 0, fbank_pad_len))(fbank)
fbank = fbank[:self.frame_length * self.n_frames]
frames = [fbank[i*self.frame_length:(i+1)*self.frame_length].unsqueeze(0) for i in range(self.n_frames)]
else:
fbank_pad_len = fbank.shape[0] % self.frame_length
if fbank_pad_len > 0:
fbank = torch.nn.ZeroPad2d((0, 0, 0, fbank_pad_len))(fbank)
curr_frames = fbank.shape[0] // self.frame_length
frames = [fbank[i*self.frame_length:(i+1)*self.frame_length].unsqueeze(0) for i in range(curr_frames)]
return torch.cat(frames, dim=0)
@classmethod
def from_config(cls, cfg=None):
if cfg is None:
cfg = OmegaConf.create()
return cls(
model_name=cfg.get("model_name", 'iter3'),
sampling_rate=cfg.get("sampling_rate", 16000),
n_frames=cfg.get("n_frames", 2),
frame_length=cfg.get("frame_length", 512),
is_eval=cfg.get("is_eval", False)
)