HTill commited on
Commit
011fa4f
·
verified ·
1 Parent(s): 8517158

Upload 7 files

Browse files

added the modified files

Files changed (7) hide show
  1. README.md +85 -0
  2. config.json +38 -0
  3. configuration_eat.py +58 -0
  4. eat_model.py +83 -0
  5. model.safetensors +3 -0
  6. model_core.py +294 -0
  7. modeling_eat.py +18 -0
README.md ADDED
@@ -0,0 +1,85 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ license: mit
3
+ tags:
4
+ - Audio
5
+ - SSL
6
+ - EAT
7
+ library_name: transformers
8
+ ---
9
+
10
+ # EAT-base (Epoch 30, Pre-trained Checkpoint)
11
+
12
+ This is the **pre-trained EAT-base model** at epoch 30, trained on the AS-2M dataset using the EAT framework for audio self-supervised learning.
13
+ It offers efficient feature extraction and can also serve as a strong initialization for fine-tuning on a wide range of downstream audio understanding tasks such as classification and captioning.
14
+
15
+ For more details on the EAT framework, please refer to the [GitHub repository](https://github.com/cwx-worst-one/EAT) and our paper [EAT: Self-Supervised Pre-Training with Efficient Audio Transformer](https://arxiv.org/abs/2401.03497).
16
+
17
+ ## 🔧 Usage
18
+
19
+ You can load and use the model for feature extraction directly via Hugging Face Transformers:
20
+
21
+ ```python
22
+ import torchaudio
23
+ import torch
24
+ import soundfile as sf
25
+ import numpy as np
26
+ from transformers import AutoModel
27
+
28
+ model_id = "HTill/flexEAT-base_epoch30_pretrain"
29
+ model = AutoModel.from_pretrained(model_id, trust_remote_code=True).eval().cuda()
30
+
31
+ source_file = "/path/to/input.wav"
32
+ target_file = "/path/to/output.npy"
33
+ norm_mean = -4.268
34
+ norm_std = 4.569
35
+
36
+ # Load and resample audio
37
+ wav, sr = sf.read(source_file)
38
+ waveform = torch.tensor(wav).float().cuda()
39
+ if sr != 16000:
40
+ waveform = torchaudio.functional.resample(waveform, sr, 16000)
41
+
42
+ # Normalize and convert to mel-spectrogram
43
+ waveform = waveform - waveform.mean()
44
+ mel = torchaudio.compliance.kaldi.fbank(
45
+ waveform.unsqueeze(0),
46
+ htk_compat=True,
47
+ sample_frequency=16000,
48
+ use_energy=False,
49
+ window_type='hanning',
50
+ num_mel_bins=128,
51
+ dither=0.0,
52
+ frame_shift=10
53
+ ).unsqueeze(0)
54
+
55
+ # Normalize
56
+ mel = (mel - norm_mean) / (norm_std * 2)
57
+ mel = mel.unsqueeze(0).cuda() # shape: [1, 1, T, F]
58
+
59
+ # Extract features
60
+ with torch.no_grad():
61
+ feat = model.extract_features(mel)
62
+
63
+ feat = feat.squeeze(0).cpu().numpy()
64
+ np.save(target_file, feat)
65
+ print(f"Feature shape: {feat.shape}")
66
+ print(f"Saved to: {target_file}")
67
+ ```
68
+
69
+ ## 📌 Notes
70
+
71
+ The model supports both **frame-level** (\~50Hz) and **utterance-level** (CLS token) representations.
72
+ See the [feature extraction guide](https://github.com/cwx-worst-one/EAT/tree/main/feature_extract) for more instructions.
73
+
74
+
75
+ ## 📚 Citation
76
+
77
+ If you find this model useful, please consider citing our [paper](https://arxiv.org/abs/2401.03497):
78
+
79
+ ```bibtex
80
+ @article{chen2024eat,
81
+ title={EAT: Self-supervised pre-training with efficient audio transformer},
82
+ author={Chen, Wenxi and Liang, Yuzhe and Ma, Ziyang and Zheng, Zhisheng and Chen, Xie},
83
+ journal={arXiv preprint arXiv:2401.03497},
84
+ year={2024}
85
+ }
config.json ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "activation_dropout": 0.0,
3
+ "architectures": [
4
+ "EATModel"
5
+ ],
6
+ "auto_map": {
7
+ "AutoModel": "modeling_eat.EATModel",
8
+ "AutoConfig": "configuration_eat.EATConfig"
9
+ },
10
+ "attn_drop_rate": 0.0,
11
+ "depth": 12,
12
+ "drop_rate": 0.0,
13
+ "embed_dim": 768,
14
+ "end_drop_path_rate": 0.0,
15
+ "fixed_positions": true,
16
+ "img_size": [
17
+ 1024,
18
+ 128
19
+ ],
20
+ "in_chans": 1,
21
+ "layer_norm_first": false,
22
+ "max_length": 768,
23
+ "mel_bins": 128,
24
+ "mlp_ratio": 4.0,
25
+ "model_type": "eat",
26
+ "model_variant": "pretrain",
27
+ "norm_affine": true,
28
+ "norm_eps": 1e-06,
29
+ "num_classes": 527,
30
+ "num_heads": 12,
31
+ "patch_size": 16,
32
+ "post_mlp_drop": 0.0,
33
+ "qkv_bias": true,
34
+ "start_drop_path_rate": 0.0,
35
+ "stride": 16,
36
+ "torch_dtype": "float32",
37
+ "transformers_version": "4.51.3"
38
+ }
configuration_eat.py ADDED
@@ -0,0 +1,58 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # configuration_eat.py
2
+
3
+ from transformers import PretrainedConfig
4
+
5
+ class EATConfig(PretrainedConfig):
6
+ model_type = "eat"
7
+
8
+ def __init__(
9
+ self,
10
+ embed_dim=768,
11
+ depth=12,
12
+ num_heads=12,
13
+ patch_size=16,
14
+ stride=16,
15
+ in_chans=1,
16
+ num_classes=527,
17
+ model_variant="pretrain", # or "finetune"
18
+
19
+ mlp_ratio=4.0,
20
+ qkv_bias=True,
21
+ drop_rate=0.0,
22
+ attn_drop_rate=0.0,
23
+ activation_dropout=0.0,
24
+ post_mlp_drop=0.0,
25
+ start_drop_path_rate=0.0,
26
+ end_drop_path_rate=0.0,
27
+
28
+ layer_norm_first=False,
29
+ norm_eps=1e-6,
30
+ norm_affine=True,
31
+ fixed_positions=True,
32
+
33
+ **kwargs,
34
+ ):
35
+ super().__init__(**kwargs)
36
+
37
+ self.embed_dim = embed_dim
38
+ self.depth = depth
39
+ self.num_heads = num_heads
40
+ self.patch_size = patch_size
41
+ self.stride = stride
42
+ self.in_chans = in_chans
43
+ self.num_classes = num_classes
44
+ self.model_variant = model_variant
45
+
46
+ self.mlp_ratio = mlp_ratio
47
+ self.qkv_bias = qkv_bias
48
+ self.drop_rate = drop_rate
49
+ self.attn_drop_rate = attn_drop_rate
50
+ self.activation_dropout = activation_dropout
51
+ self.post_mlp_drop = post_mlp_drop
52
+ self.start_drop_path_rate = start_drop_path_rate
53
+ self.end_drop_path_rate = end_drop_path_rate
54
+
55
+ self.layer_norm_first = layer_norm_first
56
+ self.norm_eps = norm_eps
57
+ self.norm_affine = norm_affine
58
+ self.fixed_positions = fixed_positions
eat_model.py ADDED
@@ -0,0 +1,83 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ from functools import partial
4
+ import numpy as np
5
+ from .model_core import (
6
+ PatchEmbed,
7
+ AltBlock,
8
+ trunc_normal_
9
+ )
10
+
11
+ class EAT(nn.Module):
12
+ def __init__(self, config):
13
+ super().__init__()
14
+ self.config = config
15
+ self.mode = config.model_variant # "pretrain" or "finetune"
16
+
17
+ # === Embedding / Encoder ===
18
+ self.local_encoder = PatchEmbed(
19
+ img_size=config.img_size,
20
+ patch_size=config.patch_size,
21
+ in_chans=config.in_chans,
22
+ embed_dim=config.embed_dim,
23
+ stride=config.stride
24
+ use_sincos_pos=config.fixed_positions
25
+ )
26
+
27
+ self.extra_tokens = nn.Parameter(torch.zeros(1, 1, config.embed_dim))
28
+ self.pos_drop = nn.Dropout(p=config.drop_rate, inplace=True)
29
+ trunc_normal_(self.extra_tokens, std=.02)
30
+
31
+ norm_layer = partial(nn.LayerNorm, eps=config.norm_eps, elementwise_affine=config.norm_affine)
32
+ dpr = np.linspace(config.start_drop_path_rate, config.end_drop_path_rate, config.depth)
33
+ self.blocks = nn.ModuleList([
34
+ AltBlock(config.embed_dim, config.num_heads, config.mlp_ratio,
35
+ qkv_bias=config.qkv_bias, drop=config.drop_rate,
36
+ attn_drop=config.attn_drop_rate, mlp_drop=config.activation_dropout,
37
+ post_mlp_drop=config.post_mlp_drop, drop_path=dpr[i],
38
+ norm_layer=norm_layer, layer_norm_first=config.layer_norm_first,
39
+ ffn_targets=True)
40
+ for i in range(config.depth)
41
+ ])
42
+
43
+ self.pre_norm = norm_layer(config.embed_dim)
44
+
45
+ # === Head (for finetune) ===
46
+ if self.mode == "finetune":
47
+ self.fc_norm = nn.LayerNorm(config.embed_dim)
48
+ self.head = nn.Linear(config.embed_dim, config.num_classes, bias=True)
49
+ else:
50
+ self.head = nn.Identity()
51
+
52
+ self.apply(self._init_weights)
53
+
54
+ def _init_weights(self, m):
55
+ if isinstance(m, nn.Linear):
56
+ trunc_normal_(m.weight, std=.02)
57
+ if m.bias is not None:
58
+ nn.init.constant_(m.bias, 0)
59
+ elif isinstance(m, nn.LayerNorm):
60
+ nn.init.constant_(m.bias, 0)
61
+ nn.init.constant_(m.weight, 1.0)
62
+
63
+ def encode(self, x):
64
+ B = x.shape[0]
65
+ x = self.local_encoder(x)
66
+ x = torch.cat((self.extra_tokens.expand(B, -1, -1), x), dim=1)
67
+ x = self.pre_norm(x)
68
+ x = self.pos_drop(x)
69
+ for blk in self.blocks:
70
+ x, _ = blk(x)
71
+ return x
72
+
73
+ def forward(self, x):
74
+ x = self.encode(x)
75
+ if self.mode == "finetune":
76
+ x = x[:, 0] # use cls token
77
+ x = self.fc_norm(x)
78
+ x = self.head(x)
79
+ return x
80
+
81
+ def extract_features(self, x):
82
+ x = self.encode(x)
83
+ return x
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:8623072d09aac4f3ad1168b4fed3a24e4f68fe1da25b9fe733375efb237e5f48
3
+ size 359905840
model_core.py ADDED
@@ -0,0 +1,294 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import numpy as np
5
+ import collections
6
+
7
+ # --- Helpers (Replacements for timm functions) ---
8
+ def to_2tuple(x):
9
+ if isinstance(x, collections.abc.Iterable) and not isinstance(x, str):
10
+ return x
11
+ return tuple(x for _ in range(2))
12
+
13
+
14
+ def trunc_normal_(tensor, mean=0.0, std=1.0, a=-2.0, b=2.0):
15
+ """Replacement for timm.models.layers.trunc_normal_"""
16
+ return torch.nn.init.trunc_normal_(tensor, mean, std, a, b)
17
+
18
+
19
+ # --- Custom Modules (No TIMM) ---
20
+ def drop_path(
21
+ x, drop_prob: float = 0.0, training: bool = False, scale_by_keep: bool = True
22
+ ):
23
+ """Drop paths (Stochastic Depth) per sample."""
24
+ if drop_prob == 0.0 or not training:
25
+ return x
26
+ keep_prob = 1 - drop_prob
27
+ shape = (x.shape[0],) + (1,) * (x.ndim - 1)
28
+ random_tensor = x.new_empty(shape).bernoulli_(keep_prob)
29
+ if keep_prob > 0.0 and scale_by_keep:
30
+ random_tensor.div_(keep_prob)
31
+ return x * random_tensor
32
+
33
+
34
+ class DropPath(nn.Module):
35
+ """Drop paths (Stochastic Depth) per sample."""
36
+
37
+ def __init__(self, drop_prob: float = 0.0, scale_by_keep: bool = True):
38
+ super(DropPath, self).__init__()
39
+ self.drop_prob = drop_prob
40
+ self.scale_by_keep = scale_by_keep
41
+
42
+ def forward(self, x):
43
+ return drop_path(x, self.drop_prob, self.training, self.scale_by_keep)
44
+
45
+ def extra_repr(self):
46
+ return f"drop_prob={round(self.drop_prob,3):0.3f}"
47
+
48
+
49
+ class Mlp(nn.Module):
50
+ """MLP as used in Vision Transformer, MLP-Mixer and related networks"""
51
+
52
+ def __init__(
53
+ self,
54
+ in_features,
55
+ hidden_features=None,
56
+ out_features=None,
57
+ act_layer=nn.GELU,
58
+ drop=0.0,
59
+ ):
60
+ super().__init__()
61
+ out_features = out_features or in_features
62
+ hidden_features = hidden_features or in_features
63
+
64
+ self.fc1 = nn.Linear(in_features, hidden_features)
65
+ self.act = act_layer() if isinstance(act_layer, type) else act_layer
66
+ self.drop1 = nn.Dropout(drop)
67
+ self.fc2 = nn.Linear(hidden_features, out_features)
68
+ self.drop2 = nn.Dropout(drop)
69
+
70
+ def forward(self, x):
71
+ x = self.fc1(x)
72
+ x = self.act(x)
73
+ x = self.drop1(x)
74
+ x = self.fc2(x)
75
+ x = self.drop2(x)
76
+ return x
77
+
78
+
79
+ class SinCos2DEmbed(nn.Module):
80
+ def __init__(self):
81
+ super().__init__()
82
+
83
+ def forward(self, x):
84
+ # x has the shape [batch_size, embed_dim, grid_length, grid_height]
85
+ # Note: grid_length corresponds to H (Time/Frequency), grid_height to W
86
+ _, embed_dim, grid_length, grid_height = x.shape
87
+
88
+ # Create grid positions
89
+ grid_length_a = torch.arange(grid_length, dtype=torch.float32, device=x.device)
90
+ grid_height_a = torch.arange(grid_height, dtype=torch.float32, device=x.device)
91
+ grid = torch.meshgrid(grid_length_a, grid_height_a, indexing="ij")
92
+
93
+ sub_embed_dim = embed_dim // 4
94
+ omega = torch.arange(sub_embed_dim, dtype=torch.float32, device=x.device)
95
+ omega /= sub_embed_dim
96
+ omega = 1.0 / 10000**omega
97
+
98
+ # embed_length (dimension 0 of grid)
99
+ out_length = torch.einsum("mn,d->dmn", grid[0], omega)
100
+ embed_length_sin = torch.sin(out_length)
101
+ embed_length_cos = torch.cos(out_length)
102
+ embed_length = torch.cat([embed_length_sin, embed_length_cos], dim=0)
103
+
104
+ # embed_height (dimension 1 of grid)
105
+ out_height = torch.einsum("mn,d->dmn", grid[1], omega)
106
+ embed_height_sin = torch.sin(out_height)
107
+ embed_height_cos = torch.cos(out_height)
108
+ embed_height = torch.cat([embed_height_sin, embed_height_cos], dim=0)
109
+
110
+ # concat length and height embeddings
111
+ embed = torch.cat([embed_length, embed_height], dim=0).unsqueeze(dim=0)
112
+
113
+ x = x + embed
114
+ return x
115
+
116
+
117
+ class PatchEmbed(nn.Module):
118
+ """Flexible Image to Patch Embedding"""
119
+
120
+ def __init__(
121
+ self,
122
+ img_size=224,
123
+ patch_size=16,
124
+ in_chans=3,
125
+ embed_dim=768,
126
+ stride=16,
127
+ use_sincos_pos=False,
128
+ ):
129
+ super().__init__()
130
+ img_size = to_2tuple(img_size)
131
+ patch_size = to_2tuple(patch_size)
132
+ stride = to_2tuple(stride)
133
+
134
+ self.img_size = img_size
135
+ self.patch_size = patch_size
136
+ self.use_sincos_pos = use_sincos_pos
137
+
138
+ self.proj = nn.Conv2d(
139
+ in_chans, embed_dim, kernel_size=patch_size, stride=stride
140
+ ) # with overlapped patches
141
+
142
+ if self.use_sincos_pos:
143
+ self.pos_embed = SinCos2DEmbed()
144
+ else:
145
+ self.pos_embed = None
146
+
147
+ def forward(self, x):
148
+ x = self.proj(x)
149
+
150
+ # Apply dynamic positional embedding before flattening
151
+ if self.pos_embed is not None:
152
+ x = self.pos_embed(x)
153
+
154
+ x = x.flatten(2).transpose(1, 2)
155
+ return x
156
+
157
+
158
+ class AltBlock(nn.Module):
159
+ def __init__(
160
+ self,
161
+ dim,
162
+ num_heads,
163
+ mlp_ratio=4.0,
164
+ qkv_bias=False,
165
+ qk_scale=None,
166
+ drop=0.0,
167
+ attn_drop=0.0,
168
+ mlp_drop=0.0,
169
+ post_mlp_drop=0.0,
170
+ drop_path=0.0,
171
+ act_layer=nn.GELU,
172
+ norm_layer=nn.LayerNorm,
173
+ layer_norm_first=True,
174
+ ffn_targets=False,
175
+ cosine_attention=False,
176
+ ):
177
+ super().__init__()
178
+
179
+ self.layer_norm_first = layer_norm_first
180
+ self.ffn_targets = ffn_targets
181
+
182
+ self.norm1 = norm_layer(dim)
183
+ self.attn = AltAttention(
184
+ dim,
185
+ num_heads=num_heads,
186
+ qkv_bias=qkv_bias,
187
+ qk_scale=qk_scale,
188
+ attn_drop=attn_drop,
189
+ proj_drop=drop,
190
+ cosine_attention=cosine_attention,
191
+ )
192
+
193
+ self.drop_path = DropPath(drop_path) if drop_path > 0.0 else nn.Identity()
194
+ self.norm2 = norm_layer(dim)
195
+ mlp_hidden_dim = int(dim * mlp_ratio)
196
+ self.mlp = Mlp(
197
+ in_features=dim,
198
+ hidden_features=mlp_hidden_dim,
199
+ act_layer=act_layer,
200
+ drop=mlp_drop,
201
+ )
202
+ self.post_mlp_dropout = nn.Dropout(post_mlp_drop, inplace=False)
203
+
204
+ def forward(self, x, padding_mask=None, alibi_bias=None):
205
+ if self.layer_norm_first:
206
+ x = x + self.drop_path(self.attn(self.norm1(x), padding_mask, alibi_bias))
207
+ r = x = self.mlp(self.norm2(x))
208
+ t = x
209
+ x = r + self.drop_path(self.post_mlp_dropout(x))
210
+ if not self.ffn_targets:
211
+ t = x
212
+ else:
213
+ x = x + self.drop_path(self.attn(x, padding_mask, alibi_bias))
214
+ r = x = self.norm1(x)
215
+ x = self.mlp(x)
216
+ t = x
217
+ x = self.norm2(r + self.drop_path(self.post_mlp_dropout(x)))
218
+ if not self.ffn_targets:
219
+ t = x
220
+
221
+ return x, t
222
+
223
+
224
+ class AltAttention(nn.Module):
225
+ def __init__(
226
+ self,
227
+ dim,
228
+ num_heads=8,
229
+ qkv_bias=False,
230
+ qk_scale=None,
231
+ attn_drop=0.0,
232
+ proj_drop=0.0,
233
+ cosine_attention=False,
234
+ ):
235
+ super().__init__()
236
+ self.num_heads = num_heads
237
+ head_dim = dim // num_heads
238
+ self.scale = qk_scale or head_dim ** -0.5
239
+
240
+ self.qkv = nn.Linear(dim, dim * 3, bias=qkv_bias)
241
+ self.attn_drop = nn.Dropout(attn_drop)
242
+ self.proj = nn.Linear(dim, dim)
243
+ self.proj_drop = nn.Dropout(proj_drop)
244
+
245
+ self.cosine_attention = cosine_attention
246
+
247
+ if cosine_attention:
248
+ self.logit_scale = nn.Parameter(
249
+ torch.log(10 * torch.ones((num_heads, 1, 1))), requires_grad=True
250
+ )
251
+
252
+ def forward(self, x, padding_mask=None, alibi_bias=None):
253
+ B, N, C = x.shape
254
+ qkv = (
255
+ self.qkv(x)
256
+ .reshape(B, N, 3, self.num_heads, C // self.num_heads)
257
+ .permute(2, 0, 3, 1, 4) # qkv x B x H x L x D
258
+ )
259
+ q, k, v = (
260
+ qkv[0],
261
+ qkv[1],
262
+ qkv[2],
263
+ ) # make torchscript happy (cannot use tensor as tuple)
264
+
265
+ dtype = q.dtype
266
+
267
+ if self.cosine_attention:
268
+ # cosine attention
269
+ attn = F.normalize(q, dim=-1) @ F.normalize(k, dim=-1).transpose(-2, -1)
270
+ logit_scale = torch.clamp(
271
+ self.logit_scale, max=torch.log(torch.tensor(1.0 / 0.01))
272
+ ).exp()
273
+ attn = attn * logit_scale
274
+ else:
275
+ q = q * self.scale
276
+ attn = q @ k.transpose(-2, -1)
277
+
278
+ if alibi_bias is not None:
279
+ attn = attn.type_as(alibi_bias)
280
+ attn[:, : alibi_bias.size(1)] += alibi_bias
281
+
282
+ if padding_mask is not None and padding_mask.any():
283
+ attn = attn.masked_fill(
284
+ padding_mask.unsqueeze(1).unsqueeze(2).to(torch.bool),
285
+ float("-inf"),
286
+ )
287
+
288
+ attn = attn.softmax(dim=-1, dtype=torch.float32).to(dtype=dtype)
289
+ attn = self.attn_drop(attn)
290
+ x = (attn @ v).transpose(1, 2) #
291
+ x = x.reshape(B, N, C)
292
+ x = self.proj(x)
293
+ x = self.proj_drop(x)
294
+ return x
modeling_eat.py ADDED
@@ -0,0 +1,18 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # modeling_eat.py
2
+
3
+ from transformers import PreTrainedModel
4
+ from .configuration_eat import EATConfig
5
+ from .eat_model import EAT
6
+
7
+ class EATModel(PreTrainedModel):
8
+ config_class = EATConfig
9
+
10
+ def __init__(self, config: EATConfig):
11
+ super().__init__(config)
12
+ self.model = EAT(config)
13
+
14
+ def forward(self, *args, **kwargs):
15
+ return self.model(*args, **kwargs)
16
+
17
+ def extract_features(self, x):
18
+ return self.model.extract_features(x)