Spaces:
Runtime error
Runtime error
| # Purpose: ArtifactNet 7ch inference pipeline — HF Spaces (CPU, ONNX Runtime) | |
| # Dependencies: onnxruntime, torch (HPSS/Mel only), huggingface_hub, scipy | |
| """ArtifactNet v9.4 inference — onnxruntime CPU. | |
| UNet + CNN 은 .onnx (public-safe) 로 실행, HPSS + Mel + 7ch feature 는 | |
| pytorch CPU 로 처리 (가중치 없는 고정 연산이라 노출 위험 없음). | |
| """ | |
| import os | |
| from pathlib import Path | |
| import numpy as np | |
| import onnxruntime as ort | |
| import torch | |
| from huggingface_hub import hf_hub_download | |
| from scipy import stats as sp_stats | |
| from config import ( | |
| HF_MODEL_REPO, UNET_ONNX_FILENAME, CNN_ONNX_FILENAME, | |
| SR, N_FFT, HOP_LENGTH, CHUNK_SAMPLES, BATCH_SIZE, | |
| ) | |
| from .audio_utils import sliding_chunks | |
| from .model import ( | |
| DifferentiableMel, hpss_gpu_pure, compute_forensic_features_7ch, | |
| ) | |
| N_MELS = 128 | |
| FREQ_BANDS = [ | |
| ("sub", 0, 250), | |
| ("low", 250, 2000), | |
| ("mid", 2000, 6000), | |
| ("hi_mid", 6000, 10000), | |
| ("hi", 10000, 16000), | |
| ("air", 16000, 22050), | |
| ] | |
| # ============================================================ | |
| # Lazy singletons | |
| # ============================================================ | |
| _unet_sess: ort.InferenceSession | None = None | |
| _cnn_sess: ort.InferenceSession | None = None | |
| _mel: DifferentiableMel | None = None | |
| _stft_window: torch.Tensor | None = None | |
| def _ort_threads() -> int: | |
| """HF Spaces CPU basic = 2 vCPU. 환경변수로 override 가능.""" | |
| try: | |
| return int(os.environ.get("ORT_THREADS", "2")) | |
| except ValueError: | |
| return 2 | |
| def _resolve_onnx(filename: str, env_var: str) -> str: | |
| """로컬 override (ARTIFACTNET_UNET_ONNX / _CNN_ONNX) 있으면 그걸 사용, 아니면 HF Hub.""" | |
| local = os.environ.get(env_var) | |
| if local and Path(local).is_file(): | |
| return local | |
| return hf_hub_download(HF_MODEL_REPO, filename) | |
| def load_models(): | |
| """ONNX 세션 + Mel/Window 초기화 (import 후 1회).""" | |
| global _unet_sess, _cnn_sess, _mel, _stft_window | |
| if _unet_sess is not None: | |
| return | |
| unet_path = _resolve_onnx(UNET_ONNX_FILENAME, "ARTIFACTNET_UNET_ONNX") | |
| cnn_path = _resolve_onnx(CNN_ONNX_FILENAME, "ARTIFACTNET_CNN_ONNX") | |
| opts = ort.SessionOptions() | |
| opts.intra_op_num_threads = _ort_threads() | |
| opts.inter_op_num_threads = 1 | |
| opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL | |
| _unet_sess = ort.InferenceSession(unet_path, sess_options=opts, | |
| providers=["CPUExecutionProvider"]) | |
| _cnn_sess = ort.InferenceSession(cnn_path, sess_options=opts, | |
| providers=["CPUExecutionProvider"]) | |
| _mel = DifferentiableMel(sr=SR, n_fft=N_FFT, n_mels=N_MELS) | |
| _mel.eval() | |
| _stft_window = torch.hann_window(N_FFT) | |
| print(f"[hf-spaces] ONNX sessions ready (intra_threads={_ort_threads()})", flush=True) | |
| # ============================================================ | |
| # Feature extraction helpers (75-dim Router + 28-dim Verdict) | |
| # ============================================================ | |
| def _extract_router_verdict_features( | |
| all_mag, all_res, all_H, all_P, all_mask, all_mel_res, probs, | |
| ): | |
| """infer.py extract_features()와 동일한 로직 (device=CPU).""" | |
| freq_hz = torch.linspace(0, SR / 2, all_mag.shape[2]) | |
| orig_total = all_mag.pow(2).mean().item() + 1e-8 | |
| res_total = all_res.pow(2).mean().item() + 1e-8 | |
| band_idx = [] | |
| for _, flo, fhi in FREQ_BANDS: | |
| lo = (freq_hz >= flo).nonzero(as_tuple=True)[0] | |
| hi = (freq_hz >= fhi).nonzero(as_tuple=True)[0] | |
| band_idx.append(( | |
| lo[0].item() if len(lo) else 0, | |
| hi[0].item() if len(hi) else all_mag.shape[2], | |
| )) | |
| rf = [] | |
| for i0, i1 in band_idx: | |
| oe = all_mag[:, :, i0:i1, :].pow(2).mean().item() / orig_total | |
| re = all_res[:, :, i0:i1, :].pow(2).mean().item() / res_total | |
| rf.extend([oe, re, re / (oe + 1e-8)]) | |
| mel_profile = all_mel_res.mean(dim=[0, 3]).squeeze().cpu().numpy() | |
| step = N_MELS // 32 | |
| compressed = mel_profile[:32 * step].reshape(32, step).mean(axis=1) | |
| compressed = compressed - compressed.mean() | |
| norm = np.abs(compressed).max() + 1e-8 | |
| rf.extend((compressed / norm).tolist()) | |
| H_total = all_H.pow(2).mean().item() + 1e-8 | |
| P_total = all_P.pow(2).mean().item() + 1e-8 | |
| hp_ratio = H_total / (H_total + P_total) | |
| rf.append(hp_ratio) | |
| for i0, i1 in band_idx: | |
| rf.extend([ | |
| all_H[:, :, i0:i1, :].pow(2).mean().item() / H_total, | |
| all_P[:, :, i0:i1, :].pow(2).mean().item() / P_total, | |
| ]) | |
| mask_np = all_mask.cpu().numpy().flatten() | |
| rf.extend([ | |
| float(mask_np.mean()), float(mask_np.std()), | |
| float(np.percentile(mask_np, 10)), float(np.percentile(mask_np, 25)), | |
| float(np.percentile(mask_np, 75)), float(np.percentile(mask_np, 90)), | |
| float(np.median(mask_np)), | |
| ]) | |
| rf.extend([ | |
| float(probs.mean()), float(probs.std()), float(np.median(probs)), | |
| float(np.percentile(probs, 10)), float(np.percentile(probs, 90)), | |
| ]) | |
| router_feat = np.nan_to_num(np.array(rf, dtype=np.float32)) | |
| arr = probs.astype(np.float64) | |
| n = len(arr) | |
| cnn_20 = np.array([ | |
| n, arr.mean(), arr.std(), np.median(arr), | |
| arr.min(), arr.max(), arr.max() - arr.min(), | |
| np.percentile(arr, 10), np.percentile(arr, 25), | |
| np.percentile(arr, 75), np.percentile(arr, 90), | |
| (arr >= 0.3).mean(), (arr >= 0.5).mean(), | |
| (arr >= 0.7).mean(), (arr >= 0.8).mean(), (arr >= 0.9).mean(), | |
| float(sp_stats.skew(arr)) if n >= 3 else 0.0, | |
| float(sp_stats.kurtosis(arr)) if n >= 3 else 0.0, | |
| float(np.diff(arr).std()) if n >= 2 else 0.0, | |
| float(np.abs(np.diff(arr)).max()) if n >= 2 else 0.0, | |
| ], dtype=np.float32) | |
| hf8k_i = (freq_hz >= 8000).nonzero(as_tuple=True)[0] | |
| hf8k_i = hf8k_i[0].item() if len(hf8k_i) else all_mag.shape[2] | |
| ai0, ai1 = band_idx[5] | |
| res_8 = np.array([ | |
| all_res[:, :, hf8k_i:, :].pow(2).mean().item() / res_total, | |
| all_res[:, :, ai0:ai1, :].pow(2).mean().item() / res_total, | |
| all_H[:, :, ai0:ai1, :].pow(2).mean().item() / H_total, | |
| all_P[:, :, ai0:ai1, :].pow(2).mean().item() / P_total, | |
| float(mel_profile[-1]), | |
| float(mel_profile[0]), | |
| float(mask_np.mean()), | |
| float(hp_ratio), | |
| ], dtype=np.float32) | |
| verdict_feat = np.nan_to_num(np.concatenate([cnn_20, res_8])) | |
| return router_feat, verdict_feat | |
| # ============================================================ | |
| # Inference | |
| # ============================================================ | |
| def run_e2e_inference(wav_mono_tensor: torch.Tensor): | |
| """mono waveform -> (probs, placeholder, metadata, forensic_stats, router_feat, verdict_feat). | |
| ONNX Runtime CPU + pytorch HPSS/Mel. | |
| """ | |
| if _unet_sess is None: | |
| load_models() | |
| chunk_data = sliding_chunks(wav_mono_tensor, CHUNK_SAMPLES) | |
| if not chunk_data: | |
| return [], torch.zeros_like(wav_mono_tensor), [], {}, \ | |
| np.zeros(75, dtype=np.float32), np.zeros(28, dtype=np.float32) | |
| chunks = [chunk for chunk, _ in chunk_data] | |
| metadata_list = [meta for _, meta in chunk_data] | |
| probs = [] | |
| all_features = [] | |
| all_mag_list, all_res_list, all_H_list, all_P_list = [], [], [], [] | |
| all_mask_list, all_mel_res_list = [], [] | |
| for i in range(0, len(chunks), BATCH_SIZE): | |
| batch = torch.stack(chunks[i:i + BATCH_SIZE]) # (B, CHUNK_SAMPLES) | |
| # STFT (torch, CPU) | |
| stft = torch.stft( | |
| batch, N_FFT, HOP_LENGTH, | |
| window=_stft_window, return_complex=True) | |
| stft_mag = stft.abs().unsqueeze(1) # (B, 1, F, T) | |
| # UNet mask via ONNX | |
| mask_np = _unet_sess.run( | |
| ["mask"], | |
| {"stft_mag": stft_mag.numpy().astype(np.float32)}, | |
| )[0] | |
| mask = torch.from_numpy(mask_np) | |
| res_mag = mask * stft_mag | |
| # HPSS — CPU median filter (unfold + median) 로 학습 분포 유지. | |
| # librosa.decompose.hpss 는 결과가 달라 v9.4 CNN 오판 (CLAUDE.md 경고 참조). | |
| H_mag, P_mag = hpss_gpu_pure(res_mag) | |
| # Mel 3-band | |
| mel_res = _mel(res_mag) | |
| mel_H = _mel(H_mag) | |
| mel_P = _mel(P_mag) | |
| features_7ch = compute_forensic_features_7ch(mel_res, mel_H, mel_P) | |
| all_features.append(features_7ch) | |
| # CNN logit via ONNX → sigmoid | |
| logits = _cnn_sess.run( | |
| ["logit"], | |
| {"features_7ch": features_7ch.numpy().astype(np.float32)}, | |
| )[0] | |
| batch_probs = (1.0 / (1.0 + np.exp(-np.clip(logits, -30, 30)))).tolist() | |
| probs.extend(batch_probs) | |
| all_mag_list.append(stft_mag) | |
| all_res_list.append(res_mag) | |
| all_H_list.append(H_mag) | |
| all_P_list.append(P_mag) | |
| all_mask_list.append(mask) | |
| all_mel_res_list.append(mel_res) | |
| if all_features: | |
| all_feat_tensor = torch.cat(all_features, dim=0) | |
| channel_means = all_feat_tensor.mean(dim=[2, 3]) | |
| feature_medians = channel_means.median(dim=0).values | |
| feat_min = channel_means.min(dim=0).values | |
| feat_max = channel_means.max(dim=0).values | |
| feat_range = feat_max - feat_min + 1e-8 | |
| normalized = ((feature_medians - feat_min) / feat_range).clamp(0, 1) | |
| forensic_stats = { | |
| "residual_energy": float(normalized[0]), | |
| "harmonic_strength": float(normalized[1]), | |
| "percussive_strength": float(normalized[2]), | |
| "temporal_delta": float(normalized[3]), | |
| "temporal_accel": float(normalized[4]), | |
| "hp_ratio": float(normalized[5]), | |
| "spectral_flux": float(normalized[6]), | |
| } | |
| else: | |
| forensic_stats = {} | |
| probs_arr = np.array(probs, dtype=np.float32) | |
| if all_mag_list: | |
| all_mag = torch.cat(all_mag_list, dim=0) | |
| all_res = torch.cat(all_res_list, dim=0) | |
| all_H = torch.cat(all_H_list, dim=0) | |
| all_P = torch.cat(all_P_list, dim=0) | |
| all_mask = torch.cat(all_mask_list, dim=0) | |
| all_mel_res = torch.cat(all_mel_res_list, dim=0) | |
| router_feat, verdict_feat = _extract_router_verdict_features( | |
| all_mag, all_res, all_H, all_P, all_mask, all_mel_res, probs_arr, | |
| ) | |
| else: | |
| router_feat = np.zeros(75, dtype=np.float32) | |
| verdict_feat = np.zeros(28, dtype=np.float32) | |
| residual_placeholder = torch.zeros_like(wav_mono_tensor) | |
| return probs, residual_placeholder, metadata_list, forensic_stats, router_feat, verdict_feat | |