ta012 commited on
Commit
82d24b9
·
verified ·
1 Parent(s): b5c0a99

Upload 7 files

Browse files
Files changed (7) hide show
  1. README.md +116 -3
  2. config.json +38 -0
  3. configuration_eat.py +66 -0
  4. eat_model.py +99 -0
  5. model.safetensors +3 -0
  6. model_core.py +224 -0
  7. modeling_eat.py +18 -0
README.md CHANGED
@@ -1,3 +1,116 @@
1
- ---
2
- license: mit
3
- ---
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - Audio
5
+ - SSL
6
+ - SSLAM
7
+ library_name: transformers
8
+ ---
9
+
10
+ # SSLAM AudioSet-2M Finetuned (ViT Base, mAP:50.2)
11
+
12
+ This repository provides an [SSLAM](https://openreview.net/forum?id=odU59TxdiB) checkpoint formatted for use with Hugging Face Transformers. It is intended for feature extraction in audio LLMs, sound event detection, and general purpose audio representation learning. The implementation follows the [EAT](https://arxiv.org/abs/2401.03497) code path while swapping in SSLAM AudioSet-2M Finetuned weight.
13
+
14
+
15
+
16
+ ## 🔧 Usage
17
+
18
+ You can load and use the model for feature extraction directly via Hugging Face Transformers:
19
+
20
+ ```python
21
+ import torchaudio
22
+ import torch
23
+ import soundfile as sf
24
+ import numpy as np
25
+ from transformers import AutoModel
26
+
27
+ model_id = "ta012/SSLAM_AS2M_Finetuned"
28
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
29
+
30
+ source_file = "/path/to/input.wav"
31
+ target_length = 1024 # Recommended: 1024 for 10s audio
32
+ norm_mean = -4.268
33
+ norm_std = 4.569
34
+
35
+ # Load and resample audio
36
+ wav, sr = sf.read(source_file)
37
+ waveform = torch.tensor(wav).float().cuda()
38
+ if sr != 16000:
39
+ waveform = torchaudio.functional.resample(waveform, sr, 16000)
40
+
41
+ # Normalize and convert to mel-spectrogram
42
+ waveform = waveform - waveform.mean()
43
+ mel = torchaudio.compliance.kaldi.fbank(
44
+ waveform.unsqueeze(0),
45
+ htk_compat=True,
46
+ sample_frequency=16000,
47
+ use_energy=False,
48
+ window_type='hanning',
49
+ num_mel_bins=128,
50
+ dither=0.0,
51
+ frame_shift=10
52
+ ).unsqueeze(0)
53
+
54
+ # Pad or truncate
55
+ n_frames = mel.shape[1]
56
+ if n_frames < target_length:
57
+ mel = torch.nn.ZeroPad2d((0, 0, 0, target_length - n_frames))(mel)
58
+ else:
59
+ mel = mel[:, :target_length, :]
60
+
61
+ # Normalize
62
+ mel = (mel - norm_mean) / (norm_std * 2)
63
+ mel = mel.unsqueeze(0).cuda() # shape: [1, 1, T, F]
64
+
65
+ # Extract features
66
+ with torch.no_grad():
67
+ feat = model.extract_features(mel)
68
+
69
+ feat = feat.squeeze(0).cpu().numpy()
70
+ print(f"Feature shape: {feat.shape}")
71
+ ```
72
+
73
+ ## 📌 Notes
74
+
75
+
76
+ See the [feature extraction guide](https://github.com/cwx-worst-one/EAT/tree/main/feature_extract) for more instructions.
77
+
78
+
79
+ ## 🙌 Acknowledgments
80
+
81
+ This repository builds on the EAT implementation for Hugging Face models. We remap SSLAM weights to that interface.
82
+
83
+ - Paper: EAT: Self supervised pretraining with Efficient Audio Transformer
84
+ - Code: https://github.com/cwx-worst-one/EAT
85
+
86
+ We are not affiliated with the EAT authors. All credit for the original implementation belongs to them.
87
+
88
+
89
+ ## 📚 Citation
90
+
91
+
92
+ If you find our work useful, please cite it as:
93
+
94
+
95
+ ```bibtex
96
+ @inproceedings{alex2025sslam,
97
+ title={{SSLAM}: Enhancing Self-Supervised Models with Audio Mixtures for Polyphonic Soundscapes},
98
+ author={Tony Alex and Sara Atito and Armin Mustafa and Muhammad Awais and Philip J B Jackson},
99
+ booktitle={The Thirteenth International Conference on Learning Representations},
100
+ year={2025},
101
+ url={https://openreview.net/forum?id=odU59TxdiB}
102
+ }
103
+ ```
104
+
105
+
106
+
107
+ Please also cite EAT:
108
+
109
+ ```bibtex
110
+ @article{chen2024eat,
111
+ title={EAT: Self-supervised pre-training with efficient audio transformer},
112
+ author={Chen, Wenxi and Liang, Yuzhe and Ma, Ziyang and Zheng, Zhisheng and Chen, Xie},
113
+ journal={arXiv preprint arXiv:2401.03497},
114
+ year={2024}
115
+ }
116
+ ```
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "architectures": [
4
+ "EATModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModel": "modeling_eat.EATModel",
8
+ "AutoConfig": "configuration_eat.EATConfig"
9
+ },
10
+ "model_type": "eat",
11
+ "attn_drop_rate": 0.0,
12
+ "depth": 12,
13
+ "drop_rate": 0.0,
14
+ "embed_dim": 768,
15
+ "end_drop_path_rate": 0.0,
16
+ "fixed_positions": true,
17
+ "img_size": [
18
+ 1024,
19
+ 128
20
+ ],
21
+ "in_chans": 1,
22
+ "layer_norm_first": false,
23
+ "max_length": 768,
24
+ "mel_bins": 128,
25
+ "mlp_ratio": 4.0,
26
+ "model_variant": "finetune",
27
+ "norm_affine": true,
28
+ "norm_eps": 1e-06,
29
+ "num_classes": 527,
30
+ "num_heads": 12,
31
+ "patch_size": 16,
32
+ "post_mlp_drop": 0.0,
33
+ "qkv_bias": true,
34
+ "start_drop_path_rate": 0.0,
35
+ "stride": 16,
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.51.3"
38
+ }
configuration_eat.py ADDED
@@ -0,0 +1,66 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # configuration_eat.py
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+ class EATConfig(PretrainedConfig):
6
+ model_type = "eat"
7
+
8
+ def __init__(
9
+ self,
10
+ embed_dim=768,
11
+ depth=12,
12
+ num_heads=12,
13
+ patch_size=16,
14
+ stride=16,
15
+ in_chans=1,
16
+ mel_bins=128,
17
+ max_length=768,
18
+ num_classes=527,
19
+ model_variant="pretrain", # or "finetune"
20
+
21
+ mlp_ratio=4.0,
22
+ qkv_bias=True,
23
+ drop_rate=0.0,
24
+ attn_drop_rate=0.0,
25
+ activation_dropout=0.0,
26
+ post_mlp_drop=0.0,
27
+ start_drop_path_rate=0.0,
28
+ end_drop_path_rate=0.0,
29
+
30
+ layer_norm_first=False,
31
+ norm_eps=1e-6,
32
+ norm_affine=True,
33
+ fixed_positions=True,
34
+
35
+ img_size=(1024, 128), # (target_length, mel_bins)
36
+
37
+ **kwargs,
38
+ ):
39
+ super().__init__(**kwargs)
40
+
41
+ self.embed_dim = embed_dim
42
+ self.depth = depth
43
+ self.num_heads = num_heads
44
+ self.patch_size = patch_size
45
+ self.stride = stride
46
+ self.in_chans = in_chans
47
+ self.mel_bins = mel_bins
48
+ self.max_length = max_length
49
+ self.num_classes = num_classes
50
+ self.model_variant = model_variant
51
+
52
+ self.mlp_ratio = mlp_ratio
53
+ self.qkv_bias = qkv_bias
54
+ self.drop_rate = drop_rate
55
+ self.attn_drop_rate = attn_drop_rate
56
+ self.activation_dropout = activation_dropout
57
+ self.post_mlp_drop = post_mlp_drop
58
+ self.start_drop_path_rate = start_drop_path_rate
59
+ self.end_drop_path_rate = end_drop_path_rate
60
+
61
+ self.layer_norm_first = layer_norm_first
62
+ self.norm_eps = norm_eps
63
+ self.norm_affine = norm_affine
64
+ self.fixed_positions = fixed_positions
65
+
66
+ self.img_size = img_size
eat_model.py ADDED
@@ -0,0 +1,99 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from timm.models.layers import trunc_normal_
4
+ from functools import partial
5
+ import numpy as np
6
+ from .model_core import (
7
+ PatchEmbed_new,
8
+ get_2d_sincos_pos_embed_flexible,
9
+ FixedPositionalEncoder,
10
+ AltBlock
11
+ )
12
+
13
+ class EAT(nn.Module):
14
+ def __init__(self, config):
15
+ super().__init__()
16
+ self.config = config
17
+ self.mode = config.model_variant # "pretrain" or "finetune"
18
+
19
+ # === Embedding / Encoder ===
20
+ self.local_encoder = PatchEmbed_new(
21
+ img_size=config.img_size,
22
+ patch_size=config.patch_size,
23
+ in_chans=config.in_chans,
24
+ embed_dim=config.embed_dim,
25
+ stride=config.stride
26
+ )
27
+
28
+ self.extra_tokens = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
29
+ self.pos_drop = nn.Dropout(p=config.drop_rate, inplace=True)
30
+ trunc_normal_(self.extra_tokens, std=.02)
31
+
32
+ self.fixed_positional_encoder = (
33
+ FixedPositionalEncoder(self.build_sincos_pos_embed()) if config.fixed_positions else None
34
+ )
35
+
36
+ norm_layer = partial(nn.LayerNorm, eps=config.norm_eps, elementwise_affine=config.norm_affine)
37
+ dpr = np.linspace(config.start_drop_path_rate, config.end_drop_path_rate, config.depth)
38
+ self.blocks = nn.ModuleList([
39
+ AltBlock(config.embed_dim, config.num_heads, config.mlp_ratio,
40
+ qkv_bias=config.qkv_bias, drop=config.drop_rate,
41
+ attn_drop=config.attn_drop_rate, mlp_drop=config.activation_dropout,
42
+ post_mlp_drop=config.post_mlp_drop, drop_path=dpr[i],
43
+ norm_layer=norm_layer, layer_norm_first=config.layer_norm_first,
44
+ ffn_targets=True)
45
+ for i in range(config.depth)
46
+ ])
47
+
48
+ self.pre_norm = norm_layer(config.embed_dim)
49
+
50
+ # === Head (for finetune) ===
51
+ if self.mode == "finetune":
52
+ self.fc_norm = nn.LayerNorm(config.embed_dim)
53
+ self.head = nn.Linear(config.embed_dim, config.num_classes, bias=True)
54
+ else:
55
+ self.head = nn.Identity()
56
+
57
+ self.apply(self._init_weights)
58
+
59
+ def build_sincos_pos_embed(self):
60
+ W = self.config.mel_bins // self.config.patch_size
61
+ max_length = self.config.max_length
62
+ embed_dim = self.config.embed_dim
63
+ pos_embed = nn.Parameter(torch.zeros(1, max_length * W, embed_dim), requires_grad=False)
64
+ emb = get_2d_sincos_pos_embed_flexible(embed_dim, (max_length, W), cls_token=False)
65
+ pos_embed.data.copy_(torch.from_numpy(emb).float().unsqueeze(0))
66
+ return pos_embed
67
+
68
+ def _init_weights(self, m):
69
+ if isinstance(m, nn.Linear):
70
+ trunc_normal_(m.weight, std=.02)
71
+ if m.bias is not None:
72
+ nn.init.constant_(m.bias, 0)
73
+ elif isinstance(m, nn.LayerNorm):
74
+ nn.init.constant_(m.bias, 0)
75
+ nn.init.constant_(m.weight, 1.0)
76
+
77
+ def encode(self, x):
78
+ B = x.shape[0]
79
+ x = self.local_encoder(x)
80
+ if self.fixed_positional_encoder is not None:
81
+ x = x + self.fixed_positional_encoder(x, None)[:, :x.size(1), :]
82
+ x = torch.cat((self.extra_tokens.expand(B, -1, -1), x), dim=1)
83
+ x = self.pre_norm(x)
84
+ x = self.pos_drop(x)
85
+ for blk in self.blocks:
86
+ x, _ = blk(x)
87
+ return x
88
+
89
+ def forward(self, x):
90
+ x = self.encode(x)
91
+ if self.mode == "finetune":
92
+ x = x[:, 0] # use cls token
93
+ x = self.fc_norm(x)
94
+ x = self.head(x)
95
+ return x
96
+
97
+ def extract_features(self, x):
98
+ x = self.encode(x)
99
+ return x
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:2049e9cf7afc48dda0319964f3c58736b99a5b8871013b5267601e312f5c11e8
3
+ size 361533396
model_core.py ADDED
@@ -0,0 +1,224 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ from timm.models.layers import to_2tuple
6
+
7
+ class PatchEmbed_new(nn.Module):
8
+ """ Flexible Image to Patch Embedding
9
+ """
10
+ def __init__(self, img_size=224, patch_size=16, in_chans=3, embed_dim=768, stride=16):
11
+ super().__init__()
12
+ img_size = to_2tuple(img_size)
13
+ patch_size = to_2tuple(patch_size)
14
+ stride = to_2tuple(stride)
15
+
16
+ self.img_size = img_size
17
+ self.patch_size = patch_size
18
+
19
+ self.proj = nn.Conv2d(in_chans, embed_dim, kernel_size=patch_size, stride=stride) # with overlapped patches
20
+
21
+ def forward(self, x):
22
+ x = self.proj(x)
23
+ x = x.flatten(2).transpose(1, 2)
24
+ return x
25
+
26
+
27
+ def get_2d_sincos_pos_embed_flexible(embed_dim, grid_size, cls_token=False):
28
+ """
29
+ grid_size: int of the grid height and width
30
+ return:
31
+ pos_embed: [grid_size*grid_size, embed_dim] or [1+grid_size*grid_size, embed_dim] (w/ or w/o cls_token)
32
+ """
33
+ grid_h = np.arange(grid_size[0], dtype=np.float32)
34
+ grid_w = np.arange(grid_size[1], dtype=np.float32)
35
+ grid = np.meshgrid(grid_w, grid_h) # here w goes first
36
+ grid = np.stack(grid, axis=0)
37
+
38
+ grid = grid.reshape([2, 1, grid_size[0], grid_size[1]])
39
+ pos_embed = get_2d_sincos_pos_embed_from_grid(embed_dim, grid)
40
+ if cls_token:
41
+ pos_embed = np.concatenate([np.zeros([1, embed_dim]), pos_embed], axis=0)
42
+ return pos_embed
43
+
44
+
45
+ def get_2d_sincos_pos_embed_from_grid(embed_dim, grid):
46
+ assert embed_dim % 2 == 0
47
+
48
+ # use half of dimensions to encode grid_h
49
+ emb_h = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[0]) # (H*W, D/2)
50
+ emb_w = get_1d_sincos_pos_embed_from_grid(embed_dim // 2, grid[1]) # (H*W, D/2)
51
+
52
+ emb = np.concatenate([emb_h, emb_w], axis=1) # (H*W, D)
53
+ return emb
54
+
55
+
56
+ def get_1d_sincos_pos_embed_from_grid(embed_dim, pos):
57
+ """
58
+ embed_dim: output dimension for each position
59
+ pos: a list of positions to be encoded: size (M,)
60
+ out: (M, D)
61
+ """
62
+ assert embed_dim % 2 == 0
63
+ omega = np.arange(embed_dim // 2, dtype=np.float32)
64
+ omega /= embed_dim / 2.0
65
+ omega = 1.0 / 10000 ** omega # (D/2,)
66
+
67
+ pos = pos.reshape(-1) # (M,)
68
+ out = np.einsum("m,d->md", pos, omega) # (M, D/2), outer product
69
+
70
+ emb_sin = np.sin(out) # (M, D/2)
71
+ emb_cos = np.cos(out) # (M, D/2)
72
+
73
+ emb = np.concatenate([emb_sin, emb_cos], axis=1) # (M, D)
74
+ return emb
75
+
76
+
77
+ class FixedPositionalEncoder(nn.Module):
78
+ def __init__(self, pos_embed):
79
+ super().__init__()
80
+ self.positions = pos_embed
81
+
82
+ def forward(self, x, padding_mask):
83
+ return self.positions
84
+
85
+
86
+ class AltBlock(nn.Module):
87
+ def __init__(
88
+ self,
89
+ dim,
90
+ num_heads,
91
+ mlp_ratio=4.0,
92
+ qkv_bias=False,
93
+ qk_scale=None,
94
+ drop=0.0,
95
+ attn_drop=0.0,
96
+ mlp_drop=0.0,
97
+ post_mlp_drop=0.0,
98
+ drop_path=0.0,
99
+ act_layer=nn.GELU,
100
+ norm_layer=nn.LayerNorm,
101
+ layer_norm_first=True,
102
+ ffn_targets=False,
103
+ cosine_attention=False,
104
+ ):
105
+ super().__init__()
106
+
107
+ self.layer_norm_first = layer_norm_first
108
+ self.ffn_targets = ffn_targets
109
+
110
+ from timm.models.vision_transformer import DropPath, Mlp
111
+
112
+ self.norm1 = norm_layer(dim)
113
+ self.attn = AltAttention(
114
+ dim,
115
+ num_heads=num_heads,
116
+ qkv_bias=qkv_bias,
117
+ qk_scale=qk_scale,
118
+ attn_drop=attn_drop,
119
+ proj_drop=drop,
120
+ cosine_attention=cosine_attention,
121
+ )
122
+
123
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
124
+ self.norm2 = norm_layer(dim)
125
+ mlp_hidden_dim = int(dim * mlp_ratio)
126
+ self.mlp = Mlp(
127
+ in_features=dim,
128
+ hidden_features=mlp_hidden_dim,
129
+ act_layer=act_layer,
130
+ drop=mlp_drop,
131
+ )
132
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
133
+
134
+ def forward(self, x, padding_mask=None, alibi_bias=None):
135
+ if self.layer_norm_first:
136
+ x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
137
+ r = x = self.mlp(self.norm2(x))
138
+ t = x
139
+ x = r + self.drop_path(self.post_mlp_dropout(x))
140
+ if not self.ffn_targets:
141
+ t = x
142
+ else:
143
+ x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
144
+ r = x = self.norm1(x)
145
+ x = self.mlp(x)
146
+ t = x
147
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
148
+ if not self.ffn_targets:
149
+ t = x
150
+
151
+ return x, t
152
+
153
+
154
+ class AltAttention(nn.Module):
155
+ def __init__(
156
+ self,
157
+ dim,
158
+ num_heads=8,
159
+ qkv_bias=False,
160
+ qk_scale=None,
161
+ attn_drop=0.0,
162
+ proj_drop=0.0,
163
+ cosine_attention=False,
164
+ ):
165
+ super().__init__()
166
+ self.num_heads = num_heads
167
+ head_dim = dim // num_heads
168
+ self.scale = qk_scale or head_dim ** -0.5
169
+
170
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
171
+ self.attn_drop = nn.Dropout(attn_drop)
172
+ self.proj = nn.Linear(dim, dim)
173
+ self.proj_drop = nn.Dropout(proj_drop)
174
+
175
+ self.cosine_attention = cosine_attention
176
+
177
+ if cosine_attention:
178
+ self.logit_scale = nn.Parameter(
179
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
180
+ )
181
+
182
+ def forward(self, x, padding_mask=None, alibi_bias=None):
183
+ B, N, C = x.shape
184
+ qkv = (
185
+ self.qkv(x)
186
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
187
+ .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
188
+ )
189
+ q, k, v = (
190
+ qkv[0],
191
+ qkv[1],
192
+ qkv[2],
193
+ ) # make torchscript happy (cannot use tensor as tuple)
194
+
195
+ dtype = q.dtype
196
+
197
+ if self.cosine_attention:
198
+ # cosine attention
199
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
200
+ logit_scale = torch.clamp(
201
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
202
+ ).exp()
203
+ attn = attn * logit_scale
204
+ else:
205
+ q = q * self.scale
206
+ attn = q @ k.transpose(-2, -1)
207
+
208
+ if alibi_bias is not None:
209
+ attn = attn.type_as(alibi_bias)
210
+ attn[:, : alibi_bias.size(1)] += alibi_bias
211
+
212
+ if padding_mask is not None and padding_mask.any():
213
+ attn = attn.masked_fill(
214
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
215
+ float("-inf"),
216
+ )
217
+
218
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
219
+ attn = self.attn_drop(attn)
220
+ x = (attn @ v).transpose(1, 2) #
221
+ x = x.reshape(B, N, C)
222
+ x = self.proj(x)
223
+ x = self.proj_drop(x)
224
+ return x
modeling_eat.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_eat.py
2
+
3
+ from transformers import PreTrainedModel
4
+ from .configuration_eat import EATConfig
5
+ from .eat_model import EAT
6
+
7
+ class EATModel(PreTrainedModel):
8
+ config_class = EATConfig
9
+
10
+ def __init__(self, config: EATConfig):
11
+ super().__init__(config)
12
+ self.model = EAT(config)
13
+
14
+ def forward(self, *args, **kwargs):
15
+ return self.model(*args, **kwargs)
16
+
17
+ def extract_features(self, x):
18
+ return self.model.extract_features(x)