AnodHuang commited on
Commit
7b67210
·
verified ·
1 Parent(s): 31679dc

Upload 5 files

Browse files
config.json ADDED
@@ -0,0 +1,77 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "_name_or_path": "facebook/wav2vec2-base-960h",
3
+ "activation_dropout": 0.1,
4
+ "apply_spec_augment": true,
5
+ "architectures": [
6
+ "Wav2Vec2ForCTC"
7
+ ],
8
+ "attention_dropout": 0.1,
9
+ "bos_token_id": 1,
10
+ "codevector_dim": 256,
11
+ "contrastive_logits_temperature": 0.1,
12
+ "conv_bias": false,
13
+ "conv_dim": [
14
+ 512,
15
+ 512,
16
+ 512,
17
+ 512,
18
+ 512,
19
+ 512,
20
+ 512
21
+ ],
22
+ "conv_kernel": [
23
+ 10,
24
+ 3,
25
+ 3,
26
+ 3,
27
+ 3,
28
+ 2,
29
+ 2
30
+ ],
31
+ "conv_stride": [
32
+ 5,
33
+ 2,
34
+ 2,
35
+ 2,
36
+ 2,
37
+ 2,
38
+ 2
39
+ ],
40
+ "ctc_loss_reduction": "sum",
41
+ "ctc_zero_infinity": false,
42
+ "diversity_loss_weight": 0.1,
43
+ "do_stable_layer_norm": false,
44
+ "eos_token_id": 2,
45
+ "feat_extract_activation": "gelu",
46
+ "feat_extract_dropout": 0.0,
47
+ "feat_extract_norm": "group",
48
+ "feat_proj_dropout": 0.1,
49
+ "feat_quantizer_dropout": 0.0,
50
+ "final_dropout": 0.1,
51
+ "gradient_checkpointing": false,
52
+ "hidden_act": "gelu",
53
+ "hidden_dropout": 0.1,
54
+ "hidden_dropout_prob": 0.1,
55
+ "hidden_size": 768,
56
+ "initializer_range": 0.02,
57
+ "intermediate_size": 3072,
58
+ "layer_norm_eps": 1e-05,
59
+ "layerdrop": 0.1,
60
+ "mask_feature_length": 10,
61
+ "mask_feature_prob": 0.0,
62
+ "mask_time_length": 10,
63
+ "mask_time_prob": 0.05,
64
+ "model_type": "wav2vec2",
65
+ "num_attention_heads": 12,
66
+ "num_codevector_groups": 2,
67
+ "num_codevectors_per_group": 320,
68
+ "num_conv_pos_embedding_groups": 16,
69
+ "num_conv_pos_embeddings": 128,
70
+ "num_feat_extract_layers": 7,
71
+ "num_hidden_layers": 12,
72
+ "num_negatives": 100,
73
+ "pad_token_id": 0,
74
+ "proj_codevector_dim": 256,
75
+ "transformers_version": "4.7.0.dev0",
76
+ "vocab_size": 32
77
+ }
preprocessor_config.json ADDED
@@ -0,0 +1,8 @@
 
 
 
 
 
 
 
 
 
