artifactnet / inference /e2e_model.py
intrect's picture
feat(space): CPU ONNX runtime build (v9.4, full-song sliding aggregation)
0020ddc
raw
history blame
10.7 kB
# Purpose: ArtifactNet 7ch inference pipeline — HF Spaces (CPU, ONNX Runtime)
# Dependencies: onnxruntime, torch (HPSS/Mel only), huggingface_hub, scipy
"""ArtifactNet v9.4 inference — onnxruntime CPU.
UNet + CNN 은 .onnx (public-safe) 로 실행, HPSS + Mel + 7ch feature 는
pytorch CPU 로 처리 (가중치 없는 고정 연산이라 노출 위험 없음).
"""
import os
from pathlib import Path
import numpy as np
import onnxruntime as ort
import torch
from huggingface_hub import hf_hub_download
from scipy import stats as sp_stats
from config import (
HF_MODEL_REPO, UNET_ONNX_FILENAME, CNN_ONNX_FILENAME,
SR, N_FFT, HOP_LENGTH, CHUNK_SAMPLES, BATCH_SIZE,
)
from .audio_utils import sliding_chunks
from .model import (
DifferentiableMel, hpss_gpu_pure, compute_forensic_features_7ch,
)
N_MELS = 128
FREQ_BANDS = [
("sub", 0, 250),
("low", 250, 2000),
("mid", 2000, 6000),
("hi_mid", 6000, 10000),
("hi", 10000, 16000),
("air", 16000, 22050),
]
# ============================================================
# Lazy singletons
# ============================================================
_unet_sess: ort.InferenceSession | None = None
_cnn_sess: ort.InferenceSession | None = None
_mel: DifferentiableMel | None = None
_stft_window: torch.Tensor | None = None
def _ort_threads() -> int:
"""HF Spaces CPU basic = 2 vCPU. 환경변수로 override 가능."""
try:
return int(os.environ.get("ORT_THREADS", "2"))
except ValueError:
return 2
def _resolve_onnx(filename: str, env_var: str) -> str:
"""로컬 override (ARTIFACTNET_UNET_ONNX / _CNN_ONNX) 있으면 그걸 사용, 아니면 HF Hub."""
local = os.environ.get(env_var)
if local and Path(local).is_file():
return local
return hf_hub_download(HF_MODEL_REPO, filename)
def load_models():
"""ONNX 세션 + Mel/Window 초기화 (import 후 1회)."""
global _unet_sess, _cnn_sess, _mel, _stft_window
if _unet_sess is not None:
return
unet_path = _resolve_onnx(UNET_ONNX_FILENAME, "ARTIFACTNET_UNET_ONNX")
cnn_path = _resolve_onnx(CNN_ONNX_FILENAME, "ARTIFACTNET_CNN_ONNX")
opts = ort.SessionOptions()
opts.intra_op_num_threads = _ort_threads()
opts.inter_op_num_threads = 1
opts.graph_optimization_level = ort.GraphOptimizationLevel.ORT_ENABLE_ALL
_unet_sess = ort.InferenceSession(unet_path, sess_options=opts,
providers=["CPUExecutionProvider"])
_cnn_sess = ort.InferenceSession(cnn_path, sess_options=opts,
providers=["CPUExecutionProvider"])
_mel = DifferentiableMel(sr=SR, n_fft=N_FFT, n_mels=N_MELS)
_mel.eval()
_stft_window = torch.hann_window(N_FFT)
print(f"[hf-spaces] ONNX sessions ready (intra_threads={_ort_threads()})", flush=True)
# ============================================================
# Feature extraction helpers (75-dim Router + 28-dim Verdict)
# ============================================================
def _extract_router_verdict_features(
all_mag, all_res, all_H, all_P, all_mask, all_mel_res, probs,
):
"""infer.py extract_features()와 동일한 로직 (device=CPU)."""
freq_hz = torch.linspace(0, SR / 2, all_mag.shape[2])
orig_total = all_mag.pow(2).mean().item() + 1e-8
res_total = all_res.pow(2).mean().item() + 1e-8
band_idx = []
for _, flo, fhi in FREQ_BANDS:
lo = (freq_hz >= flo).nonzero(as_tuple=True)[0]
hi = (freq_hz >= fhi).nonzero(as_tuple=True)[0]
band_idx.append((
lo[0].item() if len(lo) else 0,
hi[0].item() if len(hi) else all_mag.shape[2],
))
rf = []
for i0, i1 in band_idx:
oe = all_mag[:, :, i0:i1, :].pow(2).mean().item() / orig_total
re = all_res[:, :, i0:i1, :].pow(2).mean().item() / res_total
rf.extend([oe, re, re / (oe + 1e-8)])
mel_profile = all_mel_res.mean(dim=[0, 3]).squeeze().cpu().numpy()
step = N_MELS // 32
compressed = mel_profile[:32 * step].reshape(32, step).mean(axis=1)
compressed = compressed - compressed.mean()
norm = np.abs(compressed).max() + 1e-8
rf.extend((compressed / norm).tolist())
H_total = all_H.pow(2).mean().item() + 1e-8
P_total = all_P.pow(2).mean().item() + 1e-8
hp_ratio = H_total / (H_total + P_total)
rf.append(hp_ratio)
for i0, i1 in band_idx:
rf.extend([
all_H[:, :, i0:i1, :].pow(2).mean().item() / H_total,
all_P[:, :, i0:i1, :].pow(2).mean().item() / P_total,
])
mask_np = all_mask.cpu().numpy().flatten()
rf.extend([
float(mask_np.mean()), float(mask_np.std()),
float(np.percentile(mask_np, 10)), float(np.percentile(mask_np, 25)),
float(np.percentile(mask_np, 75)), float(np.percentile(mask_np, 90)),
float(np.median(mask_np)),
])
rf.extend([
float(probs.mean()), float(probs.std()), float(np.median(probs)),
float(np.percentile(probs, 10)), float(np.percentile(probs, 90)),
])
router_feat = np.nan_to_num(np.array(rf, dtype=np.float32))
arr = probs.astype(np.float64)
n = len(arr)
cnn_20 = np.array([
n, arr.mean(), arr.std(), np.median(arr),
arr.min(), arr.max(), arr.max() - arr.min(),
np.percentile(arr, 10), np.percentile(arr, 25),
np.percentile(arr, 75), np.percentile(arr, 90),
(arr >= 0.3).mean(), (arr >= 0.5).mean(),
(arr >= 0.7).mean(), (arr >= 0.8).mean(), (arr >= 0.9).mean(),
float(sp_stats.skew(arr)) if n >= 3 else 0.0,
float(sp_stats.kurtosis(arr)) if n >= 3 else 0.0,
float(np.diff(arr).std()) if n >= 2 else 0.0,
float(np.abs(np.diff(arr)).max()) if n >= 2 else 0.0,
], dtype=np.float32)
hf8k_i = (freq_hz >= 8000).nonzero(as_tuple=True)[0]
hf8k_i = hf8k_i[0].item() if len(hf8k_i) else all_mag.shape[2]
ai0, ai1 = band_idx[5]
res_8 = np.array([
all_res[:, :, hf8k_i:, :].pow(2).mean().item() / res_total,
all_res[:, :, ai0:ai1, :].pow(2).mean().item() / res_total,
all_H[:, :, ai0:ai1, :].pow(2).mean().item() / H_total,
all_P[:, :, ai0:ai1, :].pow(2).mean().item() / P_total,
float(mel_profile[-1]),
float(mel_profile[0]),
float(mask_np.mean()),
float(hp_ratio),
], dtype=np.float32)
verdict_feat = np.nan_to_num(np.concatenate([cnn_20, res_8]))
return router_feat, verdict_feat
# ============================================================
# Inference
# ============================================================
@torch.no_grad()
def run_e2e_inference(wav_mono_tensor: torch.Tensor):
"""mono waveform -> (probs, placeholder, metadata, forensic_stats, router_feat, verdict_feat).
ONNX Runtime CPU + pytorch HPSS/Mel.
"""
if _unet_sess is None:
load_models()
chunk_data = sliding_chunks(wav_mono_tensor, CHUNK_SAMPLES)
if not chunk_data:
return [], torch.zeros_like(wav_mono_tensor), [], {}, \
np.zeros(75, dtype=np.float32), np.zeros(28, dtype=np.float32)
chunks = [chunk for chunk, _ in chunk_data]
metadata_list = [meta for _, meta in chunk_data]
probs = []
all_features = []
all_mag_list, all_res_list, all_H_list, all_P_list = [], [], [], []
all_mask_list, all_mel_res_list = [], []
for i in range(0, len(chunks), BATCH_SIZE):
batch = torch.stack(chunks[i:i + BATCH_SIZE]) # (B, CHUNK_SAMPLES)
# STFT (torch, CPU)
stft = torch.stft(
batch, N_FFT, HOP_LENGTH,
window=_stft_window, return_complex=True)
stft_mag = stft.abs().unsqueeze(1) # (B, 1, F, T)
# UNet mask via ONNX
mask_np = _unet_sess.run(
["mask"],
{"stft_mag": stft_mag.numpy().astype(np.float32)},
)[0]
mask = torch.from_numpy(mask_np)
res_mag = mask * stft_mag
# HPSS — CPU median filter (unfold + median) 로 학습 분포 유지.
# librosa.decompose.hpss 는 결과가 달라 v9.4 CNN 오판 (CLAUDE.md 경고 참조).
H_mag, P_mag = hpss_gpu_pure(res_mag)
# Mel 3-band
mel_res = _mel(res_mag)
mel_H = _mel(H_mag)
mel_P = _mel(P_mag)
features_7ch = compute_forensic_features_7ch(mel_res, mel_H, mel_P)
all_features.append(features_7ch)
# CNN logit via ONNX → sigmoid
logits = _cnn_sess.run(
["logit"],
{"features_7ch": features_7ch.numpy().astype(np.float32)},
)[0]
batch_probs = (1.0 / (1.0 + np.exp(-np.clip(logits, -30, 30)))).tolist()
probs.extend(batch_probs)
all_mag_list.append(stft_mag)
all_res_list.append(res_mag)
all_H_list.append(H_mag)
all_P_list.append(P_mag)
all_mask_list.append(mask)
all_mel_res_list.append(mel_res)
if all_features:
all_feat_tensor = torch.cat(all_features, dim=0)
channel_means = all_feat_tensor.mean(dim=[2, 3])
feature_medians = channel_means.median(dim=0).values
feat_min = channel_means.min(dim=0).values
feat_max = channel_means.max(dim=0).values
feat_range = feat_max - feat_min + 1e-8
normalized = ((feature_medians - feat_min) / feat_range).clamp(0, 1)
forensic_stats = {
"residual_energy": float(normalized[0]),
"harmonic_strength": float(normalized[1]),
"percussive_strength": float(normalized[2]),
"temporal_delta": float(normalized[3]),
"temporal_accel": float(normalized[4]),
"hp_ratio": float(normalized[5]),
"spectral_flux": float(normalized[6]),
}
else:
forensic_stats = {}
probs_arr = np.array(probs, dtype=np.float32)
if all_mag_list:
all_mag = torch.cat(all_mag_list, dim=0)
all_res = torch.cat(all_res_list, dim=0)
all_H = torch.cat(all_H_list, dim=0)
all_P = torch.cat(all_P_list, dim=0)
all_mask = torch.cat(all_mask_list, dim=0)
all_mel_res = torch.cat(all_mel_res_list, dim=0)
router_feat, verdict_feat = _extract_router_verdict_features(
all_mag, all_res, all_H, all_P, all_mask, all_mel_res, probs_arr,
)
else:
router_feat = np.zeros(75, dtype=np.float32)
verdict_feat = np.zeros(28, dtype=np.float32)
residual_placeholder = torch.zeros_like(wav_mono_tensor)
return probs, residual_placeholder, metadata_list, forensic_stats, router_feat, verdict_feat