File size: 5,283 Bytes
be13417
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
"""
 Copyright (c) 2022, salesforce.com, inc.
 All rights reserved.
 SPDX-License-Identifier: BSD-3-Clause
 For full license text, see the LICENSE file in the repo root or https://opensource.org/licenses/BSD-3-Clause
"""

import torch
import torchaudio
import torchaudio.transforms as transforms
from moviepy.editor import VideoFileClip
from omegaconf import OmegaConf
import torchaudio.compliance.kaldi as ta_kaldi

from lavis.common.registry import registry
from lavis.processors.base_processor import BaseProcessor
from lavis.models.beats.Tokenizers import TokenizersConfig, Tokenizers

MAX_INT = registry.get("MAX_INT")


@registry.register_processor("beats_audio")
class BeatsAudioProcessor(BaseProcessor):
    def __init__(self, model_name, sampling_rate, n_frames, frame_length, is_eval):
        """
        Adapted from https://github.com/NINAnor/rare_species_detections/blob/main/BEATs/BEATs.py
        """
        super().__init__()

        self.model_name = model_name
        self.sampling_rate = sampling_rate
        self.n_frames = n_frames
        self.frame_length = frame_length
        self.fbank_mean = 15.41663
        self.fbank_std = 6.55582
        self.is_eval = is_eval

    def _load_audio(self, aupath):
        if aupath.endswith('.mp4'):
            video = VideoFileClip(aupath)
            audio_np = video.audio.to_soundarray(fps=self.sampling_rate)
            if len(audio_np.shape) == 2:
                audio_np = audio_np.mean(axis=1)  # Convert to mono
            waveform = torch.tensor(audio_np).float()
            sr = self.sampling_rate
        else:
            waveform, sr = torchaudio.load(aupath)
            if waveform.shape[0] == 2: 
                waveform = torch.mean(waveform, dim=0)
            if sr != self.sampling_rate:
                resampler = torchaudio.transforms.Resample(sr, self.sampling_rate)
                waveform = resampler(waveform)
        return waveform

    def __call__(self, aupath, start_sec=None, end_sec=None):
        """
        Args:
            aupath: path to audio file
        Returns:
            torch.tensor: audio clip after transforms.
        """
        # Helper function to return empty tensor for invalid audio
        def empty_audio_tensor():
            return torch.zeros((self.n_frames, self.frame_length, 128))
        
        try:
            # Handle MP4 files
            if aupath.endswith('.mp4'):
                video = VideoFileClip(aupath)
                if start_sec is not None and end_sec is not None:
                    video = video.subclip(start_sec, end_sec)
                audio_np = video.audio.to_soundarray(fps=self.sampling_rate)
                if audio_np.ndim == 2:
                    audio_np = audio_np.mean(axis=1)  # Convert to mono
                waveform = torch.tensor(audio_np).float()
                sr = self.sampling_rate
            else:
                waveform, sr = torchaudio.load(aupath)

            # Validate waveform
            if len(waveform.shape) == 0:
                return empty_audio_tensor()

            # Convert stereo to mono
            if waveform.shape[0] == 2:
                waveform = torch.mean(waveform, dim=0)

            # Resample waveform if necessary
            if sr != self.sampling_rate:
                resampler = torchaudio.transforms.Resample(sr, self.sampling_rate)
                waveform = resampler(waveform)

        except:
            return empty_audio_tensor()

        if waveform.ndim == 1:
            waveform = waveform.unsqueeze(0)

        waveform = waveform * 2**15

        # Compute fbank features
        try:
            fbank = ta_kaldi.fbank(
                waveform,
                num_mel_bins=128,
                sample_frequency=self.sampling_rate,
                frame_length=25,
                frame_shift=10,
            )
            fbank = (fbank - self.fbank_mean) / (2 * self.fbank_std)
        except:
            return empty_audio_tensor()

        # Handle padding and frames extraction differently for eval and training modes
        if not self.is_eval:
            fbank_pad_len = self.frame_length * self.n_frames - fbank.shape[0]
            if fbank_pad_len > 0:
                fbank = torch.nn.ZeroPad2d((0, 0, 0, fbank_pad_len))(fbank)
            fbank = fbank[:self.frame_length * self.n_frames]
            frames = [fbank[i*self.frame_length:(i+1)*self.frame_length].unsqueeze(0) for i in range(self.n_frames)]
        else:
            fbank_pad_len = fbank.shape[0] % self.frame_length
            if fbank_pad_len > 0:
                fbank = torch.nn.ZeroPad2d((0, 0, 0, fbank_pad_len))(fbank)
            curr_frames = fbank.shape[0] // self.frame_length
            frames = [fbank[i*self.frame_length:(i+1)*self.frame_length].unsqueeze(0) for i in range(curr_frames)]

        return torch.cat(frames, dim=0)

    @classmethod
    def from_config(cls, cfg=None):
        if cfg is None:
            cfg = OmegaConf.create()

        return cls(
            model_name=cfg.get("model_name", 'iter3'),
            sampling_rate=cfg.get("sampling_rate", 16000),
            n_frames=cfg.get("n_frames", 2),
            frame_length=cfg.get("frame_length", 512),
            is_eval=cfg.get("is_eval", False)
        )