Spaces:
Runtime error
Runtime error
| """ | |
| Copyright (c) 2022, salesforce.com, inc. | |
| All rights reserved. | |
| SPDX-License-Identifier: BSD-3-Clause | |
| For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause | |
| """ | |
| import torch | |
| import torchaudio | |
| import torchaudio.transforms as transforms | |
| from moviepy.editor import VideoFileClip | |
| from omegaconf import OmegaConf | |
| import torchaudio.compliance.kaldi as ta_kaldi | |
| from lavis.common.registry import registry | |
| from lavis.processors.base_processor import BaseProcessor | |
| from lavis.models.beats.Tokenizers import TokenizersConfig, Tokenizers | |
| MAX_INT = registry.get("MAX_INT") | |
| class BeatsAudioProcessor(BaseProcessor): | |
| def __init__(self, model_name, sampling_rate, n_frames, frame_length, is_eval): | |
| """ | |
| Adapted from https://github.com/NINAnor/rare_species_detections/blob/main/BEATs/BEATs.py | |
| """ | |
| super().__init__() | |
| self.model_name = model_name | |
| self.sampling_rate = sampling_rate | |
| self.n_frames = n_frames | |
| self.frame_length = frame_length | |
| self.fbank_mean = 15.41663 | |
| self.fbank_std = 6.55582 | |
| self.is_eval = is_eval | |
| def _load_audio(self, aupath): | |
| if aupath.endswith('.mp4'): | |
| video = VideoFileClip(aupath) | |
| audio_np = video.audio.to_soundarray(fps=self.sampling_rate) | |
| if len(audio_np.shape) == 2: | |
| audio_np = audio_np.mean(axis=1) # Convert to mono | |
| waveform = torch.tensor(audio_np).float() | |
| sr = self.sampling_rate | |
| else: | |
| waveform, sr = torchaudio.load(aupath) | |
| if waveform.shape[0] == 2: | |
| waveform = torch.mean(waveform, dim=0) | |
| if sr != self.sampling_rate: | |
| resampler = torchaudio.transforms.Resample(sr, self.sampling_rate) | |
| waveform = resampler(waveform) | |
| return waveform | |
| def __call__(self, aupath, start_sec=None, end_sec=None): | |
| """ | |
| Args: | |
| aupath: path to audio file | |
| Returns: | |
| torch.tensor: audio clip after transforms. | |
| """ | |
| # Helper function to return empty tensor for invalid audio | |
| def empty_audio_tensor(): | |
| return torch.zeros((self.n_frames, self.frame_length, 128)) | |
| try: | |
| # Handle MP4 files | |
| if aupath.endswith('.mp4'): | |
| video = VideoFileClip(aupath) | |
| if start_sec is not None and end_sec is not None: | |
| video = video.subclip(start_sec, end_sec) | |
| audio_np = video.audio.to_soundarray(fps=self.sampling_rate) | |
| if audio_np.ndim == 2: | |
| audio_np = audio_np.mean(axis=1) # Convert to mono | |
| waveform = torch.tensor(audio_np).float() | |
| sr = self.sampling_rate | |
| else: | |
| waveform, sr = torchaudio.load(aupath) | |
| # Validate waveform | |
| if len(waveform.shape) == 0: | |
| return empty_audio_tensor() | |
| # Convert stereo to mono | |
| if waveform.shape[0] == 2: | |
| waveform = torch.mean(waveform, dim=0) | |
| # Resample waveform if necessary | |
| if sr != self.sampling_rate: | |
| resampler = torchaudio.transforms.Resample(sr, self.sampling_rate) | |
| waveform = resampler(waveform) | |
| except: | |
| return empty_audio_tensor() | |
| if waveform.ndim == 1: | |
| waveform = waveform.unsqueeze(0) | |
| waveform = waveform * 2**15 | |
| # Compute fbank features | |
| try: | |
| fbank = ta_kaldi.fbank( | |
| waveform, | |
| num_mel_bins=128, | |
| sample_frequency=self.sampling_rate, | |
| frame_length=25, | |
| frame_shift=10, | |
| ) | |
| fbank = (fbank - self.fbank_mean) / (2 * self.fbank_std) | |
| except: | |
| return empty_audio_tensor() | |
| # Handle padding and frames extraction differently for eval and training modes | |
| if not self.is_eval: | |
| fbank_pad_len = self.frame_length * self.n_frames - fbank.shape[0] | |
| if fbank_pad_len > 0: | |
| fbank = torch.nn.ZeroPad2d((0, 0, 0, fbank_pad_len))(fbank) | |
| fbank = fbank[:self.frame_length * self.n_frames] | |
| frames = [fbank[i*self.frame_length:(i+1)*self.frame_length].unsqueeze(0) for i in range(self.n_frames)] | |
| else: | |
| fbank_pad_len = fbank.shape[0] % self.frame_length | |
| if fbank_pad_len > 0: | |
| fbank = torch.nn.ZeroPad2d((0, 0, 0, fbank_pad_len))(fbank) | |
| curr_frames = fbank.shape[0] // self.frame_length | |
| frames = [fbank[i*self.frame_length:(i+1)*self.frame_length].unsqueeze(0) for i in range(curr_frames)] | |
| return torch.cat(frames, dim=0) | |
| def from_config(cls, cfg=None): | |
| if cfg is None: | |
| cfg = OmegaConf.create() | |
| return cls( | |
| model_name=cfg.get("model_name", 'iter3'), | |
| sampling_rate=cfg.get("sampling_rate", 16000), | |
| n_frames=cfg.get("n_frames", 2), | |
| frame_length=cfg.get("frame_length", 512), | |
| is_eval=cfg.get("is_eval", False) | |
| ) |