artifactnet / inference /model.py
intrect's picture
feat(space): CPU ONNX runtime build (v9.4, full-song sliding aggregation)
0020ddc
raw
history blame
14.4 kB
# Created: 2026-03-03
# Purpose: ArtifactNet 7ch Forensic CNN 아키텍처 (PyTorch)
# Dependencies: torch, numpy
"""ArtifactNet model architecture — ArtifactUNet + 7ch Forensic CNN.
v9.0: PyTorch 7ch pipeline (replaces ONNX v8.0).
GPU required for HPSS median filtering.
"""
import numpy as np
import torch
import torch.nn as nn
import torch.nn.functional as F
SR = 44100
N_FFT = 2048
HOP_LENGTH = 512
N_MELS = 128
FREQ_BINS = N_FFT // 2 + 1 # 1025
# ============================================================
# GatedResidualBlock
# ============================================================
class GatedResidualBlock(nn.Module):
"""GLU bottleneck with dilated convolution."""
def __init__(self, channels, dilation=1):
super().__init__()
mid = channels // 2
self.proj_in = nn.Conv2d(channels, mid, 1)
self.conv = nn.Conv2d(
mid, mid * 2, 3,
dilation=dilation, padding=dilation)
self.bn = nn.BatchNorm2d(mid * 2)
self.proj_out = nn.Conv2d(mid, channels, 1)
def forward(self, x):
h = F.relu(self.proj_in(x))
h = self.bn(self.conv(h))
a, b = h.chunk(2, dim=1)
return x + self.proj_out(torch.tanh(a) * torch.sigmoid(b))
# ============================================================
# ConvBlock
# ============================================================
class ConvBlock(nn.Module):
def __init__(self, in_ch, out_ch):
super().__init__()
self.block = nn.Sequential(
nn.Conv2d(in_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
nn.Conv2d(out_ch, out_ch, 3, padding=1),
nn.BatchNorm2d(out_ch),
nn.ReLU(inplace=True),
)
def forward(self, x):
return self.block(x)
# ============================================================
# ArtifactUNet
# ============================================================
class ArtifactUNet(nn.Module):
"""STFT magnitude masking U-Net. mask in [0, 0.5]."""
def __init__(self, base_channels=32, mask_max=0.5):
super().__init__()
c = base_channels
self.mask_max = mask_max
self.enc1 = ConvBlock(1, c)
self.pool1 = nn.MaxPool2d(2, 2)
self.enc2 = ConvBlock(c, c * 2)
self.pool2 = nn.MaxPool2d(2, 2)
self.enc3 = ConvBlock(c * 2, c * 4)
self.pool3 = nn.MaxPool2d(2, 2)
self.enc4 = ConvBlock(c * 4, c * 8)
self.pool4 = nn.MaxPool2d(2, 2)
self.bottleneck = nn.Sequential(
GatedResidualBlock(c * 8, dilation=1),
GatedResidualBlock(c * 8, dilation=2),
GatedResidualBlock(c * 8, dilation=4),
)
self.up4 = nn.ConvTranspose2d(c * 8, c * 8, 2, stride=2)
self.dec4 = ConvBlock(c * 16, c * 4)
self.up3 = nn.ConvTranspose2d(c * 4, c * 4, 2, stride=2)
self.dec3 = ConvBlock(c * 8, c * 2)
self.up2 = nn.ConvTranspose2d(c * 2, c * 2, 2, stride=2)
self.dec2 = ConvBlock(c * 4, c)
self.up1 = nn.ConvTranspose2d(c, c, 2, stride=2)
self.dec1 = ConvBlock(c * 2, c)
self.mask_head = nn.Conv2d(c, 1, 1)
def forward(self, x):
orig_f, orig_t = x.shape[2], x.shape[3]
pad_f = (16 - orig_f % 16) % 16
pad_t = (16 - orig_t % 16) % 16
if pad_f > 0 or pad_t > 0:
x = F.pad(x, (0, pad_t, 0, pad_f))
e1 = self.enc1(x)
e2 = self.enc2(self.pool1(e1))
e3 = self.enc3(self.pool2(e2))
e4 = self.enc4(self.pool3(e3))
b = self.bottleneck(self.pool4(e4))
d4 = self._skip_cat(self.up4(b), e4)
d4 = self.dec4(d4)
d3 = self._skip_cat(self.up3(d4), e3)
d3 = self.dec3(d3)
d2 = self._skip_cat(self.up2(d3), e2)
d2 = self.dec2(d2)
d1 = self._skip_cat(self.up1(d2), e1)
d1 = self.dec1(d1)
mask = torch.sigmoid(self.mask_head(d1)) * self.mask_max
return mask[:, :, :orig_f, :orig_t]
@staticmethod
def _skip_cat(up, skip):
df = skip.shape[2] - up.shape[2]
dt = skip.shape[3] - up.shape[3]
if df > 0 or dt > 0:
up = F.pad(up, (0, max(dt, 0), 0, max(df, 0)))
elif df < 0 or dt < 0:
up = up[:, :, :skip.shape[2], :skip.shape[3]]
return torch.cat([up, skip], dim=1)
# ============================================================
# ResidualCNNNch (7-channel forensic CNN)
# ============================================================
class ResidualCNNNch(nn.Module):
"""N-channel forensic CNN. Conv-BN-ReLU-Pool structure."""
def __init__(self, in_channels=7):
super().__init__()
self.in_channels = in_channels
self.features = nn.Sequential(
nn.Conv2d(in_channels, 32, 3, padding=1),
nn.BatchNorm2d(32),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(32, 64, 3, padding=1),
nn.BatchNorm2d(64),
nn.ReLU(inplace=True),
nn.MaxPool2d(2, 2),
nn.Conv2d(64, 128, 3, padding=1),
nn.BatchNorm2d(128),
nn.ReLU(inplace=True),
nn.AdaptiveAvgPool2d((4, 4)),
)
self.classifier = nn.Sequential(
nn.Dropout(0.5),
nn.Linear(128 * 4 * 4, 256),
nn.ReLU(inplace=True),
nn.Dropout(0.3),
nn.Linear(256, 1),
)
def forward(self, x):
x = self.features(x)
x = x.view(x.size(0), -1)
x = self.classifier(x)
return x.squeeze(-1)
class ResidualCNN7ch(nn.Module):
"""7-channel CNN for v9.x SOTA pipeline.
4-layer Conv + GlobalAvgPool + FC. ResidualCNNNch(3-conv)보다 깊음.
가중치: models/cnn_v94_best.pt (v9.4 SOTA, balanced dataset)"""
def __init__(self):
super().__init__()
self.conv1 = nn.Conv2d(7, 32, 3, padding=1); self.bn1 = nn.BatchNorm2d(32); self.pool1 = nn.MaxPool2d(2)
self.conv2 = nn.Conv2d(32, 64, 3, padding=1); self.bn2 = nn.BatchNorm2d(64); self.pool2 = nn.MaxPool2d(2)
self.conv3 = nn.Conv2d(64, 128, 3, padding=1); self.bn3 = nn.BatchNorm2d(128); self.pool3 = nn.MaxPool2d(2)
self.conv4 = nn.Conv2d(128, 256, 3, padding=1);self.bn4 = nn.BatchNorm2d(256); self.pool4 = nn.MaxPool2d(2)
self.global_pool = nn.AdaptiveAvgPool2d(1)
self.fc1 = nn.Linear(256, 128)
self.dropout = nn.Dropout(0.5)
self.fc2 = nn.Linear(128, 1)
def forward(self, x):
"""x: (B, 7, N_MELS, T) → (B,) logits"""
x = self.pool1(F.relu(self.bn1(self.conv1(x))))
x = self.pool2(F.relu(self.bn2(self.conv2(x))))
x = self.pool3(F.relu(self.bn3(self.conv3(x))))
x = self.pool4(F.relu(self.bn4(self.conv4(x))))
x = self.global_pool(x).view(x.size(0), -1)
return self.fc2(F.relu(self.fc1(x))).view(-1)
# ============================================================
# DifferentiableMel
# ============================================================
class DifferentiableMel(nn.Module):
"""STFT magnitude -> log-mel dB (normalized)."""
def __init__(self, sr=44100, n_fft=2048, n_mels=128, top_db=80.0):
super().__init__()
n_freqs = n_fft // 2 + 1
fb = self._create_mel_fb(n_freqs, n_mels, 0.0, sr / 2, sr)
self.register_buffer('fb', fb)
self.top_db = top_db
@staticmethod
def _create_mel_fb(n_freqs, n_mels, f_min, f_max, sr):
def hz_to_mel(f):
return 2595.0 * np.log10(1.0 + f / 700.0)
def mel_to_hz(m):
return 700.0 * (10.0 ** (m / 2595.0) - 1.0)
mel_min = hz_to_mel(f_min)
mel_max = hz_to_mel(f_max)
mel_pts = np.linspace(mel_min, mel_max, n_mels + 2)
hz_pts = mel_to_hz(mel_pts)
freqs = np.linspace(0, sr / 2, n_freqs)
fb = np.zeros((n_freqs, n_mels), dtype=np.float32)
for i in range(n_mels):
lo, mid, hi = hz_pts[i], hz_pts[i + 1], hz_pts[i + 2]
for j in range(n_freqs):
if lo <= freqs[j] <= mid and (mid - lo) > 0:
fb[j, i] = (freqs[j] - lo) / (mid - lo)
elif mid < freqs[j] <= hi and (hi - mid) > 0:
fb[j, i] = (hi - freqs[j]) / (hi - mid)
return torch.from_numpy(fb)
def forward(self, stft_mag):
"""(B, 1, F, T) -> (B, 1, N_MELS, T) log-mel normalized."""
x = stft_mag.squeeze(1)
power = x ** 2
mel = torch.einsum('fm,bft->bmt', self.fb, power)
mel_db = 10.0 * torch.log10(torch.clamp(mel, min=1e-10))
max_val = mel_db.amax(dim=(-2, -1), keepdim=True)
mel_db = torch.clamp(mel_db, min=max_val - self.top_db)
mean = mel_db.mean(dim=(-2, -1), keepdim=True)
std = mel_db.std(dim=(-2, -1), keepdim=True)
mel_norm = (mel_db - mean) / (std + 1e-9)
return mel_norm.unsqueeze(1)
# ============================================================
# CPU HPSS (librosa)
# ============================================================
def hpss_cpu(mag):
"""HPSS via librosa on CPU. mag: (B, 1, F, T) tensor -> H_mag, P_mag tensors.
각 배치를 numpy로 변환 → librosa.decompose.hpss → 다시 tensor.
데모용 CPU 파이프라인. 학습용 GPU HPSS는 train_nch_cnn_020303.py 참조.
"""
import librosa
device = mag.device
B = mag.shape[0]
mag_np = mag.squeeze(1).cpu().numpy() # (B, F, T)
H_list, P_list = [], []
for i in range(B):
H, P = librosa.decompose.hpss(mag_np[i], kernel_size=31)
H_list.append(H)
P_list.append(P)
H_mag = torch.from_numpy(np.stack(H_list)).unsqueeze(1).to(device) # (B, 1, F, T)
P_mag = torch.from_numpy(np.stack(P_list)).unsqueeze(1).to(device)
return H_mag, P_mag
# ============================================================
# GPU/MPS HPSS (순수 PyTorch — unfold + median, Triton 불필요)
# ============================================================
def _gpu_median_filter_2d(x, kernel_size, dim):
"""GPU median filter along one axis using unfold + median.
CUDA에서 빠름. MPS에서는 median이 극도로 느리므로 _avg_filter_2d 사용 권장.
Args:
x: (B, F, T) tensor on GPU
kernel_size: odd integer
dim: 1=freq축 (P 추출), 2=time축 (H 추출)
"""
pad = kernel_size // 2
if dim == 2:
x_pad = F.pad(x, (pad, pad), mode='reflect')
x_unfold = x_pad.unfold(2, kernel_size, 1)
else:
x_pad = F.pad(x, (0, 0, pad, pad), mode='reflect')
x_unfold = x_pad.unfold(1, kernel_size, 1)
return x_unfold.median(dim=-1).values
def _avg_filter_2d(x, kernel_size, dim):
"""avg_pool 기반 smoothing filter — MPS 최적화 (median 대비 400x 빠름).
median과 동일하지 않지만, HPSS Wiener masking에서 충분한 근사.
H/P 비율 계산에서 절대값보다 상대적 크기가 중요하므로 성능 차이 미미.
Args:
x: (B, F, T) tensor
kernel_size: odd integer
dim: 1=freq축, 2=time축
"""
pad = kernel_size // 2
B, F_dim, T = x.shape
if dim == 2: # time축
x_flat = x.reshape(B * F_dim, 1, T)
out = F.avg_pool1d(x_flat, kernel_size=kernel_size, stride=1, padding=pad)
return out.reshape(B, F_dim, T)
else: # freq축
x_t = x.transpose(1, 2) # (B, T, F)
x_flat = x_t.reshape(B * T, 1, F_dim)
out = F.avg_pool1d(x_flat, kernel_size=kernel_size, stride=1, padding=pad)
return out.reshape(B, T, F_dim).transpose(1, 2)
def hpss_gpu_pure(mag, h_kernel=31, p_kernel=31):
"""순수 PyTorch HPSS — CUDA/MPS 모두 호환.
CUDA: unfold + median (정확), MPS: avg_pool 근사 (400x 빠름).
Args:
mag: (B, 1, F, T) STFT magnitude on any device
Returns:
H_mag, P_mag: (B, 1, F, T)
"""
mag_sq = mag.squeeze(1) # (B, F, T)
# 모든 CNN이 median filter HPSS로 학습됨 → avg_pool 근사 사용 금지
# MPS에서 unfold().median()이 극도로 느림 (13초/곡) → CPU에서 수행 후 복귀
if mag_sq.device.type == 'mps':
orig_device = mag_sq.device
mag_cpu = mag_sq.cpu()
H_filter = _gpu_median_filter_2d(mag_cpu, h_kernel, dim=2).to(orig_device)
P_filter = _gpu_median_filter_2d(mag_cpu, p_kernel, dim=1).to(orig_device)
else:
H_filter = _gpu_median_filter_2d(mag_sq, h_kernel, dim=2)
P_filter = _gpu_median_filter_2d(mag_sq, p_kernel, dim=1)
H2 = H_filter ** 2
P2 = P_filter ** 2
denom = H2 + P2 + 1e-10
H_mask = H2 / denom
P_mask = P2 / denom
H_mag = (mag_sq * H_mask).unsqueeze(1)
P_mag = (mag_sq * P_mask).unsqueeze(1)
return H_mag, P_mag
# ============================================================
# 7ch Forensic Feature Computation
# ============================================================
def compute_forensic_features_7ch(mel_res, mel_H, mel_P):
"""Compute 7-channel forensic features from HPSS mel spectrograms.
Channels:
ch1: mel_residual - UNet residual mel spectrogram
ch2: mel_harmonic - HPSS harmonic mel
ch3: mel_percussive - HPSS percussive mel
ch4: delta - temporal 1st derivative
ch5: delta2 - temporal 2nd derivative
ch6: hp_ratio - log(H/P) ratio
ch7: spectral_flux - |delta| (absolute spectral change)
Args:
mel_res: (B, 1, N_MELS, T)
mel_H: (B, 1, N_MELS, T)
mel_P: (B, 1, N_MELS, T)
Returns:
(B, 7, N_MELS, T) concatenated features
"""
delta = torch.diff(mel_res, n=1, dim=-1)
delta = F.pad(delta, (1, 0))
delta2 = torch.diff(delta, n=1, dim=-1)
delta2 = F.pad(delta2, (1, 0))
hp_ratio = mel_H - mel_P
spectral_flux = torch.abs(delta)
return torch.cat([mel_res, mel_H, mel_P, delta, delta2, hp_ratio, spectral_flux], dim=1)