File size: 7,903 Bytes
fc605f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
e387e4c
fc605f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
381fdd8
fc605f9
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
218
219
220
221
222
223
224
225
226
227
228
229
230
231
232
233
234
235
236
237
238
239
240
241
242
243
244
245
246
247
248
249
250
251
252
# Copyright (c) Meta Platforms, Inc. and affiliates. All Rights Reserved\n

from typing import Optional, Tuple

import numpy as np
from core.audio_visual_encoder.config import TransformerConfig as PEAVTransformerConfig
from transformers import ModernBertConfig


class DACVAEConfig:
    def __init__(
        self,
        encoder_dim: int = 64,
        encoder_rates: list[int] = [2, 8, 10, 12],
        latent_dim: int = 1024,
        decoder_dim: int = 1536,
        decoder_rates: list[int] = [12, 10, 8, 2],
        n_codebooks: int = 16,
        codebook_size: int = 1024,
        codebook_dim: int = 128,
        quantizer_dropout: bool = False,
        sample_rate: int = 48_000,
        mean: float = 0.0,
        std: float = 1.0,
    ):
        self.encoder_dim = encoder_dim
        self.encoder_rates = encoder_rates
        self.latent_dim = latent_dim
        self.decoder_dim = decoder_dim
        self.decoder_rates = decoder_rates
        self.n_codebooks = n_codebooks
        self.codebook_size = codebook_size
        self.codebook_dim = codebook_dim
        self.quantizer_dropout = quantizer_dropout
        self.sample_rate = sample_rate
        self.mean = mean
        self.std = std

    @property
    def hop_length(self):
        return int(np.prod(self.encoder_rates))


class TextEncoderConfig:
    def __init__(self, dim: int = 768):
        self.dim = dim


class T5EncoderConfig(TextEncoderConfig):
    def __init__(
        self,
        name: str = "t5-base",
        max_length: Optional[int] = 512,
        pad_mode: str = "longest",
        dim: int = 768,
    ):
        super().__init__(dim=dim)
        self.name = name
        self.max_length = max_length
        self.pad_mode = pad_mode


class VisionEncoderConfig:
    def __init__(self, dim: int = 1024, batch_size: int = 300):
        self.dim = dim
        self.batch_size = batch_size


class PerceptionEncoderConfig(VisionEncoderConfig):
    def __init__(
        self,
        dim: int = 1024,
        batch_size: int = 300,
        name: str = "PE-Core-L14-336",
        normalize_feature: bool = True,
        interpolation_mode: str = "BICUBIC",
        image_size: int = 336,
    ):
        super().__init__(dim=dim, batch_size=batch_size)
        self.name = name
        self.normalize_feature = normalize_feature
        self.interpolation_mode = interpolation_mode
        self.image_size = image_size


class TransformerConfig:
    def __init__(
        self,
        dim: int = 2048,
        n_heads: int = 16,
        n_layers: int = 16,
        dropout: float = 0.1,
        norm_eps: float = 1.0e-05,
        qk_norm: bool = True,
        fc_bias: bool = False,
        ffn_exp: int = 4,
        ffn_dim_multiplier: int = 1,
        multiple_of: int = 64,
        non_linearity: str = "swiglu",
        use_rope: bool = True,
        max_positions: int = 10000,
        frequency_embedding_dim: int = 256,
        timestep_non_linearity: str = "swiglu",
        t_block_non_linearity: str = "silu",
        t_block_bias: bool = True,
        context_dim: int = 2048,
        context_non_linearity: str = "swiglu",
        context_embedder_dropout: float = 0.0,
        context_norm: bool = False,
        out_channels: int = 256,
        in_channels: Optional[int] = None,
    ):
        self.dim = dim
        self.n_heads = n_heads
        self.n_layers = n_layers
        self.dropout = dropout
        self.norm_eps = norm_eps
        self.qk_norm = qk_norm
        self.fc_bias = fc_bias
        self.ffn_exp = ffn_exp
        self.ffn_dim_multiplier = ffn_dim_multiplier
        self.multiple_of = multiple_of
        self.non_linearity = non_linearity
        self.use_rope = use_rope
        self.max_positions = max_positions
        self.frequency_embedding_dim = frequency_embedding_dim
        self.timestep_non_linearity = timestep_non_linearity
        self.t_block_non_linearity = t_block_non_linearity
        self.t_block_bias = t_block_bias
        self.context_dim = context_dim
        self.context_non_linearity = context_non_linearity
        self.context_embedder_dropout = context_embedder_dropout
        self.context_norm = context_norm
        self.out_channels = out_channels
        self.in_channels = in_channels