1
+ {
2
+ "do_normalize": true,
3
+ "feature_size": 1,
4
+ "padding_side": "right",
5
+ "padding_value": 0.0,
6
+ "return_attention_mask": false,
7
+ "sampling_rate": 16000
8
+ }
pytorch_model.bin ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ version https://git-lfs.github.com/spec/v1
2
+ oid sha256:c34f9827b034a1b9141dbf6f652f8a60eda61cdf5771c9e05bfa99033c92cd96
3
+ size 377667514
train_wav2vec_base.py ADDED
@@ -0,0 +1,379 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # train_wav2vec2_stream_4090_winfix.py
2
+ import os
3
+ # ✅ 关键修复:禁止 torch.compile/torchdynamo 触发 cProfile/profile 冲突
4
+ os.environ["TORCHDYNAMO_DISABLE"] = "1"
5
+ os.environ["TORCH_COMPILE_DISABLE"] = "1"
6
+
7
+ import json
8
+ import time
9
+ import math
10
+ import argparse
11
+ from glob import glob
12
+ import io
13
+
14
+ import numpy as np
15
+ import torch
16
+ import torch.nn.functional as F
17
+ from torch.utils.data import DataLoader
18
+
19
+ from datasets import load_dataset, Audio
20
+ from transformers import AutoFeatureExtractor, Wav2Vec2ForSequenceClassification
21
+
22
+ import soundfile as sf
23
+
24
+
25
+ # ==============
26
+ # 默认:离线 + 国内环境
27
+ # ==============
28
+ os.environ.setdefault("HF_HUB_OFFLINE", "1")
29
+ os.environ.setdefault("TRANSFORMERS_OFFLINE", "1")
30
+
31
+ # ✅ 再保险:显式 disable dynamo(某些 torch 版本更稳)
32
+ try:
33
+ import torch._dynamo
34
+ torch._dynamo.disable()
35
+ except Exception:
36
+ pass
37
+
38
+ AUDIO_COL = "wav"
39
+ PARQUET_KEY_COL = "__key__"
40
+ JSONL_KEY_COL = "member"
41
+ JSONL_LABEL_COL = "key"
42
+
43
+
44
+ def parse_args():
45
+ p = argparse.ArgumentParser()
46
+
47
+ p.add_argument("--data_dir", type=str, default=r"./ASV_Spoof_2019_LA_SNR_50MB")
48
+ p.add_argument("--model_dir", type=str, default=r"./wav2vecbase")
49
+ p.add_argument("--out", type=str, default="./wav2vec2_stream_out_4090")
50
+
51
+ p.add_argument("--sr", type=int, default=16000)
52
+ p.add_argument("--max_sec", type=float, default=6.0)
53
+
54
+ p.add_argument("--epochs", type=int, default=3)
55
+ p.add_argument("--batch", type=int, default=16)
56
+ p.add_argument("--grad_accum", type=int, default=1)
57
+
58
+ p.add_argument("--lr", type=float, default=2e-5)
59
+ p.add_argument("--weight_decay", type=float, default=0.01)
60
+
61
+ p.add_argument("--log_every", type=int, default=20)
62
+ p.add_argument("--eval_every_epoch", action="store_true", default=True)
63
+
64
+ p.add_argument("--train_buffer_shuffle", type=int, default=50000)
65
+
66
+ p.add_argument("--val_take", type=int, default=0)
67
+ p.add_argument("--fp16", action="store_true", default=True)
68
+
69
+ # Windows 更稳:2~4
70
+ p.add_argument("--num_workers", type=int, default=2)
71
+ p.add_argument("--pin_memory", action="store_true", default=True)
72
+
73
+ p.add_argument("--train_size_hint", type=int, default=45600)
74
+
75
+ return p.parse_args()
76
+
77
+
78
+ def find_parquet_files(data_dir: str, split: str):
79
+ base = os.path.join(data_dir, "default")
80
+ pat = {"train": "partial-train", "validation": "partial-validation", "test": "partial-test"}[split]
81
+ files = sorted(glob(os.path.join(base, pat, "*.parquet")))
82
+ if not files:
83
+ raise FileNotFoundError(f"没找到 {split} parquet: {os.path.join(base, pat)}/*.parquet")
84
+ return files
85
+
86
+
87
+ def find_jsonl(data_dir: str, split: str):
88
+ cands = [
89
+ os.path.join(data_dir, "index", f"{split}.jsonl"),
90
+ os.path.join(data_dir, f"{split}.jsonl"),
91
+ os.path.join(data_dir, "default", "index", f"{split}.jsonl"),
92
+ os.path.join(data_dir, "default", f"{split}.jsonl"),
93
+ ]
94
+ for p in cands:
95
+ if os.path.isfile(p):
96
+ return p
97
+ raise FileNotFoundError(f"找不到 {split}.jsonl(建议放到 {data_dir}/index/{split}.jsonl)")
98
+
99
+
100
+ def load_member2label(jsonl_path: str):
101
+ m2l = {}
102
+ with open(jsonl_path, "r", encoding="utf-8") as f:
103
+ for line in f:
104
+ line = line.strip()
105
+ if not line:
106
+ continue
107
+ obj = json.loads(line)
108
+ m = obj.get(JSONL_KEY_COL, None)
109
+ k = obj.get(JSONL_LABEL_COL, None)
110
+ if m is None or k is None:
111
+ continue
112
+
113
+ if isinstance(k, (int, np.integer)):
114
+ label = 1 if int(k) == 1 else 0
115
+ else:
116
+ s = str(k).lower()
117
+ label = 1 if s == "bonafide" else 0
118
+
119
+ m2l[str(m)] = int(label)
120
+
121
+ if not m2l:
122
+ raise ValueError(f"{jsonl_path} 没读到任何 member->label")
123
+ return m2l
124
+
125
+
126
+ def decode_wav_any(w, target_sr: int):
127
+ if isinstance(w, dict):
128
+ if "bytes" in w and w["bytes"] is not None:
129
+ x, sr0 = sf.read(io.BytesIO(w["bytes"]), dtype="float32")
130
+ return x, sr0
131
+ if "array" in w and w["array"] is not None:
132
+ x = np.asarray(w["array"], dtype=np.float32)
133
+ sr0 = int(w.get("sampling_rate", target_sr))
134
+ return x, sr0
135
+
136
+ if isinstance(w, (bytes, bytearray)):
137
+ x, sr0 = sf.read(io.BytesIO(w), dtype="float32")
138
+ return x, sr0
139
+
140
+ x = np.asarray(w, dtype=np.float32)
141
+ return x, target_sr
142
+
143
+
144
+ def cheap_resample(x: np.ndarray, sr0: int, sr1: int):
145
+ if sr0 == sr1:
146
+ return x
147
+ n1 = int(round(len(x) * (sr1 / sr0)))
148
+ if n1 <= 1:
149
+ return x[:1]
150
+ idx = np.linspace(0, len(x) - 1, n1).astype(np.float64)
151
+ x0 = np.arange(len(x), dtype=np.float64)
152
+ y = np.interp(idx, x0, x).astype(np.float32)
153
+ return y
154
+
155
+
156
+ def disable_audio_decoding(ds, audio_col: str, sr: int):
157
+ if hasattr(ds, "decode"):
158
+ try:
159
+ return ds.decode(False)
160
+ except TypeError:
161
+ pass
162
+
163
+ if hasattr(ds, "cast_column"):
164
+ try:
165
+ return ds.cast_column(audio_col, Audio(decode=False))
166
+ except TypeError:
167
+ return ds.cast_column(audio_col, Audio(sampling_rate=sr))
168
+
169
+ return ds
170
+
171
+
172
+ class StreamCollator:
173
+ def __init__(self, feature_extractor, member2label, sr=16000, max_sec=6.0):
174
+ self.fe = feature_extractor
175
+ self.m2l = member2label
176
+ self.sr = sr
177
+ self.max_len = int(sr * max_sec)
178
+
179
+ def __call__(self, batch):
180
+ audios = []
181
+ labels = []
182
+
183
+ for ex in batch:
184
+ kk = str(ex.get(PARQUET_KEY_COL, "")) + ".wav"
185
+ if kk == "" or kk not in self.m2l:
186
+ raise ValueError(f"jsonl 找不到 member={kk} 的标签(检查 parquet.__key__ 与 jsonl.member 是否一致)")
187
+ labels.append(self.m2l[kk])
188
+
189
+ w = ex.get(AUDIO_COL, None)
190
+ if w is None:
191
+ raise ValueError(f"样本缺少音频列 {AUDIO_COL}")
192
+
193
+ x, sr0 = decode_wav_any(w, self.sr)
194
+ x = np.asarray(x, dtype=np.float32)
195
+ if x.ndim > 1:
196
+ x = x.mean(axis=-1)
197
+
198
+ if sr0 != self.sr:
199
+ x = cheap_resample(x, sr0, self.sr)
200
+
201
+ if len(x) >= self.max_len:
202
+ x = x[: self.max_len]
203
+ else:
204
+ x = np.pad(x, (0, self.max_len - len(x)))
205
+
206
+ audios.append(x)
207
+
208
+ inputs = self.fe(audios, sampling_rate=self.sr, return_tensors="pt", padding=True)
209
+ inputs["labels"] = torch.tensor(labels, dtype=torch.long)
210
+ return inputs
211
+
212
+
213
+ @torch.no_grad()
214
+ def eval_loop(model, dl, device, fp16: bool):
215
+ model.eval()
216
+ all_probs, all_preds, all_labels = [], [], []
217
+
218
+ for batch in dl:
219
+ batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
220
+ with torch.amp.autocast("cuda", enabled=fp16):
221
+ logits = model(**batch).logits
222
+
223
+ probs = F.softmax(logits, dim=-1)[:, 1]
224
+ preds = torch.argmax(logits, dim=-1)
225
+
226
+ all_probs.append(probs.detach().cpu().numpy())
227
+ all_preds.append(preds.detach().cpu().numpy())
228
+ all_labels.append(batch["labels"].detach().cpu().numpy())
229
+
230
+ probs = np.concatenate(all_probs) if all_probs else np.array([], dtype=np.float32)
231
+ preds = np.concatenate(all_preds) if all_preds else np.array([], dtype=np.int64)
232
+ labels = np.concatenate(all_labels) if all_labels else np.array([], dtype=np.int64)
233
+
234
+ acc = float((preds == labels).mean()) if len(labels) else float("nan")
235
+
236
+ tp = int(((preds == 1) & (labels == 1)).sum())
237
+ fp = int(((preds == 1) & (labels == 0)).sum())
238
+ fn = int(((preds == 0) & (labels == 1)).sum())
239
+ precision = tp / (tp + fp + 1e-9)
240
+ recall = tp / (tp + fn + 1e-9)
241
+ f1 = float(2 * precision * recall / (precision + recall + 1e-9))
242
+
243
+ roc_auc = float("nan")
244
+ if len(labels) and len(np.unique(labels)) == 2:
245
+ order = np.argsort(probs)
246
+ y = labels[order]
247
+ n_pos = (y == 1).sum()
248
+ n_neg = (y == 0).sum()
249
+ if n_pos > 0 and n_neg > 0:
250
+ ranks = np.arange(1, len(y) + 1)
251
+ sum_ranks_pos = ranks[y == 1].sum()
252
+ roc_auc = float((sum_ranks_pos - n_pos * (n_pos + 1) / 2) / (n_pos * n_neg))
253
+
254
+ model.train()
255
+ return {"acc": acc, "f1": f1, "roc_auc": roc_auc, "n": int(len(labels))}
256
+
257
+
258
+ def main():
259
+ args = parse_args()
260
+
261
+ assert torch.cuda.is_available(), "CUDA 不可用"
262
+ device = torch.device("cuda")
263
+ print("CUDA OK:", torch.cuda.get_device_name(0))
264
+
265
+ torch.backends.cudnn.benchmark = True
266
+ torch.backends.cuda.matmul.allow_tf32 = True
267
+ torch.backends.cudnn.allow_tf32 = True
268
+
269
+ train_files = find_parquet_files(args.data_dir, "train")
270
+ val_files = find_parquet_files(args.data_dir, "validation")
271
+ train_jsonl = find_jsonl(args.data_dir, "train")
272
+ val_jsonl = find_jsonl(args.data_dir, "validation")
273
+
274
+ train_m2l = load_member2label(train_jsonl)
275
+ val_m2l = load_member2label(val_jsonl)
276
+ print("labels loaded:", len(train_m2l), len(val_m2l))
277
+
278
+ train_stream = load_dataset("parquet", data_files={"train": train_files}, streaming=True)["train"]
279
+ train_stream = disable_audio_decoding(train_stream, AUDIO_COL, args.sr)
280
+ train_stream = train_stream.shuffle(buffer_size=args.train_buffer_shuffle, seed=42)
281
+
282
+ val_stream = load_dataset("parquet", data_files={"validation": val_files}, streaming=True)["validation"]
283
+ val_stream = disable_audio_decoding(val_stream, AUDIO_COL, args.sr)
284
+ if args.val_take and args.val_take > 0:
285
+ val_stream = val_stream.take(int(args.val_take))
286
+
287
+ fe = AutoFeatureExtractor.from_pretrained(args.model_dir, local_files_only=True)
288
+
289
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(
290
+ args.model_dir,
291
+ num_labels=2,
292
+ id2label={0: "spoof", 1: "bonafide"},
293
+ label2id={"spoof": 0, "bonafide": 1},
294
+ ignore_mismatched_sizes=True,
295
+ local_files_only=True,
296
+ ).to(device)
297
+
298
+ model.train()
299
+
300
+ train_collator = StreamCollator(fe, train_m2l, sr=args.sr, max_sec=args.max_sec)
301
+ val_collator = StreamCollator(fe, val_m2l, sr=args.sr, max_sec=args.max_sec)
302
+
303
+ train_dl = DataLoader(
304
+ train_stream,
305
+ batch_size=args.batch,
306
+ num_workers=args.num_workers,
307
+ pin_memory=args.pin_memory,
308
+ collate_fn=train_collator,
309
+ )
310
+ val_dl = DataLoader(
311
+ val_stream,
312
+ batch_size=args.batch,
313
+ num_workers=args.num_workers,
314
+ pin_memory=args.pin_memory,
315
+ collate_fn=val_collator,
316
+ )
317
+
318
+ # ✅ 这里现在不会再触发 torch._dynamo -> cProfile 了
319
+ optim = torch.optim.AdamW(model.parameters(), lr=args.lr, weight_decay=args.weight_decay)
320
+ scaler = torch.amp.GradScaler("cuda", enabled=args.fp16)
321
+
322
+ os.makedirs(args.out, exist_ok=True)
323
+
324
+ best_auc = -1.0
325
+ global_step = 0
326
+
327
+ steps_per_epoch = max(1, math.ceil(args.train_size_hint / max(1, args.batch)))
328
+ print(f"steps_per_epoch={steps_per_epoch} (train_size_hint={args.train_size_hint}, batch={args.batch})")
329
+
330
+ for epoch in range(1, args.epochs + 1):
331
+ print(f"\n===== EPOCH {epoch}/{args.epochs} =====")
332
+ t0 = time.time()
333
+ running = 0.0
334
+ seen = 0
335
+
336
+ it = iter(train_dl)
337
+ optim.zero_grad(set_to_none=True)
338
+
339
+ for step_in_epoch in range(steps_per_epoch):
340
+ batch = next(it)
341
+ batch = {k: v.to(device, non_blocking=True) for k, v in batch.items()}
342
+
343
+ with torch.amp.autocast("cuda", enabled=args.fp16):
344
+ loss = model(**batch).loss
345
+ loss_scaled = loss / args.grad_accum
346
+
347
+ scaler.scale(loss_scaled).backward()
348
+
349
+ if (step_in_epoch + 1) % args.grad_accum == 0:
350
+ scaler.step(optim)
351
+ scaler.update()
352
+ optim.zero_grad(set_to_none=True)
353
+
354
+ running += float(loss.item()) * batch["labels"].size(0)
355
+ seen += int(batch["labels"].size(0))
356
+ global_step += 1
357
+
358
+ if global_step % args.log_every == 0:
359
+ avg = running / max(1, seen)
360
+ dt = time.time() - t0
361
+ spd = seen / max(1e-9, dt)
362
+ mem = torch.cuda.memory_allocated() / (1024**3)
363
+ print(f"step {global_step:6d} | loss(avg)={avg:.4f} | samples={seen} | {spd:.1f} samp/s | mem={mem:.2f} GB")
364
+
365
+ if args.eval_every_epoch:
366
+ metrics = eval_loop(model, val_dl, device, fp16=args.fp16)
367
+ print(f"[VAL] n={metrics['n']} acc={metrics['acc']:.4f} f1={metrics['f1']:.4f} roc_auc={metrics['roc_auc']:.4f}")
368
+
369
+ last_dir = os.path.join(args.out, "last")
370
+ os.makedirs(last_dir, exist_ok=True)
371
+ model.save_pretrained(last_dir)
372
+ fe.save_pretrained(last_dir)
373
+ print(f"saved last to: {last_dir}")
374
+
375
+ print("\nDONE.")
376
+
377
+
378
+ if __name__ == "__main__":
379
+ main()
verify_wav2vecbase.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import time
4
+ import numpy as np
5
+ import torch
6
+ import soundfile as sf
7
+ from tqdm import tqdm
8
+ import pyarrow.parquet as pq
9
+
10
+ from transformers import (
11
+ Wav2Vec2FeatureExtractor,
12
+ Wav2Vec2ForSequenceClassification,
13
+ )
14
+
15
+ # =========================
16
+ # 0) 配置
17
+ # =========================
18
+ PARQUET_DIR = r"D:\capstone\asv_spoof\parquet"
19
+
20
+ # ✅ 如果是原始模型
21
+ MODEL_DIR = r"D:\capstone\wav2vecbase"
22
+ # ✅ 如果是你 fine-tune 后的模型
23
+ # MODEL_DIR = r"D:\capstone\models\wav2vec2_snr"
24
+
25
+ SPLIT = "test"
26
+ BATCH_SIZE = 32 # RTX 4060 推荐 16~32
27
+ CPU_THREADS = 8
28
+
29
+ KEY_SPOOF_VALUE = 1 # key=1 → spoof
30
+
31
+ PARQUET_FILE = os.path.join(PARQUET_DIR, f"{SPLIT}-00000-of-00001.parquet")
32
+ CHECK_LABEL_CONSISTENCY = True
33
+
34
+
35
+ # =========================
36
+ # 1) 音频解码
37
+ # =========================
38
+ def decode_audio(bytes_blob, path_str):
39
+ if bytes_blob is not None:
40
+ wav, sr = sf.read(io.BytesIO(bytes_blob), dtype="float32")
41
+ else:
42
+ wav, sr = sf.read(path_str, dtype="float32")
43
+
44
+ if wav.ndim > 1:
45
+ wav = wav.mean(axis=1)
46
+ return wav.astype(np.float32), int(sr)
47
+
48
+
49
+ def resample(wav, sr, target_sr):
50
+ if sr == target_sr:
51
+ return wav
52
+ x_old = np.linspace(0, 1, len(wav), endpoint=False)
53
+ new_len = int(len(wav) * target_sr / sr)
54
+ x_new = np.linspace(0, 1, new_len, endpoint=False)
55
+ return np.interp(x_new, x_old, wav).astype(np.float32)
56
+
57
+
58
+ def key_to_label(k):
59
+ return 1 if int(k) == KEY_SPOOF_VALUE else 0
60
+
61
+
62
+ def system_id_to_label(sid):
63
+ return 0 if str(sid).strip() == "-" else 1
64
+
65
+
66
+ # =========================
67
+ # 2) 设备 & 模型
68
+ # =========================
69
+ torch.set_num_threads(CPU_THREADS)
70
+
71
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
72
+ print("Device:", device)
73
+ if device.type == "cuda":
74
+ print("GPU:", torch.cuda.get_device_name(0))
75
+ torch.backends.cudnn.benchmark = True
76
+
77
+ use_amp = device.type == "cuda"
78
+
79
+ feature_extractor = Wav2Vec2FeatureExtractor.from_pretrained(MODEL_DIR)
80
+ model = Wav2Vec2ForSequenceClassification.from_pretrained(MODEL_DIR).to(device).eval()
81
+
82
+ target_sr = feature_extractor.sampling_rate # 16000
83
+
84
+
85
+ # =========================
86
+ # 3) 读 parquet
87
+ # =========================
88
+ pf = pq.ParquetFile(PARQUET_FILE)
89
+ num_rows = pf.metadata.num_rows
90
+ num_batches = (num_rows + BATCH_SIZE - 1) // BATCH_SIZE
91
+
92
+ print(f"Parquet: {PARQUET_FILE}")
93
+ print(f"Rows: {num_rows}, Batches: {num_batches}")
94
+
95
+
96
+ # =========================
97
+ # 4) 推理
98
+ # =========================
99
+ tp = fp = tn = fn = 0
100
+ correct = total = 0
101
+ mismatch = checked = 0
102
+
103
+ t0 = time.time()
104
+ with torch.no_grad():
105
+ pbar = tqdm(total=num_batches, desc=f"Predicting [{SPLIT}]", unit="batch")
106
+
107
+ for rb in pf.iter_batches(batch_size=BATCH_SIZE, columns=["audio", "key", "system_id"]):
108
+ audio_struct = rb.column(rb.schema.get_field_index("audio"))
109
+ key_arr = rb.column(rb.schema.get_field_index("key"))
110
+ sys_arr = rb.column(rb.schema.get_field_index("system_id"))
111
+
112
+ bytes_arr = audio_struct.field("bytes")
113
+ path_arr = audio_struct.field("path")
114
+
115
+ waves, labels = [], []
116
+
117
+ for b, p, k, sid in zip(
118
+ bytes_arr.to_pylist(),
119
+ path_arr.to_pylist(),
120
+ key_arr.to_pylist(),
121
+ sys_arr.to_pylist(),
122
+ ):
123
+ y = key_to_label(k)
124
+ labels.append(y)
125
+
126
+ if CHECK_LABEL_CONSISTENCY:
127
+ checked += 1
128
+ if y != system_id_to_label(sid):
129
+ mismatch += 1
130
+
131
+ wav, sr = decode_audio(b, p)
132
+ wav = resample(wav, sr, target_sr)
133
+ waves.append(wav)
134
+
135
+ inputs = feature_extractor(
136
+ waves,
137
+ sampling_rate=target_sr,
138
+ padding=True,
139
+ return_tensors="pt",
140
+ )
141
+ inputs = {k: v.to(device) for k, v in inputs.items()}
142
+ labels_t = torch.tensor(labels, device=device)
143
+
144
+ if use_amp:
145
+ with torch.amp.autocast("cuda"):
146
+ logits = model(**inputs).logits
147
+ else:
148
+ logits = model(**inputs).logits
149
+
150
+ preds = logits.argmax(dim=-1)
151
+
152
+ total += labels_t.numel()
153
+ correct += (preds == labels_t).sum().item()
154
+
155
+ tp += ((preds == 1) & (labels_t == 1)).sum().item()
156
+ fp += ((preds == 1) & (labels_t == 0)).sum().item()
157
+ tn += ((preds == 0) & (labels_t == 0)).sum().item()
158
+ fn += ((preds == 0) & (labels_t == 1)).sum().item()
159
+
160
+ pbar.update(1)
161
+
162
+ pbar.close()
163
+
164
+ elapsed = time.time() - t0
165
+
166
+
167
+ # =========================
168
+ # 5) 指标
169
+ # =========================
170
+ eps = 1e-12
171
+ acc = correct / max(total, 1)
172
+ precision = tp / (tp + fp + eps)
173
+ recall = tp / (tp + fn + eps)
174
+ f1 = 2 * precision * recall / (precision + recall + eps)
175
+ fnr = fn / (fn + tp + eps)
176
+ fpr = fp / (fp + tn + eps)
177
+
178
+ print("\n===== Summary =====")
179
+ print(f"Accuracy : {acc:.6f} ({correct}/{total})")
180
+ print(f"TP={tp}, FP={fp}, TN={tn}, FN={fn}")
181
+ print(f"Time : {elapsed:.2f}s, {total/elapsed:.2f} samples/s")
182
+
183
+ if CHECK_LABEL_CONSISTENCY:
184
+ print(f"Label check: key vs system_id mismatches = {mismatch}/{checked}")
185
+
186
+ print("\n===== Metrics (pos=spoof=1) =====")
187
+ print(f"Precision : {precision:.6f}")
188
+ print(f"Recall : {recall:.6f}")
189
+ print(f"FNR : {fnr:.6f}")
190
+ print(f"FPR : {fpr:.6f}")
191
+ print(f"F1-score : {f1:.6f}")
192
+
193
+ '''
194
+ ===== Summary =====
195
+ Accuracy : 0.896753 (63882/71237)
196
+ TP=63882, FP=7355, TN=0, FN=0
197
+ Time : 4266.32s, 16.70 samples/s
198
+ Label check: key vs system_id mismatches = 0/71237
199
+
200
+ ===== Metrics (pos=spoof=1) =====
201
+ Precision : 0.896753
202
+ Recall : 1.000000
203
+ FNR : 0.000000
204
+ FPR : 1.000000
205
+ F1-score : 0.945567
206
+
207
+ 进程已结束,退出代码为 0
208
+
209
+ '''