# Purpose: ArtifactNet 7ch inference pipeline — HF Spaces (CPU, ONNX Runtime) # Dependencies: onnxruntime, torch (HPSS/Mel only), huggingface_hub, scipy """ArtifactNet v9.4 inference — onnxruntime CPU. UNet + CNN 은 .onnx (public-safe) 로 실행, HPSS + Mel + 7ch feature 는 pytorch CPU 로 처리 (가중치 없는 고정 연산이라 노출 위험 없음). """ import os from pathlib import Path import numpy as np import onnxruntime as ort import torch from huggingface_hub import hf_hub_download from scipy import stats as sp_stats from config import ( HF_MODEL_REPO, UNET_ONNX_FILENAME, CNN_ONNX_FILENAME, SR, N_FFT, HOP_LENGTH, CHUNK_SAMPLES, BATCH_SIZE, ) from .audio_utils import sliding_chunks from .model import ( DifferentiableMel, hpss_gpu_pure, compute_forensic_features_7ch, ) N_MELS = 128 FREQ_BANDS = [ ("sub", 0, 250), ("low", 250, 2000), ("mid", 2000, 6000), ("hi_mid", 6000, 10000), ("hi", 10000, 16000), ("air", 16000, 22050), ] # ============================================================ # Lazy singletons # ============================================================ _unet_sess: ort.InferenceSession | None = None _cnn_sess: ort.InferenceSession | None = None _mel: DifferentiableMel | None = None _stft_window: torch.Tensor | None = None def _ort_threads() -> int: """HF Spaces CPU basic = 2 vCPU. 환경변수로 override 가능.""" try: return int(os.environ.get("ORT_THREADS", "2")) except ValueError: return 2 def _resolve_onnx(filename: str, env_var: str) -> str: """로컬 override (ARTIFACTNET_UNET_ONNX / _CNN_ONNX) 있으면 그걸 사용, 아니면 HF Hub.""" local = os.environ.get(env_var) if local and Path(local).is_file(): return local return hf_hub_download(HF_MODEL_REPO, filename) def load_models(): """ONNX 세션 + Mel/Window 초기화 (import 후 1회).""" global _unet_sess, _cnn_sess, _mel, _stft_window if _unet_sess is not None: return unet_path = _resolve_onnx(UNET_ONNX_FILENAME, "ARTIFACTNET_UNET_ONNX") cnn_path = _resolve_onnx(CNN_ONNX_FILENAME, "ARTIFACTNET_CNN_ONNX") opts = ort.SessionOptions() opts.intra_op_num_threads = _ort_threads() opts.inter_op_num_threads = 1 opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL _unet_sess = ort.InferenceSession(unet_path, sess_options=opts, providers=["CPUExecutionProvider"]) _cnn_sess = ort.InferenceSession(cnn_path, sess_options=opts, providers=["CPUExecutionProvider"]) _mel = DifferentiableMel(sr=SR, n_fft=N_FFT, n_mels=N_MELS) _mel.eval() _stft_window = torch.hann_window(N_FFT) print(f"[hf-spaces] ONNX sessions ready (intra_threads={_ort_threads()})", flush=True) # ============================================================ # Feature extraction helpers (75-dim Router + 28-dim Verdict) # ============================================================ def _extract_router_verdict_features( all_mag, all_res, all_H, all_P, all_mask, all_mel_res, probs, ): """infer.py extract_features()와 동일한 로직 (device=CPU).""" freq_hz = torch.linspace(0, SR / 2, all_mag.shape[2]) orig_total = all_mag.pow(2).mean().item() + 1e-8 res_total = all_res.pow(2).mean().item() + 1e-8 band_idx = [] for _, flo, fhi in FREQ_BANDS: lo = (freq_hz >= flo).nonzero(as_tuple=True)[0] hi = (freq_hz >= fhi).nonzero(as_tuple=True)[0] band_idx.append(( lo[0].item() if len(lo) else 0, hi[0].item() if len(hi) else all_mag.shape[2], )) rf = [] for i0, i1 in band_idx: oe = all_mag[:, :, i0:i1, :].pow(2).mean().item() / orig_total re = all_res[:, :, i0:i1, :].pow(2).mean().item() / res_total rf.extend([oe, re, re / (oe + 1e-8)]) mel_profile = all_mel_res.mean(dim=[0, 3]).squeeze().cpu().numpy() step = N_MELS // 32 compressed = mel_profile[:32 * step].reshape(32, step).mean(axis=1) compressed = compressed - compressed.mean() norm = np.abs(compressed).max() + 1e-8 rf.extend((compressed / norm).tolist()) H_total = all_H.pow(2).mean().item() + 1e-8 P_total = all_P.pow(2).mean().item() + 1e-8 hp_ratio = H_total / (H_total + P_total) rf.append(hp_ratio) for i0, i1 in band_idx: rf.extend([ all_H[:, :, i0:i1, :].pow(2).mean().item() / H_total, all_P[:, :, i0:i1, :].pow(2).mean().item() / P_total, ]) mask_np = all_mask.cpu().numpy().flatten() rf.extend([ float(mask_np.mean()), float(mask_np.std()), float(np.percentile(mask_np, 10)), float(np.percentile(mask_np, 25)), float(np.percentile(mask_np, 75)), float(np.percentile(mask_np, 90)), float(np.median(mask_np)), ]) rf.extend([ float(probs.mean()), float(probs.std()), float(np.median(probs)), float(np.percentile(probs, 10)), float(np.percentile(probs, 90)), ]) router_feat = np.nan_to_num(np.array(rf, dtype=np.float32)) arr = probs.astype(np.float64) n = len(arr) cnn_20 = np.array([ n, arr.mean(), arr.std(), np.median(arr), arr.min(), arr.max(), arr.max() - arr.min(), np.percentile(arr, 10), np.percentile(arr, 25), np.percentile(arr, 75), np.percentile(arr, 90), (arr >= 0.3).mean(), (arr >= 0.5).mean(), (arr >= 0.7).mean(), (arr >= 0.8).mean(), (arr >= 0.9).mean(), float(sp_stats.skew(arr)) if n >= 3 else 0.0, float(sp_stats.kurtosis(arr)) if n >= 3 else 0.0, float(np.diff(arr).std()) if n >= 2 else 0.0, float(np.abs(np.diff(arr)).max()) if n >= 2 else 0.0, ], dtype=np.float32) hf8k_i = (freq_hz >= 8000).nonzero(as_tuple=True)[0] hf8k_i = hf8k_i[0].item() if len(hf8k_i) else all_mag.shape[2] ai0, ai1 = band_idx[5] res_8 = np.array([ all_res[:, :, hf8k_i:, :].pow(2).mean().item() / res_total, all_res[:, :, ai0:ai1, :].pow(2).mean().item() / res_total, all_H[:, :, ai0:ai1, :].pow(2).mean().item() / H_total, all_P[:, :, ai0:ai1, :].pow(2).mean().item() / P_total, float(mel_profile[-1]), float(mel_profile[0]), float(mask_np.mean()), float(hp_ratio), ], dtype=np.float32) verdict_feat = np.nan_to_num(np.concatenate([cnn_20, res_8])) return router_feat, verdict_feat # ============================================================ # Inference # ============================================================ @torch.no_grad() def run_e2e_inference(wav_mono_tensor: torch.Tensor): """mono waveform -> (probs, placeholder, metadata, forensic_stats, router_feat, verdict_feat). ONNX Runtime CPU + pytorch HPSS/Mel. """ if _unet_sess is None: load_models() chunk_data = sliding_chunks(wav_mono_tensor, CHUNK_SAMPLES) if not chunk_data: return [], torch.zeros_like(wav_mono_tensor), [], {}, \ np.zeros(75, dtype=np.float32), np.zeros(28, dtype=np.float32) chunks = [chunk for chunk, _ in chunk_data] metadata_list = [meta for _, meta in chunk_data] probs = [] all_features = [] all_mag_list, all_res_list, all_H_list, all_P_list = [], [], [], [] all_mask_list, all_mel_res_list = [], [] for i in range(0, len(chunks), BATCH_SIZE): batch = torch.stack(chunks[i:i + BATCH_SIZE]) # (B, CHUNK_SAMPLES) # STFT (torch, CPU) stft = torch.stft( batch, N_FFT, HOP_LENGTH, window=_stft_window, return_complex=True) stft_mag = stft.abs().unsqueeze(1) # (B, 1, F, T) # UNet mask via ONNX mask_np = _unet_sess.run( ["mask"], {"stft_mag": stft_mag.numpy().astype(np.float32)}, )[0] mask = torch.from_numpy(mask_np) res_mag = mask * stft_mag # HPSS — CPU median filter (unfold + median) 로 학습 분포 유지. # librosa.decompose.hpss 는 결과가 달라 v9.4 CNN 오판 (CLAUDE.md 경고 참조). H_mag, P_mag = hpss_gpu_pure(res_mag) # Mel 3-band mel_res = _mel(res_mag) mel_H = _mel(H_mag) mel_P = _mel(P_mag) features_7ch = compute_forensic_features_7ch(mel_res, mel_H, mel_P) all_features.append(features_7ch) # CNN logit via ONNX → sigmoid logits = _cnn_sess.run( ["logit"], {"features_7ch": features_7ch.numpy().astype(np.float32)}, )[0] batch_probs = (1.0 / (1.0 + np.exp(-np.clip(logits, -30, 30)))).tolist() probs.extend(batch_probs) all_mag_list.append(stft_mag) all_res_list.append(res_mag) all_H_list.append(H_mag) all_P_list.append(P_mag) all_mask_list.append(mask) all_mel_res_list.append(mel_res) if all_features: all_feat_tensor = torch.cat(all_features, dim=0) channel_means = all_feat_tensor.mean(dim=[2, 3]) feature_medians = channel_means.median(dim=0).values feat_min = channel_means.min(dim=0).values feat_max = channel_means.max(dim=0).values feat_range = feat_max - feat_min + 1e-8 normalized = ((feature_medians - feat_min) / feat_range).clamp(0, 1) forensic_stats = { "residual_energy": float(normalized[0]), "harmonic_strength": float(normalized[1]), "percussive_strength": float(normalized[2]), "temporal_delta": float(normalized[3]), "temporal_accel": float(normalized[4]), "hp_ratio": float(normalized[5]), "spectral_flux": float(normalized[6]), } else: forensic_stats = {} probs_arr = np.array(probs, dtype=np.float32) if all_mag_list: all_mag = torch.cat(all_mag_list, dim=0) all_res = torch.cat(all_res_list, dim=0) all_H = torch.cat(all_H_list, dim=0) all_P = torch.cat(all_P_list, dim=0) all_mask = torch.cat(all_mask_list, dim=0) all_mel_res = torch.cat(all_mel_res_list, dim=0) router_feat, verdict_feat = _extract_router_verdict_features( all_mag, all_res, all_H, all_P, all_mask, all_mel_res, probs_arr, ) else: router_feat = np.zeros(75, dtype=np.float32) verdict_feat = np.zeros(28, dtype=np.float32) residual_placeholder = torch.zeros_like(wav_mono_tensor) return probs, residual_placeholder, metadata_list, forensic_stats, router_feat, verdict_feat