wh1tet3a commited on
Commit
73b6cc3
·
0 Parent(s):

add model

Browse files
Files changed (5) hide show
  1. .gitattributes +35 -0
  2. README.md +191 -0
  3. config.json +1 -0
  4. model.py +955 -0
  5. model.safetensors +3 -0
.gitattributes ADDED
@@ -0,0 +1,35 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ *.7z filter=lfs diff=lfs merge=lfs -text
2
+ *.arrow filter=lfs diff=lfs merge=lfs -text
3
+ *.bin filter=lfs diff=lfs merge=lfs -text
4
+ *.bz2 filter=lfs diff=lfs merge=lfs -text
5
+ *.ckpt filter=lfs diff=lfs merge=lfs -text
6
+ *.ftz filter=lfs diff=lfs merge=lfs -text
7
+ *.gz filter=lfs diff=lfs merge=lfs -text
8
+ *.h5 filter=lfs diff=lfs merge=lfs -text
9
+ *.joblib filter=lfs diff=lfs merge=lfs -text
10
+ *.lfs.* filter=lfs diff=lfs merge=lfs -text
11
+ *.mlmodel filter=lfs diff=lfs merge=lfs -text
12
+ *.model filter=lfs diff=lfs merge=lfs -text
13
+ *.msgpack filter=lfs diff=lfs merge=lfs -text
14
+ *.npy filter=lfs diff=lfs merge=lfs -text
15
+ *.npz filter=lfs diff=lfs merge=lfs -text
16
+ *.onnx filter=lfs diff=lfs merge=lfs -text
17
+ *.ot filter=lfs diff=lfs merge=lfs -text
18
+ *.parquet filter=lfs diff=lfs merge=lfs -text
19
+ *.pb filter=lfs diff=lfs merge=lfs -text
20
+ *.pickle filter=lfs diff=lfs merge=lfs -text
21
+ *.pkl filter=lfs diff=lfs merge=lfs -text
22
+ *.pt filter=lfs diff=lfs merge=lfs -text
23
+ *.pth filter=lfs diff=lfs merge=lfs -text
24
+ *.rar filter=lfs diff=lfs merge=lfs -text
25
+ *.safetensors filter=lfs diff=lfs merge=lfs -text
26
+ saved_model/**/* filter=lfs diff=lfs merge=lfs -text
27
+ *.tar.* filter=lfs diff=lfs merge=lfs -text
28
+ *.tar filter=lfs diff=lfs merge=lfs -text
29
+ *.tflite filter=lfs diff=lfs merge=lfs -text
30
+ *.tgz filter=lfs diff=lfs merge=lfs -text
31
+ *.wasm filter=lfs diff=lfs merge=lfs -text
32
+ *.xz filter=lfs diff=lfs merge=lfs -text
33
+ *.zip filter=lfs diff=lfs merge=lfs -text
34
+ *.zst filter=lfs diff=lfs merge=lfs -text
35
+ *tfevents* filter=lfs diff=lfs merge=lfs -text
README.md ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ library_name: pytorch
3
+ tags:
4
+ - audio
5
+ - spoofing-detection
6
+ - anti-spoofing
7
+ - wav2vec2
8
+ - aasist
9
+ - aasist3
10
+ license: apache-2.0
11
+ pipeline_tag: audio-classification
12
+ model-index:
13
+ - name: spectra_aasist3
14
+ results:
15
+ - task:
16
+ type: Speech Antispoofing
17
+ dataset:
18
+ name: ASVspoof19_LA
19
+ type: ASVspoof19_LA
20
+ metrics:
21
+ - name: Equal Error Rate
22
+ type: Equal Error Rate
23
+ value: 0.723
24
+ - task:
25
+ type: Speech Antispoofing
26
+ dataset:
27
+ name: ASVspoof21_LA
28
+ type: ASVspoof21_LA
29
+ metrics:
30
+ - name: Equal Error Rate
31
+ type: Equal Error Rate
32
+ value: 4.506
33
+ - task:
34
+ type: Speech Antispoofing
35
+ dataset:
36
+ name: ASVspoof21_DF
37
+ type: ASVspoof21_DF
38
+ metrics:
39
+ - name: Equal Error Rate
40
+ type: Equal Error Rate
41
+ value: 1.998
42
+ - task:
43
+ type: Speech Antispoofing
44
+ dataset:
45
+ name: ASVspoof5
46
+ type: ASVspoof5
47
+ metrics:
48
+ - name: Equal Error Rate
49
+ type: Equal Error Rate
50
+ value: 13.82
51
+ - task:
52
+ type: Speech Antispoofing
53
+ dataset:
54
+ name: ADD2022
55
+ type: ADD2022
56
+ metrics:
57
+ - name: Equal Error Rate
58
+ type: Equal Error Rate
59
+ value: 15.187
60
+ - task:
61
+ type: Speech Antispoofing
62
+ dataset:
63
+ name: In-the-Wild
64
+ type: In-the-Wild
65
+ metrics:
66
+ - name: Equal Error Rate
67
+ type: Equal Error Rate
68
+ value: 0.961
69
+ ---
70
+
71
+ ## Model Card: Spectra-0 (anti-spoofing / bonafide vs spoof)
72
+
73
+ `Spectra-AASIST3` is a model for **speech spoofing detection** (binary classification: `bonafide` vs `spoof`) from **raw audio waveforms**. Architecture: SSL encoder (`Wav2Vec2`) → MLP projection → `AASIST3` 2-class classifier.
74
+
75
+ - **Input**: waveform \(float32\), shape `(batch, num_samples)` (typically 16 kHz).
76
+ - **Output**: logits of shape `(batch, 2)`, where **index 0 = spoof**, **index 1 = bonafide**.
77
+
78
+ On first run, the model will automatically download the SSL encoder `facebook/wav2vec2-xls-r-300m` via `transformers`.
79
+
80
+ ## Evaluation Results
81
+
82
+ | Model | ASVspoof19 LA | ASVspoof21 LA | ASVspoof21 DF | ASVspoof5 | ADD2022 | In-the-Wild |
83
+ |-----------|--------|--------|--------|--------|--------|--------|
84
+ | [Res2TCNGuard](https://github.com/mtuciru/Res2TCNGuard) | 7.487 | 19.130 | 19.883 | 37.620 | 49.538 | 49.246 |
85
+ | [AASIST3](https://huggingface.co/MTUCI/AASIST3) | 27.585 | 37.407 | 33.099 | 41.001 | 47.192 | 39.626 |
86
+ | [XSLS](https://github.com/QiShanZhang/SLSforASVspoof-2021-DF) | 0.231 | 7.714 | 4.220 | 17.688 | 33.951 | 7.453 |
87
+ | [TCM-ADD](https://github.com/ductuantruong/tcm_add) | **0.152** | 6.655 | 3.444 | 19.505 | 35.252 | 7.767 |
88
+ | [DF Arena 1B](https://huggingface.co/Speech-Arena-2025/DF_Arena_1B_V_1) | 43.793 | 40.137 | 42.994 | 35.333 | 42.139 | 17.598 |
89
+ | **Spectra-AASIST3** | 0.723 | **4.506** | **1.998** | **13.82** | **15.187** | **0.961** |
90
+
91
+ ## Quickstart
92
+
93
+ ### Clone from Hugging Face
94
+
95
+ This repository is hosted on Hugging Face Hub: `https://huggingface.co/MTUCI/spectra_aasist3`.
96
+
97
+ ```bash
98
+ git lfs install
99
+ git clone https://huggingface.co/MTUCI/spectra_aasist3
100
+ cd spectra_aasist3
101
+ ```
102
+
103
+ ### Install dependencies
104
+
105
+ ```bash
106
+ pip install -U torch torchaudio transformers huggingface_hub safetensors soundfile
107
+ ```
108
+
109
+ ### Single-file inference (example preprocessing)
110
+
111
+ ```python
112
+ import random
113
+ import torch
114
+ import torchaudio
115
+ import soundfile as sf
116
+
117
+ from model import spectra_aasist3
118
+
119
+
120
+ def pad_random(x: torch.Tensor, max_len: int = 64600) -> torch.Tensor:
121
+ # x: (num_samples,) or (1, num_samples)
122
+ if x.ndim > 1:
123
+ x = x.squeeze()
124
+ x_len = x.shape[0]
125
+ if x_len >= max_len:
126
+ start = random.randint(0, x_len - max_len)
127
+ return x[start:start + max_len]
128
+ num_repeats = int(max_len / x_len) + 1
129
+ return x.repeat(num_repeats)[:max_len]
130
+
131
+
132
+ def load_audio_mono(path: str) -> torch.Tensor:
133
+ audio, sr = sf.read(path, dtype="float32")
134
+ audio = torch.from_numpy(audio)
135
+ if audio.ndim > 1:
136
+ # (num_samples, channels) -> mono
137
+ audio = audio.mean(dim=1)
138
+ if sr != 16000:
139
+ audio = torchaudio.functional.resample(audio, sr, 16000)
140
+ return audio
141
+
142
+
143
+ device = "cuda" if torch.cuda.is_available() else "cpu"
144
+ model = spectra_aasist3.from_pretrained(pretrained_model_name_or_path=".").eval().to(device)
145
+
146
+ audio = load_audio_mono("path/to/audio.wav")
147
+ audio = torchaudio.functional.preemphasis(audio.unsqueeze(0)) # (1, T)
148
+ audio = pad_random(audio.squeeze(0), 64600).unsqueeze(0) # (1, 64600)
149
+
150
+ with torch.inference_mode():
151
+ logits = model(audio.to(device)) # (1, 2)
152
+ score_spoof = logits[0, 0].item()
153
+ score_bonafide = logits[0, 1].item()
154
+
155
+ print({"score_bonafide": score_bonafide, "score_spoof": score_spoof})
156
+ ```
157
+
158
+ ## Threshold-based classification (and how to tune it)
159
+
160
+ In `model.py`, the `SpectraAASIST3` class provides `classify()` with a **default threshold** chosen as an “optimal” value for the original setting:
161
+
162
+ - **Default threshold**: `-1.0625009` (it thresholds `logit_bonafide = logits[:, 1]`)
163
+ - **Note**: this threshold **may not be optimal** on a different dataset/domain. It’s recommended to tune the threshold on your dataset using **EER** (Equal Error Rate) or a target FAR/FRR.
164
+
165
+ Example:
166
+
167
+ ```python
168
+ with torch.inference_mode():
169
+ pred = model.classify(audio.to(device), threshold=-1.0625009) # 1=bonafide, 0=spoof
170
+ ```
171
+
172
+ ### Tuning the threshold via EER (typical workflow)
173
+
174
+ 1) Run the model on a labeled set and collect scores for both classes.
175
+
176
+ 2) Compute EER and the threshold
177
+
178
+ ## Limitations and notes
179
+
180
+ - This is a **pre-release** model.
181
+ - Significantly stronger models are planned for **Q3–Q4 2026** — stay tuned.
182
+
183
+ ## License
184
+
185
+ MIT (see the `license` field in the model repo header).
186
+
187
+ ## Contacts
188
+
189
+ TG channel: https://t.me/korallll_ai
190
+ email: k.n.borodin@mtuci.ru
191
+ website: https://lab260.ru/
config.json ADDED
@@ -0,0 +1 @@
 
 
1
+ {}
model.py ADDED
@@ -0,0 +1,955 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+
3
+ import torch
4
+ import torch.nn as nn
5
+ from transformers import Wav2Vec2Model
6
+ import torch.nn.functional as F
7
+ from huggingface_hub import PyTorchModelHubMixin
8
+
9
+
10
+ class KANLinear(torch.nn.Module):
11
+ def __init__(
12
+ self,
13
+ in_features,
14
+ out_features,
15
+ grid_size=16,
16
+ spline_order=4,
17
+ scale_noise=0.1,
18
+ scale_base=1.0,
19
+ scale_spline=1.0,
20
+ enable_standalone_scale_spline=True,
21
+ base_activation=torch.nn.PReLU,
22
+ grid_eps=0.02,
23
+ grid_range=[-1, 1],
24
+ ):
25
+ super(KANLinear, self).__init__()
26
+ self.in_features = in_features
27
+ self.out_features = out_features
28
+ self.grid_size = grid_size
29
+ self.spline_order = spline_order
30
+
31
+ h = (grid_range[1] - grid_range[0]) / grid_size
32
+ grid = (
33
+ (
34
+ torch.arange(-spline_order, grid_size + spline_order + 1) * h
35
+ + grid_range[0]
36
+ )
37
+ .expand(in_features, -1)
38
+ .contiguous()
39
+ )
40
+ self.register_buffer("grid", grid)
41
+
42
+ self.base_weight = torch.nn.Parameter(torch.Tensor(out_features, in_features))
43
+ self.spline_weight = torch.nn.Parameter(
44
+ torch.Tensor(out_features, in_features, grid_size + spline_order)
45
+ )
46
+ if enable_standalone_scale_spline:
47
+ self.spline_scaler = torch.nn.Parameter(
48
+ torch.Tensor(out_features, in_features)
49
+ )
50
+
51
+ self.scale_noise = scale_noise
52
+ self.scale_base = scale_base
53
+ self.scale_spline = scale_spline
54
+ self.enable_standalone_scale_spline = enable_standalone_scale_spline
55
+ self.base_activation = base_activation()
56
+ self.grid_eps = grid_eps
57
+
58
+ self.reset_parameters()
59
+
60
+ def reset_parameters(self):
61
+ torch.nn.init.kaiming_uniform_(self.base_weight, a=math.sqrt(5) * self.scale_base)
62
+ with torch.no_grad():
63
+ noise = (
64
+ (
65
+ torch.rand(self.grid_size + 1, self.in_features, self.out_features)
66
+ - 1 / 2
67
+ )
68
+ * self.scale_noise
69
+ / self.grid_size
70
+ )
71
+ self.spline_weight.data.copy_(
72
+ (self.scale_spline if not self.enable_standalone_scale_spline else 1.0)
73
+ * self.curve2coeff(
74
+ self.grid.T[self.spline_order:-self.spline_order],
75
+ noise,
76
+ )
77
+ )
78
+ if self.enable_standalone_scale_spline:
79
+ # torch.nn.init.constant_(self.spline_scaler, self.scale_spline)
80
+ torch.nn.init.kaiming_uniform_(self.spline_scaler, a=math.sqrt(5) * self.scale_spline)
81
+
82
+ def b_splines(self, x: torch.Tensor):
83
+ """
84
+ Compute the B-spline bases for the given input tensor.
85
+
86
+ Args:
87
+ x (torch.Tensor): Input tensor of shape (batch_size, in_features).
88
+
89
+ Returns:
90
+ torch.Tensor: B-spline bases tensor of shape (batch_size, in_features, grid_size + spline_order).
91
+ """
92
+ assert x.dim() == 2 and x.size(1) == self.in_features
93
+
94
+ grid: torch.Tensor = (
95
+ self.grid
96
+ ) # (in_features, grid_size + 2 * spline_order + 1)
97
+ x = x.unsqueeze(-1)
98
+ bases = ((x >= grid[:, :-1]) & (x < grid[:, 1:])).to(x.dtype)
99
+ for k in range(1, self.spline_order + 1):
100
+ bases = (
101
+ (x - grid[:, : -(k + 1)])
102
+ / (grid[:, k:-1] - grid[:, : -(k + 1)])
103
+ * bases[:, :, :-1]
104
+ ) + (
105
+ (grid[:, k + 1:] - x)
106
+ / (grid[:, k + 1:] - grid[:, 1:(-k)])
107
+ * bases[:, :, 1:]
108
+ )
109
+
110
+ assert bases.size() == (
111
+ x.size(0),
112
+ self.in_features,
113
+ self.grid_size + self.spline_order,
114
+ )
115
+ return bases.contiguous()
116
+
117
+ def curve2coeff(self, x: torch.Tensor, y: torch.Tensor):
118
+ """
119
+ Compute the coefficients of the curve that interpolates the given points.
120
+
121
+ Args:
122
+ x (torch.Tensor): Input tensor of shape (batch_size, in_features).
123
+ y (torch.Tensor): Output tensor of shape (batch_size, in_features, out_features).
124
+
125
+ Returns:
126
+ torch.Tensor: Coefficients tensor of shape (out_features, in_features, grid_size + spline_order).
127
+ """
128
+ assert x.dim() == 2 and x.size(1) == self.in_features
129
+ assert y.size() == (x.size(0), self.in_features, self.out_features)
130
+
131
+ A = self.b_splines(x).transpose(
132
+ 0, 1
133
+ ) # (in_features, batch_size, grid_size + spline_order)
134
+ B = y.transpose(0, 1) # (in_features, batch_size, out_features)
135
+ solution = torch.linalg.lstsq(
136
+ A, B
137
+ ).solution # (in_features, grid_size + spline_order, out_features)
138
+ result = solution.permute(
139
+ 2, 0, 1
140
+ ) # (out_features, in_features, grid_size + spline_order)
141
+
142
+ assert result.size() == (
143
+ self.out_features,
144
+ self.in_features,
145
+ self.grid_size + self.spline_order,
146
+ )
147
+ return result.contiguous()
148
+
149
+ @property
150
+ def scaled_spline_weight(self):
151
+ return self.spline_weight * (
152
+ self.spline_scaler.unsqueeze(-1)
153
+ if self.enable_standalone_scale_spline
154
+ else 1.0
155
+ )
156
+
157
+ def forward(self, x: torch.Tensor):
158
+ assert x.size(-1) == self.in_features
159
+ original_shape = x.shape
160
+ x = x.reshape(-1, self.in_features)
161
+
162
+ base_output = F.linear(self.base_activation(x), self.base_weight)
163
+ spline_output = F.linear(
164
+ self.b_splines(x).view(x.size(0), -1),
165
+ self.scaled_spline_weight.reshape(self.out_features, -1),
166
+ )
167
+ output = base_output + spline_output
168
+ # print(*original_shape[:-1], output.shape)
169
+ output = output.view(*original_shape[:-1], self.out_features)
170
+ return output
171
+
172
+
173
+ class Wav2Vec2Encoder(nn.Module):
174
+ """SSL encoder based on Hugging Face's Wav2Vec2 model."""
175
+
176
+ def __init__(self,
177
+ model_name_or_path: str = "facebook/wav2vec2-base-960h",
178
+ ssl_out_dim: int = 1024,
179
+ use_ssl_n_layers: int = None,
180
+ freeze_ssl_n_layers: int = 0,
181
+ output_attentions: bool = False,
182
+ output_hidden_states: bool = False,
183
+ normalize_waveform: bool = True):
184
+ """Initialize the Wav2Vec2 encoder.
185
+
186
+ Args:
187
+ model_name_or_path: HuggingFace model name or path to local model.
188
+ ssl_out_dim: Output dimension of the Wav2Vec2 encoder.
189
+ use_ssl_n_layers: Number of Wav2Vec2 layers to use. If None, use all layers.
190
+ freeze_ssl_n_layers: Number of Wav2Vec2 layers to freeze during training.
191
+ output_attentions: Whether to output attentions.
192
+ output_hidden_states: Whether to output hidden states.
193
+ normalize_waveform: Whether to normalize the waveform input.
194
+ """
195
+ super().__init__()
196
+
197
+ self.model_name_or_path = model_name_or_path
198
+ self.ssl_out_dim = ssl_out_dim
199
+ self.use_ssl_n_layers = use_ssl_n_layers
200
+ self.freeze_ssl_n_layers = freeze_ssl_n_layers
201
+ self.output_attentions = output_attentions
202
+ self.output_hidden_states = output_hidden_states
203
+ self.normalize_waveform = normalize_waveform
204
+
205
+ # Load Wav2Vec2 model
206
+ self.model = Wav2Vec2Model.from_pretrained(
207
+ model_name_or_path,
208
+ gradient_checkpointing=False)
209
+ self.model.config.apply_spec_augment = False
210
+ self.model.masked_spec_embed = None
211
+
212
+ # Handle layer freezing
213
+ if freeze_ssl_n_layers > 0:
214
+ self._freeze_layers(freeze_ssl_n_layers)
215
+
216
+ def _freeze_layers(self, n_layers):
217
+ """Freeze the first n_layers layers of the Wav2Vec2 encoder.
218
+
219
+ Args:
220
+ n_layers: Number of layers to freeze.
221
+ """
222
+ # Freeze feature extractor
223
+ if n_layers > 0:
224
+ for param in self.model.feature_extractor.parameters():
225
+ param.requires_grad = False
226
+
227
+ # Freeze encoder layers
228
+ encoder_layers = self.model.encoder.layers
229
+ total_layers = len(encoder_layers)
230
+ layers_to_freeze = min(n_layers - 1, total_layers) # -1 because feature_extractor counts as one layer
231
+
232
+ if layers_to_freeze > 0:
233
+ for i in range(layers_to_freeze):
234
+ for param in encoder_layers[i].parameters():
235
+ param.requires_grad = False
236
+
237
+ def forward(self, x):
238
+ """Forward pass through the Wav2Vec2 encoder.
239
+
240
+ Args:
241
+ x: Input tensor of shape (batch_size, sequence_length, channels)
242
+
243
+ Returns:
244
+ Extracted features of shape (batch_size, sequence_length, ssl_out_dim)
245
+ """
246
+ # Handle shape: convert (batch_size, sequence_length, channels) to (batch_size, sequence_length)
247
+ if x.ndim == 3:
248
+ x = x.squeeze(-1) # Remove channel dimension if present
249
+
250
+ # Normalize input if specified
251
+ if self.normalize_waveform:
252
+ x = x / (torch.max(torch.abs(x), dim=1, keepdim=True)[0] + 1e-8)
253
+
254
+ # Wav2Vec2 forward pass
255
+ outputs = self.model(
256
+ x,
257
+ output_attentions=self.output_attentions,
258
+ output_hidden_states=self.output_hidden_states,
259
+ return_dict=True
260
+ )
261
+
262
+ # Extract last hidden state
263
+ last_hidden_state = outputs.last_hidden_state
264
+
265
+ # Optionally use only a subset of layers (if use_ssl_n_layers is set and output_hidden_states is True)
266
+ if self.use_ssl_n_layers is not None and self.output_hidden_states and outputs.hidden_states is not None:
267
+ # Use the last N hidden states and concatenate or average them
268
+ selected = outputs.hidden_states[-self.use_ssl_n_layers:]
269
+ last_hidden_state = torch.mean(torch.stack(selected, dim=0), dim=0)
270
+ del outputs
271
+
272
+ return last_hidden_state
273
+
274
+
275
+ class MLPBridge(nn.Module):
276
+ """MLP bridge between SSL encoder and AASIST model."""
277
+
278
+ def __init__(self, input_dim: int, output_dim: int, hidden_dim: int = None,
279
+ dropout: float = 0.1, activation: str = nn.ReLU, n_layers: int = 1):
280
+ """Initialize the MLP bridge.
281
+
282
+ Args:
283
+ input_dim: The input dimension from the SSL encoder.
284
+ output_dim: The output dimension for the AASIST model.
285
+ hidden_dim: Hidden dimension size. If None, use the average of input and output dims.
286
+ dropout: Dropout probability to apply between layers.
287
+ activation: Activation function to use
288
+ n_layers: Number of MLP layers (repeats of Linear+Activation+Dropout blocks).
289
+ """
290
+ super().__init__()
291
+
292
+ if hidden_dim is None:
293
+ hidden_dim = (input_dim + output_dim) // 2
294
+
295
+ self.input_dim = input_dim
296
+ self.output_dim = output_dim
297
+ self.hidden_dim = hidden_dim
298
+ self.n_layers = n_layers
299
+
300
+ assert hasattr(activation, 'forward') and callable(getattr(activation, 'forward', None)), "Activation class must have a callable forward() method."
301
+ act_fn = activation
302
+
303
+ layers = []
304
+ for i in range(n_layers):
305
+ in_dim = input_dim if i == 0 else hidden_dim
306
+ out_dim = hidden_dim
307
+ layers.append(nn.Linear(in_dim, out_dim))
308
+ layers.append(act_fn)
309
+ layers.append(nn.Dropout(dropout) if dropout > 0 else nn.Identity())
310
+ # Final output layer
311
+ layers.append(nn.Linear(hidden_dim, output_dim))
312
+ layers.append(nn.Dropout(dropout) if dropout > 0 else nn.Identity())
313
+
314
+ self.mlp = nn.Sequential(*layers)
315
+
316
+ def forward(self, x):
317
+ """Forward pass through the bridge.
318
+
319
+ Args:
320
+ x: The input tensor from the SSL encoder.
321
+
322
+ Returns:
323
+ The transformed tensor for the AASIST model.
324
+ """
325
+ return self.mlp(x)
326
+
327
+
328
+ class HtrgGraphAttentionLayer(nn.Module):
329
+ def __init__(self, in_dim, out_dim, size, layer="KANLinear", **kwargs):
330
+ super().__init__()
331
+ if layer == "KANLinear":
332
+ self.proj_type1 = KANLinear(in_dim, in_dim)
333
+ self.proj_type2 = KANLinear(in_dim, in_dim)
334
+ self.att_proj = KANLinear(in_dim, out_dim)
335
+ self.att_projM = KANLinear(in_dim, out_dim)
336
+ self.proj_with_att = KANLinear(in_dim, out_dim)
337
+ self.proj_without_att = KANLinear(in_dim, out_dim)
338
+ self.proj_with_attM = KANLinear(in_dim, out_dim)
339
+ self.proj_without_attM = KANLinear(in_dim, out_dim)
340
+ else:
341
+ raise ValueError(f"Invalid layer type: {layer}")
342
+ self.att_weight11 = self._init_new_params(out_dim, 1)
343
+ self.att_weight22 = self._init_new_params(out_dim, 1)
344
+ self.att_weight12 = self._init_new_params(out_dim, 1)
345
+ self.att_weightM = self._init_new_params(out_dim, 1)
346
+ self.bn = nn.BatchNorm1d(out_dim)
347
+ self.input_drop = nn.Dropout(p=0.2)
348
+ self.act = nn.SELU(inplace=True)
349
+ self.temp = 1.
350
+ if "temperature" in kwargs:
351
+ self.temp = kwargs["temperature"]
352
+
353
+ def forward(self, x1, x2, master=None):
354
+ '''
355
+ x1 :(#bs, #node, #dim)
356
+ x2 :(#bs, #node, #dim)
357
+ '''
358
+ num_type1 = x1.size(1)
359
+ num_type2 = x2.size(1)
360
+
361
+ x1 = self.proj_type1(x1)
362
+ x2 = self.proj_type2(x2)
363
+
364
+ x = torch.cat([x1, x2], dim=1)
365
+
366
+ if master is None:
367
+ master = torch.mean(x, dim=1, keepdim=True)
368
+
369
+ # apply input dropout
370
+ x = self.input_drop(x)
371
+
372
+ # derive attention map
373
+ att_map = self._derive_att_map(x, num_type1, num_type2)
374
+
375
+ # directional edge for master node
376
+ master = self._update_master(x, master)
377
+
378
+ # projection
379
+ x = self._project(x, att_map)
380
+
381
+ # apply batch norm
382
+ x = self._apply_BN(x)
383
+ # x = self.act(x)
384
+
385
+ x1 = x.narrow(1, 0, num_type1)
386
+ x2 = x.narrow(1, num_type1, num_type2)
387
+
388
+ return x1, x2, master
389
+
390
+ def _update_master(self, x, master):
391
+
392
+ att_map = self._derive_att_map_master(x, master)
393
+ master = self._project_master(x, master, att_map)
394
+
395
+ return master
396
+
397
+ def _pairwise_mul_nodes(self, x):
398
+ '''
399
+ Calculates pairwise multiplication of nodes.
400
+ - for attention map
401
+ x :(#bs, #node, #dim)
402
+ out_shape :(#bs, #node, #node, #dim)
403
+ '''
404
+
405
+ nb_nodes = x.size(1)
406
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
407
+ x_mirror = x.transpose(1, 2)
408
+
409
+ return x * x_mirror
410
+
411
+ def _derive_att_map_master(self, x, master):
412
+ '''
413
+ x :(#bs, #node, #dim)
414
+ out_shape :(#bs, #node, #node, 1)
415
+ '''
416
+ att_map = x * master
417
+ att_map = torch.tanh(self.att_projM(att_map))
418
+
419
+ att_map = torch.matmul(att_map, self.att_weightM)
420
+
421
+ # apply temperature
422
+ att_map = att_map / self.temp
423
+
424
+ att_map = F.softmax(att_map, dim=-2)
425
+
426
+ return att_map
427
+
428
+ def _derive_att_map(self, x, num_type1, num_type2):
429
+ '''
430
+ x :(#bs, #node, #dim)
431
+ out_shape :(#bs, #node, #node, 1)
432
+ '''
433
+ att_map = self._pairwise_mul_nodes(x)
434
+ # size: (#bs, #node, #node, #dim_out)
435
+ att_map = torch.tanh(self.att_proj(att_map))
436
+ # size: (#bs, #node, #node, 1)
437
+
438
+ att_board = torch.zeros_like(att_map[:, :, :, 0]).unsqueeze(-1)
439
+
440
+ att_board[:, :num_type1, :num_type1, :] = torch.matmul(
441
+ att_map[:, :num_type1, :num_type1, :], self.att_weight11)
442
+ att_board[:, num_type1:, num_type1:, :] = torch.matmul(
443
+ att_map[:, num_type1:, num_type1:, :], self.att_weight22)
444
+ att_board[:, :num_type1, num_type1:, :] = torch.matmul(
445
+ att_map[:, :num_type1, num_type1:, :], self.att_weight12)
446
+ att_board[:, num_type1:, :num_type1, :] = torch.matmul(
447
+ att_map[:, num_type1:, :num_type1, :], self.att_weight12)
448
+
449
+ att_map = att_board
450
+
451
+ # att_map = torch.matmul(att_map, self.att_weight12)
452
+
453
+ # apply temperature
454
+ att_map = att_map / self.temp
455
+
456
+ att_map = F.softmax(att_map, dim=-2)
457
+
458
+ return att_map
459
+
460
+ def _project(self, x, att_map):
461
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
462
+ x2 = self.proj_without_att(x)
463
+
464
+ return x1 + x2
465
+
466
+ def _project_master(self, x, master, att_map):
467
+
468
+ x1 = self.proj_with_attM(torch.matmul(
469
+ att_map.squeeze(-1).unsqueeze(1), x))
470
+ x2 = self.proj_without_attM(master)
471
+
472
+ return x1 + x2
473
+
474
+ def _apply_BN(self, x):
475
+ org_size = x.size()
476
+ x = x.view(-1, org_size[-1])
477
+ x = self.bn(x)
478
+ x = x.view(org_size)
479
+
480
+ return x
481
+
482
+ def _init_new_params(self, *size):
483
+ out = nn.Parameter(torch.FloatTensor(*size))
484
+ nn.init.xavier_normal_(out)
485
+ return out
486
+
487
+
488
+ class GraphPool(nn.Module):
489
+ def __init__(self, k: float, in_dim: int, p, size, layer="KANLinear"):
490
+ super().__init__()
491
+ self.k = k
492
+ self.sigmoid = nn.Sigmoid()
493
+ if layer == "KANLinear":
494
+ self.proj = KANLinear(in_dim, 1)
495
+ else:
496
+ raise ValueError(f"Invalid layer type: {layer}")
497
+ self.drop = nn.Dropout(p=p) if p > 0 else nn.Identity()
498
+ self.in_dim = in_dim
499
+
500
+ def forward(self, h):
501
+ Z = self.drop(h)
502
+ weights = self.proj(Z)
503
+ scores = self.sigmoid(weights)
504
+ new_h = self.top_k_graph(scores, h, self.k)
505
+
506
+ return new_h
507
+
508
+ def top_k_graph(self, scores, h, k):
509
+ """
510
+ args
511
+ =====
512
+ scores: attention-based weights (#bs, #node, 1)
513
+ h: graph data (#bs, #node, #dim)
514
+ k: ratio of remaining nodes, (float)
515
+
516
+ returns
517
+ =====
518
+ h: graph pool applied data (#bs, #node', #dim)
519
+ """
520
+ _, n_nodes, n_feat = h.size()
521
+ n_nodes = max(int(n_nodes * k), 1)
522
+ _, idx = torch.topk(scores, n_nodes, dim=1)
523
+ idx = idx.expand(-1, -1, n_feat)
524
+
525
+ h = h * scores
526
+ h = torch.gather(h, 1, idx)
527
+
528
+ return h
529
+
530
+
531
+ class GraphAttentionLayer(nn.Module):
532
+ def __init__(self, in_dim, out_dim, layer="KANLinear", **kwargs):
533
+ super().__init__()
534
+ # attention map
535
+ if layer == "KANLinear":
536
+ self.att_proj = KANLinear(in_dim, out_dim)
537
+ self.proj_with_att = KANLinear(in_dim, out_dim)
538
+ self.proj_without_att = KANLinear(in_dim, out_dim)
539
+ else:
540
+ raise ValueError(f"Invalid layer type: {layer}")
541
+ self.att_weight = self._init_new_params(out_dim, 1)
542
+
543
+ # batch norm
544
+ self.bn = nn.BatchNorm1d(out_dim)
545
+
546
+ # dropout for inputs
547
+ self.input_drop = nn.Dropout(p=0.2)
548
+
549
+ # activate
550
+ self.act = nn.SELU(inplace=True)
551
+
552
+ # temperature
553
+ self.temp = 1.
554
+ if "temperature" in kwargs:
555
+ self.temp = kwargs["temperature"]
556
+
557
+ def forward(self, x):
558
+ '''
559
+ x :(#bs, #node, #dim)
560
+ '''
561
+ # apply input dropout
562
+ x = self.input_drop(x)
563
+
564
+ # derive attention map
565
+ att_map = self._derive_att_map(x)
566
+
567
+ # projection
568
+ x = self._project(x, att_map)
569
+
570
+ # apply batch norm
571
+ x = self._apply_BN(x)
572
+ x = self.act(x)
573
+ return x
574
+
575
+ def _pairwise_mul_nodes(self, x):
576
+ '''
577
+ Calculates pairwise multiplication of nodes.
578
+ - for attention map
579
+ x :(#bs, #node, #dim)
580
+ out_shape :(#bs, #node, #node, #dim)
581
+ '''
582
+
583
+ nb_nodes = x.size(1)
584
+ x = x.unsqueeze(2).expand(-1, -1, nb_nodes, -1)
585
+ x_mirror = x.transpose(1, 2)
586
+
587
+ return x * x_mirror
588
+
589
+ def _derive_att_map(self, x):
590
+ '''
591
+ x :(#bs, #node, #dim)
592
+ out_shape :(#bs, #node, #node, 1)
593
+ '''
594
+ att_map = self._pairwise_mul_nodes(x)
595
+ # size: (#bs, #node, #node, #dim_out)
596
+ att_map = torch.tanh(self.att_proj(att_map))
597
+ # size: (#bs, #node, #node, 1)
598
+ att_map = torch.matmul(att_map, self.att_weight)
599
+
600
+ # apply temperature
601
+ att_map = att_map / self.temp
602
+
603
+ att_map = F.softmax(att_map, dim=-2)
604
+
605
+ return att_map
606
+
607
+ def _project(self, x, att_map):
608
+ x1 = self.proj_with_att(torch.matmul(att_map.squeeze(-1), x))
609
+ x2 = self.proj_without_att(x)
610
+
611
+ return x1 + x2
612
+
613
+ def _apply_BN(self, x):
614
+ org_size = x.size()
615
+ x = x.view(-1, org_size[-1])
616
+ x = self.bn(x)
617
+ x = x.view(org_size)
618
+
619
+ return x
620
+
621
+ def _init_new_params(self, *size):
622
+ out = nn.Parameter(torch.FloatTensor(*size))
623
+ nn.init.xavier_normal_(out)
624
+ return out
625
+
626
+
627
+ class Res2NetBlock(nn.Module):
628
+ def __init__(self, in_channels, out_channels, scale=4, kernel_size=(2, 3), stride=1, padding=(1, 1)):
629
+ super().__init__()
630
+ assert out_channels % scale == 0, "out_channels must be divisible by scale"
631
+ self.scale = scale
632
+ self.width = out_channels // scale
633
+ self.conv1 = nn.Conv2d(in_channels, out_channels, kernel_size=1)
634
+ self.convs = nn.ModuleList([
635
+ nn.Conv2d(self.width, self.width, kernel_size=kernel_size, stride=stride, padding=padding)
636
+ for _ in range(scale)
637
+ ])
638
+ self.bn = nn.BatchNorm2d(out_channels)
639
+ self.selu = nn.SELU(inplace=True)
640
+ self.conv3 = nn.Conv2d(out_channels, out_channels, kernel_size=1)
641
+ self.downsample = None
642
+ if in_channels != out_channels:
643
+ self.downsample = nn.Conv2d(in_channels, out_channels, kernel_size=1)
644
+
645
+ def forward(self, x):
646
+ identity = x
647
+ out = self.conv1(x)
648
+ xs = torch.chunk(out, self.scale, dim=1)
649
+ ys = []
650
+ for s in range(self.scale):
651
+ if s == 0:
652
+ ys.append(self.convs[s](xs[s]))
653
+ else:
654
+ ys.append(self.convs[s](xs[s] + ys[s - 1]))
655
+ out = torch.cat(ys, dim=1)
656
+ out = self.bn(out)
657
+ out = self.selu(out)
658
+ out = self.conv3(out)
659
+ if self.downsample is not None:
660
+ identity = self.downsample(identity)
661
+ out += identity
662
+ return out
663
+
664
+
665
+ class Residual_block(nn.Module):
666
+ def __init__(self, nb_filts, first=False):
667
+ super().__init__()
668
+ self.first = first
669
+
670
+ if not self.first:
671
+ self.bn1 = nn.BatchNorm2d(num_features=nb_filts[0])
672
+ self.conv1 = nn.Conv2d(in_channels=nb_filts[0],
673
+ out_channels=nb_filts[1],
674
+ kernel_size=(2, 3),
675
+ padding=(1, 1),
676
+ stride=1)
677
+ self.selu = nn.SELU(inplace=True)
678
+
679
+ self.bn2 = nn.BatchNorm2d(num_features=nb_filts[1])
680
+ self.conv2 = nn.Conv2d(in_channels=nb_filts[1],
681
+ out_channels=nb_filts[1],
682
+ kernel_size=(2, 3),
683
+ padding=(0, 1),
684
+ stride=1)
685
+
686
+ if nb_filts[0] != nb_filts[1]:
687
+ self.downsample = True
688
+ self.conv_downsample = nn.Conv2d(in_channels=nb_filts[0],
689
+ out_channels=nb_filts[1],
690
+ padding=(0, 1),
691
+ kernel_size=(1, 3),
692
+ stride=1)
693
+
694
+ else:
695
+ self.downsample = False
696
+
697
+ def forward(self, x):
698
+ identity = x
699
+ if not self.first:
700
+ out = self.bn1(x)
701
+ out = self.selu(out)
702
+ else:
703
+ out = x
704
+
705
+ # print('out',out.shape)
706
+ out = self.conv1(out)
707
+
708
+ # print('aft conv1 out',out.shape)
709
+ out = self.bn2(out)
710
+ out = self.selu(out)
711
+ # print('out',out.shape)
712
+ out = self.conv2(out)
713
+ # print('conv2 out',out.shape)
714
+
715
+ if self.downsample:
716
+ identity = self.conv_downsample(identity)
717
+
718
+ out += identity
719
+ # out = self.mp(out)
720
+ return out
721
+
722
+
723
+ class Encoder(nn.Module):
724
+ def __init__(self, filts):
725
+ super().__init__()
726
+
727
+ self.first_bn = nn.BatchNorm2d(num_features=1)
728
+ self.first_bn1 = nn.BatchNorm2d(num_features=64)
729
+
730
+ self.selu = nn.SELU(inplace=True)
731
+ self.enc = nn.Sequential(
732
+ nn.Sequential(Residual_block(nb_filts=filts[1], first=True)),
733
+ nn.Sequential(Residual_block(nb_filts=filts[2])),
734
+ nn.Sequential(Residual_block(nb_filts=filts[3])),
735
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
736
+ nn.Sequential(Residual_block(nb_filts=filts[4])),
737
+ nn.Sequential(Residual_block(nb_filts=filts[4]))
738
+ )
739
+
740
+ def forward(self, x):
741
+
742
+ x = x.transpose(1, 2)
743
+ x = x.unsqueeze(dim=1)
744
+
745
+ x = F.max_pool2d(torch.abs(x), (3, 3))
746
+ x = self.first_bn(x)
747
+ x = self.selu(x)
748
+
749
+ # # get embeddings using encoder
750
+ # # (#bs, #filt, #spec, #seq)
751
+
752
+ x = self.enc(x)
753
+
754
+ x = self.first_bn1(x)
755
+ x = self.selu(x)
756
+
757
+ return x
758
+
759
+
760
+ class HSGALBranch_v1(nn.Module):
761
+ def __init__(self, gat_dims, temperatures, pool_ratios, size=200, layer="KANLinear"):
762
+ super().__init__()
763
+
764
+ self.master = nn.Parameter(torch.randn(1, 1, gat_dims[0]))
765
+ self.HtrgGAT_layer_ST1 = HtrgGraphAttentionLayer(
766
+ gat_dims[0], gat_dims[1], temperature=temperatures[2], size=size, layer=layer
767
+ )
768
+ self.HtrgGAT_layer_ST2 = HtrgGraphAttentionLayer(
769
+ gat_dims[1], gat_dims[1], temperature=temperatures[2], size=size, layer=layer
770
+ )
771
+
772
+ self.pool_hS = GraphPool(pool_ratios[2], gat_dims[1], 0.3, size=size, layer=layer)
773
+ self.pool_hT = GraphPool(pool_ratios[2], gat_dims[1], 0.3, size=size, layer=layer)
774
+
775
+ self.drop_way = nn.Dropout(0.2, inplace=True)
776
+
777
+ def forward(self, out_t, out_s):
778
+ out_T, out_S, master = self.HtrgGAT_layer_ST1(
779
+ out_t, out_s, master=self.master
780
+ )
781
+
782
+ out_S = self.pool_hS(out_S)
783
+ out_T = self.pool_hT(out_T)
784
+
785
+ out_T_aug, out_S_aug, master_aug = self.HtrgGAT_layer_ST2(
786
+ out_T, out_S, master=master
787
+ )
788
+ out_T = out_T + out_T_aug
789
+ out_S = out_S + out_S_aug
790
+ master = master + master_aug
791
+
792
+ out_T = self.drop_way(out_T)
793
+ out_S = self.drop_way(out_S)
794
+ master = self.drop_way(master)
795
+
796
+ return out_T, out_S, master
797
+
798
+
799
+ class KANAASIST(nn.Module):
800
+ """KAN-AASIST model with graph attention layers."""
801
+
802
+ def __init__(
803
+ self,
804
+ d_args={
805
+ "architecture": "AASIST",
806
+ "nb_samp": 64600,
807
+ "filts": [512, [1, 32], [32, 32], [32, 64], [64, 64]],
808
+ "gat_dims": [64, 32],
809
+ "pool_ratios": [0.5, 0.5, 0.5, 0.5],
810
+ "temperatures": [2.0, 2.0, 100.0, 100.0]
811
+ },
812
+ encoder=Encoder,
813
+ size=200,
814
+ n_frames=400,
815
+ layer_type="Linear",
816
+ **kwargs
817
+ ):
818
+ super().__init__()
819
+
820
+ layer = layer_type
821
+ self.d_args = d_args
822
+ filts = d_args["filts"]
823
+ gat_dims = d_args["gat_dims"]
824
+ pool_ratios = d_args["pool_ratios"]
825
+ temperatures = d_args["temperatures"]
826
+
827
+ self.drop = nn.Dropout(0.5, inplace=True)
828
+ self.drop_way = nn.Dropout(0.2, inplace=True)
829
+
830
+ self.attention = nn.Sequential(
831
+ nn.Conv2d(64, 128, kernel_size=(1, 1)),
832
+ nn.SELU(inplace=True),
833
+ nn.BatchNorm2d(128),
834
+ nn.Conv2d(128, 64, kernel_size=(1, 1)),
835
+ )
836
+
837
+ self.pos_S = nn.Parameter(torch.randn(1, filts[0] // 3, filts[-1][-1]))
838
+ self.pos_T = nn.Parameter(torch.randn(1, n_frames, filts[0]))
839
+
840
+ self.GAT_layer_S = GraphAttentionLayer(filts[-1][-1],
841
+ gat_dims[0],
842
+ temperature=temperatures[0], size=size, layer=layer)
843
+ self.GAT_layer_T = GraphAttentionLayer(filts[-1][-1],
844
+ gat_dims[0],
845
+ temperature=temperatures[1], size=size, layer=layer)
846
+
847
+ self.branch1 = HSGALBranch_v1(gat_dims, temperatures, pool_ratios, size, layer=layer)
848
+ self.branch2 = HSGALBranch_v1(gat_dims, temperatures, pool_ratios, size, layer=layer)
849
+ self.branch3 = HSGALBranch_v1(gat_dims, temperatures, pool_ratios, size, layer=layer)
850
+ self.branch4 = HSGALBranch_v1(gat_dims, temperatures, pool_ratios, size, layer=layer)
851
+
852
+ self.pool_S = GraphPool(pool_ratios[0], gat_dims[0], 0.3, size=size, layer=layer)
853
+ self.pool_T = GraphPool(pool_ratios[1], gat_dims[0], 0.3, size=size, layer=layer)
854
+
855
+ out_features = 2
856
+ in_features = 5 * gat_dims[1]
857
+ if layer == 'KANLinear':
858
+ self.out_layer = KANLinear(in_features, out_features)
859
+ else:
860
+ raise ValueError(f"Invalid layer type: {layer}")
861
+ self.enc = encoder(filts=filts)
862
+
863
+ def forward(self, x, Freq_aug=False):
864
+ """Forward pass through the KAN-AASIST model.
865
+
866
+ Args:
867
+ x: Input tensor of shape (batch_size, seq_len, channels)
868
+ Freq_aug: Whether to use frequency augmentation
869
+
870
+ Returns:
871
+ Model output for binary classification.
872
+ """
873
+ x = x + self.pos_T[:, :x.size(1), :]
874
+ x = self.enc(x)
875
+ # attention block assumes x is (batch, time, feature_dim)
876
+ # Adapt attention block if needed for SSL features
877
+ w = self.attention(x)
878
+ w1 = F.softmax(w, dim=-1)
879
+ m = torch.sum(x * w1, dim=-1)
880
+ e_S = m.transpose(1, 2) + self.pos_S
881
+
882
+ gat_S = self.GAT_layer_S(e_S)
883
+ out_S = self.pool_S(gat_S) # (#bs, #node, #dim)
884
+
885
+ w2 = F.softmax(w, dim=-2)
886
+ m1 = torch.sum(x * w2, dim=-2)
887
+
888
+ e_T = m1.transpose(1, 2)
889
+
890
+ gat_T = self.GAT_layer_T(e_T)
891
+ out_T = self.pool_T(gat_T)
892
+
893
+ out_T1, out_S1, master1 = self.branch1(out_T, out_S)
894
+ out_T2, out_S2, master2 = self.branch2(out_T, out_S)
895
+ out_T3, out_S3, master3 = self.branch3(out_T, out_S)
896
+ out_T4, out_S4, master4 = self.branch4(out_T, out_S)
897
+
898
+ out_T = torch.amax(torch.stack([out_T1, out_T2, out_T3, out_T4]), dim=0)
899
+ out_S = torch.amax(torch.stack([out_S1, out_S2, out_S3, out_S4]), dim=0)
900
+ master = torch.amax(torch.stack([master1, master2, master3, master4]), dim=0)
901
+
902
+ T_max, _ = torch.max(torch.abs(out_T), dim=1)
903
+ T_avg = torch.mean(out_T, dim=1)
904
+
905
+ S_max, _ = torch.max(torch.abs(out_S), dim=1)
906
+ S_avg = torch.mean(out_S, dim=1)
907
+
908
+ last_hidden = torch.cat(
909
+ [T_max, T_avg, S_max, S_avg, master.squeeze(1)], dim=1)
910
+
911
+ last_hidden = self.drop(last_hidden)
912
+ output = self.out_layer(last_hidden)
913
+
914
+ return output
915
+
916
+
917
+ class SpectraAASIST3(nn.Module, PyTorchModelHubMixin):
918
+ def __init__(self, **kwargs):
919
+ super().__init__()
920
+ self.ssl_encoder = Wav2Vec2Encoder("facebook/wav2vec2-xls-r-300m",
921
+ 1024,
922
+ None,
923
+ 0,
924
+ False,
925
+ False,
926
+ False)
927
+ self.bridge = MLPBridge(1024,
928
+ 128,
929
+ hidden_dim=128, dropout=0.1, activation=nn.SELU(), n_layers=1)
930
+ self.aasist = KANAASIST(
931
+ d_args={
932
+ "architecture": "AASIST",
933
+ "nb_samp": 64400,
934
+ "filts": [128, [1, 32], [32, 32], [32, 64], [64, 64]],
935
+ "gat_dims": [64, 32],
936
+ "pool_ratios": [0.5, 0.5, 0.5, 0.5],
937
+ "temperatures": [2.0, 2.0, 100.0, 100.0]
938
+ },
939
+ size=200,
940
+ layer_type="KANLinear"
941
+ )
942
+
943
+ def forward(self, x):
944
+ x = self.ssl_encoder(x)
945
+ x = self.bridge(x)
946
+ x = self.aasist(x)
947
+ return x
948
+
949
+ @torch.inference_mode()
950
+ def classify(self, x, threshold: float = -1.0625009):
951
+ x = self.forward(x)[:, 1]
952
+ x = (x > threshold).float()
953
+ return x.item()
954
+
955
+ spectra_aasist3 = SpectraAASIST3
model.safetensors ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:327762caab3106ecee5f26b821ccdc6ab1fe34906a3e7042f416c87927df2088
3
+ size 1276330956