AndreasXi commited on
Commit
6daf432
·
verified ·
1 Parent(s): e51b2a2

Upload folder using huggingface_hub

Browse files
config.json ADDED
@@ -0,0 +1,50 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "auto_map": {
3
+ "AutoConfig": "configuration_finelap.FineLAPConfig",
4
+ "AutoModel": "modeling_finelap.FineLAPModel"
5
+ },
6
+ "architectures": [
7
+ "FineLAPModel"
8
+ ],
9
+ "audio_config": {
10
+ "_attn_implementation_autoset": true,
11
+ "activation_dropout": 0.0,
12
+ "attn_drop_rate": 0.0,
13
+ "depth": 12,
14
+ "drop_rate": 0.0,
15
+ "embed_dim": 768,
16
+ "end_drop_path_rate": 0.0,
17
+ "fixed_positions": true,
18
+ "img_size": [
19
+ 1024,
20
+ 128
21
+ ],
22
+ "in_chans": 1,
23
+ "layer_norm_first": false,
24
+ "mel_bins": 128,
25
+ "mlp_ratio": 4.0,
26
+ "model_type": "eat",
27
+ "model_variant": "pretrain",
28
+ "norm_affine": true,
29
+ "norm_eps": 1e-06,
30
+ "num_classes": 527,
31
+ "num_heads": 12,
32
+ "patch_size": 16,
33
+ "post_mlp_drop": 0.0,
34
+ "qkv_bias": true,
35
+ "start_drop_path_rate": 0.0,
36
+ "stride": 16
37
+ },
38
+ "b_global": -10.0,
39
+ "b_local": -10.0,
40
+ "embed_size": 1024,
41
+ "local_audio_proj_type": "transformer",
42
+ "model_type": "finelap",
43
+ "normalize_dense_audio_embeds": true,
44
+ "temp_global": 0.1,
45
+ "temp_local": 0.1,
46
+ "text_encoder_name": "roberta-base",
47
+ "torch_dtype": "float32",
48
+ "transformers_version": "4.51.3",
49
+ "unify_audio_proj": false
50
+ }
configuration_eat.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # configuration_eat.py
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+ class EATConfig(PretrainedConfig):
6
+ model_type = "eat"
7
+
8
+ def __init__(
9
+ self,
10
+ embed_dim=768,
11
+ depth=12,
12
+ num_heads=12,
13
+ patch_size=16,
14
+ stride=16,
15
+ in_chans=1,
16
+ mel_bins=128,
17
+ max_length=768,
18
+ num_classes=527,
19
+ model_variant="pretrain", # or "finetune"
20
+
21
+ mlp_ratio=4.0,
22
+ qkv_bias=True,
23
+ drop_rate=0.0,
24
+ attn_drop_rate=0.0,
25
+ activation_dropout=0.0,
26
+ post_mlp_drop=0.0,
27
+ start_drop_path_rate=0.0,
28
+ end_drop_path_rate=0.0,
29
+
30
+ layer_norm_first=False,
31
+ norm_eps=1e-6,
32
+ norm_affine=True,
33
+ fixed_positions=True,
34
+
35
+ img_size=(1024, 128), # (target_length, mel_bins)
36
+
37
+ **kwargs,
38
+ ):
39
+ super().__init__(**kwargs)
40
+
41
+ self.embed_dim = embed_dim
42
+ self.depth = depth
43
+ self.num_heads = num_heads
44
+ self.patch_size = patch_size
45
+ self.stride = stride
46
+ self.in_chans = in_chans
47
+ self.mel_bins = mel_bins
48
+ self.max_length = max_length
49
+ self.num_classes = num_classes
50
+ self.model_variant = model_variant
51
+
52
+ self.mlp_ratio = mlp_ratio
53
+ self.qkv_bias = qkv_bias
54
+ self.drop_rate = drop_rate
55
+ self.attn_drop_rate = attn_drop_rate
56
+ self.activation_dropout = activation_dropout
57
+ self.post_mlp_drop = post_mlp_drop
58
+ self.start_drop_path_rate = start_drop_path_rate
59
+ self.end_drop_path_rate = end_drop_path_rate
60
+
61
+ self.layer_norm_first = layer_norm_first
62
+ self.norm_eps = norm_eps
63
+ self.norm_affine = norm_affine
64
+ self.fixed_positions = fixed_positions
65
+
66
+ self.img_size = img_size
configuration_finelap.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from transformers import PretrainedConfig
2
+ from .configuration_eat import EATConfig
3
+
4
+ class FineLAPConfig(PretrainedConfig):
5
+ model_type = "finelap"
6
+
7
+ def __init__(
8
+ self,
9
+ embed_size=1024,
10
+ temp_global=0.1,
11
+ b_global=-10.0,
12
+ temp_local=0.1,
13
+ b_local=-10.0,
14
+ local_audio_proj_type="transformer",
15
+ normalize_dense_audio_embeds=True,
16
+ unify_audio_proj=False,
17
+ text_encoder_name="roberta-base",
18
+ audio_config=None,
19
+ **kwargs
20
+ ):
21
+ self.embed_size = embed_size
22
+ self.temp_global = temp_global
23
+ self.b_global = b_global
24
+ self.temp_local = temp_local
25
+ self.b_local = b_local
26
+ self.local_audio_proj_type = local_audio_proj_type
27
+ self.normalize_dense_audio_embeds = normalize_dense_audio_embeds
28
+ self.unify_audio_proj = unify_audio_proj
29
+ self.text_encoder_name = text_encoder_name
30
+
31
+ # 👈 关键修改 2:如果读进来的是字典,把它重新包装成 EATConfig 对象
32
+ if isinstance(audio_config, dict):
33
+ self.audio_config = EATConfig(**audio_config)
34
+ elif isinstance(audio_config, EATConfig):
35
+ self.audio_config = audio_config
36
+ else:
37
+ self.audio_config = EATConfig()
38
+
39
+ super().__init__(**kwargs)
eat_model.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from timm.models.layers import trunc_normal_
4
+ from functools import partial
5
+ import numpy as np
6
+ from .eat_model_core import (
7
+ PatchEmbed_new,
8
+ get_2d_sincos_pos_embed_flexible,
9
+ FixedPositionalEncoder,
10
+ AltBlock
11
+ )
12
+
13
+ class EAT(nn.Module):
14
+ def __init__(self, config):
15
+ super().__init__()
16
+ self.config = config
17
+ self.mode = config.model_variant # "pretrain" or "finetune"
18
+
19
+ # === Embedding / Encoder ===
20
+ self.local_encoder = PatchEmbed_new(
21
+ img_size=config.img_size,
22
+ patch_size=config.patch_size,
23
+ in_chans=config.in_chans,
24
+ embed_dim=config.embed_dim,
25
+ stride=config.stride
26
+ )
27
+
28
+ self.extra_tokens = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
29
+ self.pos_drop = nn.Dropout(p=config.drop_rate, inplace=True)
30
+ trunc_normal_(self.extra_tokens, std=.02)
31
+
32
+ self.fixed_positional_encoder = (
33
+ FixedPositionalEncoder(self.build_sincos_pos_embed()) if config.fixed_positions else None
34
+ )
35
+
36
+ norm_layer = partial(nn.LayerNorm, eps=config.norm_eps, elementwise_affine=config.norm_affine)
37
+ dpr = np.linspace(config.start_drop_path_rate, config.end_drop_path_rate, config.depth)
38
+ self.blocks = nn.ModuleList([
39
+ AltBlock(config.embed_dim, config.num_heads, config.mlp_ratio,
40
+ qkv_bias=config.qkv_bias, drop=config.drop_rate,
41
+ attn_drop=config.attn_drop_rate, mlp_drop=config.activation_dropout,
42
+ post_mlp_drop=config.post_mlp_drop, drop_path=dpr[i],
43
+ norm_layer=norm_layer, layer_norm_first=config.layer_norm_first,
44
+ ffn_targets=True)
45
+ for i in range(config.depth)
46
+ ])
47
+
48
+ self.pre_norm = norm_layer(config.embed_dim)
49
+
50
+ # === Head (for finetune) ===
51
+ if self.mode == "finetune":
52
+ self.fc_norm = nn.LayerNorm(config.embed_dim)
53
+ self.head = nn.Linear(config.embed_dim, config.num_classes, bias=True)
54
+ else:
55
+ self.head = nn.Identity()
56
+
57
+ self.apply(self._init_weights)
58
+
59
+ def build_sincos_pos_embed(self):
60
+ W = self.config.mel_bins // self.config.patch_size
61
+ max_length = self.config.max_length
62
+ embed_dim = self.config.embed_dim
63
+ pos_embed = nn.Parameter(torch.zeros(1, max_length * W, embed_dim), requires_grad=False)
64
+ emb = get_2d_sincos_pos_embed_flexible(embed_dim, (max_length, W), cls_token=False)
65
+ pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0))
66
+ return pos_embed
67
+
68
+ def _init_weights(self, m):
69
+ if isinstance(m, nn.Linear):
70
+ trunc_normal_(m.weight, std=.02)
71
+ if m.bias is not None:
72
+ nn.init.constant_(m.bias, 0)
73
+ elif isinstance(m, nn.LayerNorm):
74
+ nn.init.constant_(m.bias, 0)
75
+ nn.init.constant_(m.weight, 1.0)
76
+
77
+ def encode(self, x):
78
+ B = x.shape[0]
79
+ x = self.local_encoder(x)
80
+ if self.fixed_positional_encoder is not None:
81
+ x = x + self.fixed_positional_encoder(x, None)[:, :x.size(1), :]
82
+ x = torch.cat((self.extra_tokens.expand(B, -1, -1), x), dim=1)
83
+ x = self.pre_norm(x)
84
+ x = self.pos_drop(x)
85
+ for blk in self.blocks:
86
+ x, _ = blk(x)
87
+ return x
88
+
89
+ def forward(self, x):
90
+ x = self.encode(x)
91
+ if self.mode == "finetune":
92
+ x = x[:, 0] # use cls token
93
+ x = self.fc_norm(x)
94
+ x = self.head(x)
95
+ return x
96
+
97
+ def extract_features(self, x):
98
+ x = self.encode(x)
99
+ return x
eat_model_core.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from timm.models.layers import to_2tuple
6
+
7
+ class PatchEmbed_new(nn.Module):
8
+ """ Flexible Image to Patch Embedding
9
+ """
10
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=16):
11
+ super().__init__()
12
+ img_size = to_2tuple(img_size)
13
+ patch_size = to_2tuple(patch_size)
14
+ stride = to_2tuple(stride)
15
+
16
+ self.img_size = img_size
17
+ self.patch_size = patch_size
18
+
19
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
20
+
21
+ def forward(self, x):
22
+ x = self.proj(x)
23
+ x = x.flatten(2).transpose(1, 2)
24
+ return x
25
+
26
+
27
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
28
+ """
29
+ grid_size: int of the grid height and width
30
+ return:
31
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
32
+ """
33
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
34
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
35
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
36
+ grid = np.stack(grid, axis=0)
37
+
38
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
39
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
40
+ if cls_token:
41
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
42
+ return pos_embed
43
+
44
+
45
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
46
+ assert embed_dim % 2 == 0
47
+
48
+ # use half of dimensions to encode grid_h
49
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
50
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
51
+
52
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
53
+ return emb
54
+
55
+
56
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
57
+ """
58
+ embed_dim: output dimension for each position
59
+ pos: a list of positions to be encoded: size (M,)
60
+ out: (M, D)
61
+ """
62
+ assert embed_dim % 2 == 0
63
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
64
+ omega /= embed_dim / 2.0
65
+ omega = 1.0 / 10000 ** omega # (D/2,)
66
+
67
+ pos = pos.reshape(-1) # (M,)
68
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
69
+
70
+ emb_sin = np.sin(out) # (M, D/2)
71
+ emb_cos = np.cos(out) # (M, D/2)
72
+
73
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
74
+ return emb
75
+
76
+
77
+ class FixedPositionalEncoder(nn.Module):
78
+ def __init__(self, pos_embed):
79
+ super().__init__()
80
+ self.positions = pos_embed
81
+
82
+ def forward(self, x, padding_mask):
83
+ return self.positions
84
+
85
+
86
+ class AltBlock(nn.Module):
87
+ def __init__(
88
+ self,
89
+ dim,
90
+ num_heads,
91
+ mlp_ratio=4.0,
92
+ qkv_bias=False,
93
+ qk_scale=None,
94
+ drop=0.0,
95
+ attn_drop=0.0,
96
+ mlp_drop=0.0,
97
+ post_mlp_drop=0.0,
98
+ drop_path=0.0,
99
+ act_layer=nn.GELU,
100
+ norm_layer=nn.LayerNorm,
101
+ layer_norm_first=True,
102
+ ffn_targets=False,
103
+ cosine_attention=False,
104
+ ):
105
+ super().__init__()
106
+
107
+ self.layer_norm_first = layer_norm_first
108
+ self.ffn_targets = ffn_targets
109
+
110
+ from timm.models.vision_transformer import DropPath, Mlp
111
+
112
+ self.norm1 = norm_layer(dim)
113
+ self.attn = AltAttention(
114
+ dim,
115
+ num_heads=num_heads,
116
+ qkv_bias=qkv_bias,
117
+ qk_scale=qk_scale,
118
+ attn_drop=attn_drop,
119
+ proj_drop=drop,
120
+ cosine_attention=cosine_attention,
121
+ )
122
+
123
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
124
+ self.norm2 = norm_layer(dim)
125
+ mlp_hidden_dim = int(dim * mlp_ratio)
126
+ self.mlp = Mlp(
127
+ in_features=dim,
128
+ hidden_features=mlp_hidden_dim,
129
+ act_layer=act_layer,
130
+ drop=mlp_drop,
131
+ )
132
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
133
+
134
+ def forward(self, x, padding_mask=None, alibi_bias=None):
135
+ if self.layer_norm_first:
136
+ x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
137
+ r = x = self.mlp(self.norm2(x))
138
+ t = x
139
+ x = r + self.drop_path(self.post_mlp_dropout(x))
140
+ if not self.ffn_targets:
141
+ t = x
142
+ else:
143
+ x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
144
+ r = x = self.norm1(x)
145
+ x = self.mlp(x)
146
+ t = x
147
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
148
+ if not self.ffn_targets:
149
+ t = x
150
+
151
+ return x, t
152
+
153
+
154
+ class AltAttention(nn.Module):
155
+ def __init__(
156
+ self,
157
+ dim,
158
+ num_heads=8,
159
+ qkv_bias=False,
160
+ qk_scale=None,
161
+ attn_drop=0.0,
162
+ proj_drop=0.0,
163
+ cosine_attention=False,
164
+ ):
165
+ super().__init__()
166
+ self.num_heads = num_heads
167
+ head_dim = dim // num_heads
168
+ self.scale = qk_scale or head_dim ** -0.5
169
+
170
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
171
+ self.attn_drop = nn.Dropout(attn_drop)
172
+ self.proj = nn.Linear(dim, dim)
173
+ self.proj_drop = nn.Dropout(proj_drop)
174
+
175
+ self.cosine_attention = cosine_attention
176
+
177
+ if cosine_attention:
178
+ self.logit_scale = nn.Parameter(
179
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
180
+ )
181
+
182
+ def forward(self, x, padding_mask=None, alibi_bias=None):
183
+ B, N, C = x.shape
184
+ qkv = (
185
+ self.qkv(x)
186
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
187
+ .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
188
+ )
189
+ q, k, v = (
190
+ qkv[0],
191
+ qkv[1],
192
+ qkv[2],
193
+ ) # make torchscript happy (cannot use tensor as tuple)
194
+
195
+ dtype = q.dtype
196
+
197
+ if self.cosine_attention:
198
+ # cosine attention
199
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
200
+ logit_scale = torch.clamp(
201
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
202
+ ).exp()
203
+ attn = attn * logit_scale
204
+ else:
205
+ q = q * self.scale
206
+ attn = q @ k.transpose(-2, -1)
207
+
208
+ if alibi_bias is not None:
209
+ attn = attn.type_as(alibi_bias)
210
+ attn[:, : alibi_bias.size(1)] += alibi_bias
211
+
212
+ if padding_mask is not None and padding_mask.any():
213
+ attn = attn.masked_fill(
214
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
215
+ float("-inf"),
216
+ )
217
+
218
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
219
+ attn = self.attn_drop(attn)
220
+ x = (attn @ v).transpose(1, 2) #
221
+ x = x.reshape(B, N, C)
222
+ x = self.proj(x)
223
+ x = self.proj_drop(x)
224
+ return x
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:13b9646c9f9d48513c0145bed75e654179e83f0fd8d49ed4ffc5d6b8f3353fb4
3
+ size 974773008
modeling_eat.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_eat.py
2
+
3
+ from transformers import PreTrainedModel
4
+ from .configuration_eat import EATConfig
5
+ from .eat_model import EAT
6
+
7
+ class EATModel(PreTrainedModel):
8
+ config_class = EATConfig
9
+
10
+ def __init__(self, config: EATConfig):
11
+ super().__init__(config)
12
+ self.model = EAT(config)
13
+
14
+ def forward(self, *args, **kwargs):
15
+ return self.model(*args, **kwargs)
16
+
17
+ def extract_features(self, x):
18
+ return self.model.extract_features(x)
modeling_finelap.py ADDED
@@ -0,0 +1,175 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_finelap.py
2
+ import torch
3
+ import torch.nn as nn
4
+ import torch.nn.functional as F
5
+ from transformers import PreTrainedModel, RobertaModel, RobertaTokenizer
6
+
7
+ from .configuration_finelap import FineLAPConfig
8
+ from .modeling_eat import EATModel
9
+
10
+ class FineLAPModel(PreTrainedModel):
11
+ config_class = FineLAPConfig
12
+
13
+ def __init__(self, config: FineLAPConfig):
14
+ super().__init__(config)
15
+ self.config = config
16
+
17
+ self.audio_encoder = EATModel(config.audio_config)
18
+ self.audio_width = getattr(config.audio_config, 'hidden_size', 768)
19
+
20
+ self.text_encoder = RobertaModel.from_pretrained(
21
+ config.text_encoder_name,
22
+ add_pooling_layer=False,
23
+ )
24
+
25
+ self.text_width = self.text_encoder.config.hidden_size
26
+ self.embed_size = config.embed_size
27
+
28
+ if config.temp_global != 0:
29
+ self.temp_global = nn.Parameter(torch.ones([]) * config.temp_global)
30
+ if config.b_global != 0:
31
+ self.b_global = nn.Parameter(torch.ones([]) * config.b_global)
32
+ if config.temp_local != 0:
33
+ self.temp_local = nn.Parameter(torch.ones([]) * config.temp_local)
34
+ if config.b_local != 0:
35
+ self.b_local = nn.Parameter(torch.ones([]) * config.b_local)
36
+
37
+ self.global_audio_proj = nn.Sequential(
38
+ nn.Linear(self.audio_width, self.embed_size),
39
+ nn.ReLU(),
40
+ nn.Linear(self.embed_size, self.embed_size),
41
+ )
42
+ self.global_text_proj = nn.Sequential(
43
+ nn.Linear(self.text_width, self.embed_size),
44
+ nn.ReLU(),
45
+ nn.Linear(self.embed_size, self.embed_size),
46
+ )
47
+
48
+ # 5. Local Audio Projection Layer
49
+ self.local_audio_proj_type = config.local_audio_proj_type
50
+ if self.local_audio_proj_type == "rnn":
51
+ self.local_audio_proj = nn.GRU(
52
+ input_size=self.audio_width,
53
+ hidden_size=int(self.embed_size / 2),
54
+ num_layers=2,
55
+ batch_first=True,
56
+ bidirectional=True
57
+ )
58
+ elif self.local_audio_proj_type == "linear":
59
+ self.local_audio_proj = nn.Sequential(
60
+ nn.Linear(self.audio_width, self.embed_size),
61
+ nn.ReLU(),
62
+ nn.Linear(self.embed_size, self.embed_size)
63
+ )
64
+ elif self.local_audio_proj_type == "transformer":
65
+ encoder_layer = nn.TransformerEncoderLayer(
66
+ d_model=self.embed_size,
67
+ nhead=8,
68
+ dim_feedforward=self.embed_size * 4,
69
+ dropout=0.1,
70
+ activation='relu',
71
+ batch_first=True
72
+ )
73
+ transformer_encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=2)
74
+ self.local_audio_proj = nn.Sequential(
75
+ nn.Linear(self.audio_width, self.embed_size),
76
+ transformer_encoder
77
+ )
78
+ elif self.local_audio_proj_type == "transformer_linearlast":
79
+ encoder_layer = nn.TransformerEncoderLayer(
80
+ d_model=self.audio_width,
81
+ nhead=8,
82
+ dim_feedforward=self.audio_width * 4,
83
+ dropout=0.1,
84
+ activation='relu',
85
+ batch_first=True
86
+ )
87
+ transformer_encoder = nn.TransformerEncoder(encoder_layer=encoder_layer, num_layers=2)
88
+ self.local_audio_proj = nn.Sequential(
89
+ transformer_encoder,
90
+ nn.Linear(self.audio_width, self.embed_size),
91
+ )
92
+ else:
93
+ raise ValueError(f"Invalid local audio proj type: {self.local_audio_proj_type}")
94
+
95
+ self.post_init()
96
+
97
+
98
+ def encode_audio(self, audio_mel):
99
+
100
+ outputs = self.audio_encoder.extract_features(audio_mel)
101
+ audio_encoded_raw = outputs['x'] if isinstance(outputs, dict) else outputs
102
+
103
+ audio_cls = audio_encoded_raw[:, 0:1, :]
104
+ audio_patches = audio_encoded_raw[:, 1:, :]
105
+
106
+ B, T, D = audio_patches.shape
107
+ ds_factor = 8
108
+ audio_patches_downsampled = audio_patches.reshape(
109
+ B, T // ds_factor, ds_factor, D
110
+ ).mean(dim=2)
111
+
112
+ # [B, 1+T//8, D]
113
+ audio_encoded = torch.cat([audio_cls, audio_patches_downsampled], dim=1)
114
+ return audio_encoded
115
+
116
+
117
+ def encode_text(self, input_ids, attention_mask):
118
+ outputs = self.text_encoder(input_ids=input_ids, attention_mask=attention_mask)
119
+ return outputs.last_hidden_state
120
+
121
+
122
+ def get_global_text_embeds(self, input_ids, attention_mask):
123
+ text_feats = self.encode_text(input_ids, attention_mask)
124
+ text_embeds = F.normalize(self.global_text_proj(text_feats[:, 0, :]), dim=-1)
125
+ return text_embeds
126
+
127
+
128
+ def get_global_audio_embeds(self, audio_mel):
129
+ audio_feats = self.encode_audio(audio_mel)
130
+
131
+ if self.config.unify_audio_proj:
132
+ audio_embeds = self.local_audio_proj(audio_feats)
133
+ if self.config.local_audio_proj_type == "rnn":
134
+ audio_embeds = audio_embeds[0]
135
+ global_audio_embeds = F.normalize(audio_embeds[:, 0, :], dim=-1)
136
+ return global_audio_embeds
137
+ else:
138
+ audio_cls_feat = audio_feats[:, 0, :]
139
+ audio_embeds = F.normalize(self.global_audio_proj(audio_cls_feat), dim=-1)
140
+ return audio_embeds
141
+
142
+
143
+ def get_dense_audio_embeds(self, audio_mel):
144
+ audio_feats = self.encode_audio(audio_mel)
145
+ audio_patches = audio_feats[:, 1:, :]
146
+
147
+ audio_embeds = self.local_audio_proj(audio_patches)
148
+ if self.config.local_audio_proj_type == "rnn":
149
+ audio_embeds = audio_embeds[0]
150
+
151
+ if self.config.normalize_dense_audio_embeds:
152
+ audio_embeds = F.normalize(audio_embeds, dim=-1)
153
+ return audio_embeds
154
+
155
+
156
+ def forward(self, audio_mel=None, input_ids=None, attention_mask=None, return_dict=True):
157
+ global_audio_embeds = None
158
+ dense_audio_embeds = None
159
+ global_text_embeds = None
160
+
161
+ if audio_mel is not None:
162
+ global_audio_embeds = self.get_global_audio_embeds(audio_mel)
163
+ dense_audio_embeds = self.get_dense_audio_embeds(audio_mel)
164
+
165
+ if input_ids is not None:
166
+ global_text_embeds = self.get_global_text_embeds(input_ids, attention_mask)
167
+
168
+ if not return_dict:
169
+ return (global_audio_embeds, dense_audio_embeds, global_text_embeds)
170
+
171
+ return {
172
+ "global_audio_embeds": global_audio_embeds,
173
+ "dense_audio_embeds": dense_audio_embeds,
174
+ "global_text_embeds": global_text_embeds
175
+ }