AnodHuang commited on
Commit
8a5bd90
·
verified ·
1 Parent(s): 1603129

Upload verify_matty.py

Browse files
Files changed (1) hide show
  1. verify_matty.py +214 -0
verify_matty.py ADDED
@@ -0,0 +1,214 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import io
3
+ import time
4
+ import numpy as np
5
+ import torch
6
+ import soundfile as sf
7
+ from tqdm import tqdm
8
+ import pyarrow.parquet as pq
9
+
10
+ from transformers import AutoFeatureExtractor, ASTForAudioClassification
11
+
12
+ # =========================
13
+ # 0) 你只改这里
14
+ # =========================
15
+ PARQUET_DIR = r"D:\capstone\asv_spoof\parquet"
16
+ MODEL_DIR = r"D:\capstone\models\matty_snr"
17
+
18
+ SPLIT = "test" # "train" / "validation" / "test"
19
+ BATCH_SIZE = 32 # 4090 可 64
20
+ CPU_THREADS = 8 # 影响音频解码/预处理
21
+
22
+ # key 的定义:根据你的数据分布 & system_id 对齐: key=1 是 spoof,key=0 是 bonafide
23
+ # (system_id: '-' 是 bonafide;'Axx' 是 spoof)
24
+ KEY_SPOOF_VALUE = 1
25
+
26
+ PARQUET_FILE = os.path.join(PARQUET_DIR, f"{SPLIT}-00000-of-00001.parquet")
27
+
28
+ # 是否做 system_id 与 key 的一致性检查(不影响推理,只打印检查结果)
29
+ CHECK_LABEL_CONSISTENCY = True
30
+
31
+
32
+ # =========================
33
+ # 1) 音频解码/重采样(不落盘)
34
+ # =========================
35
+ def decode_audio(bytes_blob: bytes | None, path_str: str | None):
36
+ if bytes_blob is not None:
37
+ wav, sr = sf.read(io.BytesIO(bytes_blob), dtype="float32", always_2d=False)
38
+ else:
39
+ if not path_str or not os.path.exists(path_str):
40
+ raise RuntimeError("audio.bytes 为空,且 audio.path 不存在/不可用")
41
+ wav, sr = sf.read(path_str, dtype="float32", always_2d=False)
42
+
43
+ if isinstance(wav, np.ndarray) and wav.ndim > 1:
44
+ wav = wav.mean(axis=1)
45
+ return wav.astype(np.float32), int(sr)
46
+
47
+
48
+ def simple_resample(wav: np.ndarray, sr: int, new_sr: int) -> np.ndarray:
49
+ if sr == new_sr:
50
+ return wav
51
+ if wav.size == 0:
52
+ return wav
53
+ x_old = np.linspace(0, 1, num=wav.shape[0], endpoint=False)
54
+ new_len = int(round(wav.shape[0] * (new_sr / sr)))
55
+ x_new = np.linspace(0, 1, num=new_len, endpoint=False)
56
+ return np.interp(x_new, x_old, wav).astype(np.float32)
57
+
58
+
59
+ def key_to_label01(k) -> int:
60
+ # parquet 里 key 是 int64,但有时 to_pylist 可能给 int 或 str
61
+ v = int(k)
62
+ return 1 if v == KEY_SPOOF_VALUE else 0
63
+
64
+
65
+ def system_id_to_label01(sid: str) -> int:
66
+ sid = str(sid).strip()
67
+ return 0 if sid == "-" else 1 # '-' bonafide, 'Axx' spoof
68
+
69
+
70
+ # =========================
71
+ # 2) 设备 & 模型
72
+ # =========================
73
+ torch.set_num_threads(CPU_THREADS)
74
+
75
+ device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
76
+ print("Device:", device)
77
+ if device.type == "cuda":
78
+ print("GPU:", torch.cuda.get_device_name(0))
79
+ torch.backends.cudnn.benchmark = True
80
+
81
+ use_amp = (device.type == "cuda")
82
+
83
+ extractor = AutoFeatureExtractor.from_pretrained(MODEL_DIR)
84
+ model = ASTForAudioClassification.from_pretrained(MODEL_DIR).to(device).eval()
85
+ target_sr = getattr(extractor, "sampling_rate", 16000)
86
+
87
+ # =========================
88
+ # 3) 读 parquet
89
+ # =========================
90
+ pf = pq.ParquetFile(PARQUET_FILE)
91
+ num_rows = pf.metadata.num_rows
92
+ num_batches = (num_rows + BATCH_SIZE - 1) // BATCH_SIZE
93
+
94
+ print(f"Parquet: {PARQUET_FILE}")
95
+ print(f"Rows: {num_rows}, Batches: {num_batches}, BatchSize: {BATCH_SIZE}")
96
+
97
+ # =========================
98
+ # 4) 推理 + 指标统计
99
+ # =========================
100
+ correct = 0
101
+ total = 0
102
+ tp = fp = tn = fn = 0 # pos=spoof=1
103
+
104
+ # 可选:检查 key 与 system_id 是否一致
105
+ mismatch = 0
106
+ checked = 0
107
+
108
+ t0 = time.time()
109
+ with torch.no_grad():
110
+ pbar = tqdm(total=num_batches, desc=f"Predicting [{SPLIT}]", unit="batch")
111
+
112
+ for rb in pf.iter_batches(batch_size=BATCH_SIZE, columns=["audio", "key", "system_id"]):
113
+ audio_struct = rb.column(rb.schema.get_field_index("audio"))
114
+ key_arr = rb.column(rb.schema.get_field_index("key"))
115
+ sys_arr = rb.column(rb.schema.get_field_index("system_id"))
116
+
117
+ bytes_arr = audio_struct.field("bytes") if audio_struct.type.get_field_index("bytes") != -1 else None
118
+ path_arr = audio_struct.field("path") if audio_struct.type.get_field_index("path") != -1 else None
119
+
120
+ keys = key_arr.to_pylist()
121
+ sysids = sys_arr.to_pylist()
122
+ bytes_list = bytes_arr.to_pylist() if bytes_arr is not None else [None] * len(keys)
123
+ path_list = path_arr.to_pylist() if path_arr is not None else [None] * len(keys)
124
+
125
+ waves = []
126
+ labels = []
127
+
128
+ for b, p, k, sid in zip(bytes_list, path_list, keys, sysids):
129
+ y = key_to_label01(k)
130
+ labels.append(y)
131
+
132
+ if CHECK_LABEL_CONSISTENCY:
133
+ y2 = system_id_to_label01(sid)
134
+ checked += 1
135
+ if y != y2:
136
+ mismatch += 1
137
+
138
+ wav, sr = decode_audio(b, p)
139
+ wav = simple_resample(wav, sr, target_sr)
140
+ waves.append(wav)
141
+
142
+ inputs = extractor(
143
+ waves,
144
+ sampling_rate=target_sr,
145
+ return_tensors="pt",
146
+ padding=True,
147
+ )
148
+ inputs = {k: v.to(device, non_blocking=True) for k, v in inputs.items()}
149
+ labels_t = torch.tensor(labels, dtype=torch.long, device=device)
150
+
151
+ if use_amp:
152
+ with torch.amp.autocast("cuda"):
153
+ logits = model(**inputs).logits
154
+ else:
155
+ logits = model(**inputs).logits
156
+
157
+ preds = torch.argmax(logits, dim=-1)
158
+
159
+ total += labels_t.numel()
160
+ correct += (preds == labels_t).sum().item()
161
+
162
+ tp += ((preds == 1) & (labels_t == 1)).sum().item()
163
+ fp += ((preds == 1) & (labels_t == 0)).sum().item()
164
+ tn += ((preds == 0) & (labels_t == 0)).sum().item()
165
+ fn += ((preds == 0) & (labels_t == 1)).sum().item()
166
+
167
+ pbar.update(1)
168
+
169
+ pbar.close()
170
+
171
+ elapsed = time.time() - t0
172
+
173
+ # =========================
174
+ # 5) 计算指标
175
+ # =========================
176
+ acc = correct / max(total, 1)
177
+
178
+ eps = 1e-12
179
+ precision = tp / (tp + fp + eps)
180
+ recall = tp / (tp + fn + eps) # TPR
181
+ f1 = 2 * precision * recall / (precision + recall + eps)
182
+ fnr = fn / (fn + tp + eps)
183
+ fpr = fp / (fp + tn + eps)
184
+
185
+ print("\n===== Summary =====")
186
+ print(f"Split : {SPLIT}")
187
+ print(f"Accuracy : {acc:.6f} ({correct}/{total})")
188
+ print(f"Confusion : TP={tp}, FP={fp}, TN={tn}, FN={fn}")
189
+ print(f"Time : {elapsed:.2f}s, {total / max(elapsed,1e-9):.2f} samples/s")
190
+
191
+ if CHECK_LABEL_CONSISTENCY:
192
+ print(f"Label check: key vs system_id mismatches = {mismatch}/{checked}")
193
+
194
+ print("\n===== Metrics (pos=spoof=1) =====")
195
+ print(f"Precision : {precision:.6f}")
196
+ print(f"FNR : {fnr:.6f}")
197
+ print(f"FPR : {fpr:.6f}")
198
+ print(f"F1-score : {f1:.6f}")
199
+
200
+
201
+ '''
202
+ ===== Summary =====
203
+ Split : test
204
+ Accuracy : 0.898845 (64031/71237)
205
+ Confusion : TP=57091, FP=415, TN=6940, FN=6791
206
+ Time : 1135.30s, 62.75 samples/s
207
+ Label check: key vs system_id mismatches = 0/71237
208
+
209
+ ===== Metrics (pos=spoof=1) =====
210
+ Precision : 0.992783
211
+ FNR : 0.106305
212
+ FPR : 0.056424
213
+ F1-score : 0.940637
214
+ '''