class RankerConfig:
    kind: str


class ImageBindRankerConfig(RankerConfig):
    kind: str = "imagebind"

    def __init__(self, checkpoint: Optional[str] = None):
        self.checkpoint = checkpoint


class ClapRankerConfig(RankerConfig):
    kind: str = "clap"

    def __init__(self, checkpoint: Optional[str] = None):
        self.checkpoint = checkpoint


class JudgeRankerConfig(RankerConfig):
    kind: str = "judge"

    def __init__(self, checkpoint_or_model_id: str = "facebook/sam-audio-judge"):
        self.checkpoint_or_model_id = checkpoint_or_model_id


class SoundActivityRankerConfig(RankerConfig):
    kind: str = "sound_activity"

    def __init__(
        self,
        threshold_mode: str = "rel_to_max",
        sil_threshold: float = -40,
        metric: str = "iou",
    ):
        self.threshold_mode = threshold_mode
        self.sil_threshold = sil_threshold
        self.metric = metric


class EnsembleRankerConfig(RankerConfig):
    kind: str = "ensemble"

    def __init__(self, rankers: dict[str, Tuple[RankerConfig, float]]):
        self.rankers = rankers


def parse_ranker_config(config_dict: dict):
    kind = config_dict.pop("kind")
    match kind:
        case ImageBindRankerConfig.kind:
            return ImageBindRankerConfig(**config_dict)
        case ClapRankerConfig.kind:
            return ClapRankerConfig(**config_dict)
        case JudgeRankerConfig.kind:
            return JudgeRankerConfig(**config_dict)
        case SoundActivityRankerConfig.kind:
            return SoundActivityRankerConfig(**config_dict)
        case EnsembleRankerConfig.kind:
            return EnsembleRankerConfig(
                {
                    k: (parse_ranker_config(v), w)
                    for k, (v, w) in config_dict["rankers"].items()
                }
            )


class SAMAudioConfig:
    def __init__(
        self,
        in_channels: int = 768,
        audio_codec=None,
        text_encoder=None,
        vision_encoder=None,
        transformer=None,
        num_anchors: int = 3,
        anchor_embedding_dim: int = 128,
        visual_ranker=None,
        text_ranker=None,
        span_predictor: Optional[str] = "pe-a-frame-large",
    ):
        self.in_channels = in_channels
        self.audio_codec = DACVAEConfig(**(audio_codec or {}))
        self.text_encoder = T5EncoderConfig(**(text_encoder or {}))
        self.vision_encoder = PerceptionEncoderConfig(**(vision_encoder or {}))
        self.transformer = TransformerConfig(**(transformer or {}))
        self.num_anchors = num_anchors
        self.anchor_embedding_dim = anchor_embedding_dim
        self.visual_ranker = (
            None if visual_ranker is None else parse_ranker_config(visual_ranker)
        )
        self.text_ranker = (
            None if text_ranker is None else parse_ranker_config(text_ranker)
        )
        self.span_predictor = span_predictor


class SAMAudioJudgeConfig:
    def __init__(
        self,
        audio_codec: DACVAEConfig = None,
        transformer: PEAVTransformerConfig = None,
        text_model: ModernBertConfig = None,
        finetune_transformer: PEAVTransformerConfig = None,
        nth_text_layer: int = 22,
        bottleneck_dim: int = 256,
    ):
        self.audio_codec = DACVAEConfig(**(audio_codec or {}))
        self.transformer = PEAVTransformerConfig(**(transformer or {}))
        self.text_model = ModernBertConfig(**(text_model or {}))
        self.finetune_transformer = PEAVTransformerConfig(
            **(finetune_transformer or {})
        )
        self.nth_text_layer = nth_text_layer
        self.bottleneck_dim = bottleneck_dim