# Created: 2026-03-03 # Purpose: ArtifactNet 7ch Forensic CNN 아키텍처 (PyTorch) # Dependencies: torch, numpy """ArtifactNet model architecture — ArtifactUNet + 7ch Forensic CNN. v9.0: PyTorch 7ch pipeline (replaces ONNX v8.0). GPU required for HPSS median filtering. """ import numpy as np import torch import torch.nn as nn import torch.nn.functional as F SR = 44100 N_FFT = 2048 HOP_LENGTH = 512 N_MELS = 128 FREQ_BINS = N_FFT // 2 + 1 # 1025 # ============================================================ # GatedResidualBlock # ============================================================ class GatedResidualBlock(nn.Module): """GLU bottleneck with dilated convolution.""" def __init__(self, channels, dilation=1): super().__init__() mid = channels // 2 self.proj_in = nn.Conv2d(channels, mid, 1) self.conv = nn.Conv2d( mid, mid * 2, 3, dilation=dilation, padding=dilation) self.bn = nn.BatchNorm2d(mid * 2) self.proj_out = nn.Conv2d(mid, channels, 1) def forward(self, x): h = F.relu(self.proj_in(x)) h = self.bn(self.conv(h)) a, b = h.chunk(2, dim=1) return x + self.proj_out(torch.tanh(a) * torch.sigmoid(b)) # ============================================================ # ConvBlock # ============================================================ class ConvBlock(nn.Module): def __init__(self, in_ch, out_ch): super().__init__() self.block = nn.Sequential( nn.Conv2d(in_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), nn.Conv2d(out_ch, out_ch, 3, padding=1), nn.BatchNorm2d(out_ch), nn.ReLU(inplace=True), ) def forward(self, x): return self.block(x) # ============================================================ # ArtifactUNet # ============================================================ class ArtifactUNet(nn.Module): """STFT magnitude masking U-Net. mask in [0, 0.5].""" def __init__(self, base_channels=32, mask_max=0.5): super().__init__() c = base_channels self.mask_max = mask_max self.enc1 = ConvBlock(1, c) self.pool1 = nn.MaxPool2d(2, 2) self.enc2 = ConvBlock(c, c * 2) self.pool2 = nn.MaxPool2d(2, 2) self.enc3 = ConvBlock(c * 2, c * 4) self.pool3 = nn.MaxPool2d(2, 2) self.enc4 = ConvBlock(c * 4, c * 8) self.pool4 = nn.MaxPool2d(2, 2) self.bottleneck = nn.Sequential( GatedResidualBlock(c * 8, dilation=1), GatedResidualBlock(c * 8, dilation=2), GatedResidualBlock(c * 8, dilation=4), ) self.up4 = nn.ConvTranspose2d(c * 8, c * 8, 2, stride=2) self.dec4 = ConvBlock(c * 16, c * 4) self.up3 = nn.ConvTranspose2d(c * 4, c * 4, 2, stride=2) self.dec3 = ConvBlock(c * 8, c * 2) self.up2 = nn.ConvTranspose2d(c * 2, c * 2, 2, stride=2) self.dec2 = ConvBlock(c * 4, c) self.up1 = nn.ConvTranspose2d(c, c, 2, stride=2) self.dec1 = ConvBlock(c * 2, c) self.mask_head = nn.Conv2d(c, 1, 1) def forward(self, x): orig_f, orig_t = x.shape[2], x.shape[3] pad_f = (16 - orig_f % 16) % 16 pad_t = (16 - orig_t % 16) % 16 if pad_f > 0 or pad_t > 0: x = F.pad(x, (0, pad_t, 0, pad_f)) e1 = self.enc1(x) e2 = self.enc2(self.pool1(e1)) e3 = self.enc3(self.pool2(e2)) e4 = self.enc4(self.pool3(e3)) b = self.bottleneck(self.pool4(e4)) d4 = self._skip_cat(self.up4(b), e4) d4 = self.dec4(d4) d3 = self._skip_cat(self.up3(d4), e3) d3 = self.dec3(d3) d2 = self._skip_cat(self.up2(d3), e2) d2 = self.dec2(d2) d1 = self._skip_cat(self.up1(d2), e1) d1 = self.dec1(d1) mask = torch.sigmoid(self.mask_head(d1)) * self.mask_max return mask[:, :, :orig_f, :orig_t] @staticmethod def _skip_cat(up, skip): df = skip.shape[2] - up.shape[2] dt = skip.shape[3] - up.shape[3] if df > 0 or dt > 0: up = F.pad(up, (0, max(dt, 0), 0, max(df, 0))) elif df < 0 or dt < 0: up = up[:, :, :skip.shape[2], :skip.shape[3]] return torch.cat([up, skip], dim=1) # ============================================================ # ResidualCNNNch (7-channel forensic CNN) # ============================================================ class ResidualCNNNch(nn.Module): """N-channel forensic CNN. Conv-BN-ReLU-Pool structure.""" def __init__(self, in_channels=7): super().__init__() self.in_channels = in_channels self.features = nn.Sequential( nn.Conv2d(in_channels, 32, 3, padding=1), nn.BatchNorm2d(32), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Conv2d(32, 64, 3, padding=1), nn.BatchNorm2d(64), nn.ReLU(inplace=True), nn.MaxPool2d(2, 2), nn.Conv2d(64, 128, 3, padding=1), nn.BatchNorm2d(128), nn.ReLU(inplace=True), nn.AdaptiveAvgPool2d((4, 4)), ) self.classifier = nn.Sequential( nn.Dropout(0.5), nn.Linear(128 * 4 * 4, 256), nn.ReLU(inplace=True), nn.Dropout(0.3), nn.Linear(256, 1), ) def forward(self, x): x = self.features(x) x = x.view(x.size(0), -1) x = self.classifier(x) return x.squeeze(-1) class ResidualCNN7ch(nn.Module): """7-channel CNN for v9.x SOTA pipeline. 4-layer Conv + GlobalAvgPool + FC. ResidualCNNNch(3-conv)보다 깊음. 가중치: models/cnn_v94_best.pt (v9.4 SOTA, balanced dataset)""" def __init__(self): super().__init__() self.conv1 = nn.Conv2d(7, 32, 3, padding=1); self.bn1 = nn.BatchNorm2d(32); self.pool1 = nn.MaxPool2d(2) self.conv2 = nn.Conv2d(32, 64, 3, padding=1); self.bn2 = nn.BatchNorm2d(64); self.pool2 = nn.MaxPool2d(2) self.conv3 = nn.Conv2d(64, 128, 3, padding=1); self.bn3 = nn.BatchNorm2d(128); self.pool3 = nn.MaxPool2d(2) self.conv4 = nn.Conv2d(128, 256, 3, padding=1);self.bn4 = nn.BatchNorm2d(256); self.pool4 = nn.MaxPool2d(2) self.global_pool = nn.AdaptiveAvgPool2d(1) self.fc1 = nn.Linear(256, 128) self.dropout = nn.Dropout(0.5) self.fc2 = nn.Linear(128, 1) def forward(self, x): """x: (B, 7, N_MELS, T) → (B,) logits""" x = self.pool1(F.relu(self.bn1(self.conv1(x)))) x = self.pool2(F.relu(self.bn2(self.conv2(x)))) x = self.pool3(F.relu(self.bn3(self.conv3(x)))) x = self.pool4(F.relu(self.bn4(self.conv4(x)))) x = self.global_pool(x).view(x.size(0), -1) return self.fc2(F.relu(self.fc1(x))).view(-1) # ============================================================ # DifferentiableMel # ============================================================ class DifferentiableMel(nn.Module): """STFT magnitude -> log-mel dB (normalized).""" def __init__(self, sr=44100, n_fft=2048, n_mels=128, top_db=80.0): super().__init__() n_freqs = n_fft // 2 + 1 fb = self._create_mel_fb(n_freqs, n_mels, 0.0, sr / 2, sr) self.register_buffer('fb', fb) self.top_db = top_db @staticmethod def _create_mel_fb(n_freqs, n_mels, f_min, f_max, sr): def hz_to_mel(f): return 2595.0 * np.log10(1.0 + f / 700.0) def mel_to_hz(m): return 700.0 * (10.0 ** (m / 2595.0) - 1.0) mel_min = hz_to_mel(f_min) mel_max = hz_to_mel(f_max) mel_pts = np.linspace(mel_min, mel_max, n_mels + 2) hz_pts = mel_to_hz(mel_pts) freqs = np.linspace(0, sr / 2, n_freqs) fb = np.zeros((n_freqs, n_mels), dtype=np.float32) for i in range(n_mels): lo, mid, hi = hz_pts[i], hz_pts[i + 1], hz_pts[i + 2] for j in range(n_freqs): if lo <= freqs[j] <= mid and (mid - lo) > 0: fb[j, i] = (freqs[j] - lo) / (mid - lo) elif mid < freqs[j] <= hi and (hi - mid) > 0: fb[j, i] = (hi - freqs[j]) / (hi - mid) return torch.from_numpy(fb) def forward(self, stft_mag): """(B, 1, F, T) -> (B, 1, N_MELS, T) log-mel normalized.""" x = stft_mag.squeeze(1) power = x ** 2 mel = torch.einsum('fm,bft->bmt', self.fb, power) mel_db = 10.0 * torch.log10(torch.clamp(mel, min=1e-10)) max_val = mel_db.amax(dim=(-2, -1), keepdim=True) mel_db = torch.clamp(mel_db, min=max_val - self.top_db) mean = mel_db.mean(dim=(-2, -1), keepdim=True) std = mel_db.std(dim=(-2, -1), keepdim=True) mel_norm = (mel_db - mean) / (std + 1e-9) return mel_norm.unsqueeze(1) # ============================================================ # CPU HPSS (librosa) # ============================================================ def hpss_cpu(mag): """HPSS via librosa on CPU. mag: (B, 1, F, T) tensor -> H_mag, P_mag tensors. 각 배치를 numpy로 변환 → librosa.decompose.hpss → 다시 tensor. 데모용 CPU 파이프라인. 학습용 GPU HPSS는 train_nch_cnn_020303.py 참조. """ import librosa device = mag.device B = mag.shape[0] mag_np = mag.squeeze(1).cpu().numpy() # (B, F, T) H_list, P_list = [], [] for i in range(B): H, P = librosa.decompose.hpss(mag_np[i], kernel_size=31) H_list.append(H) P_list.append(P) H_mag = torch.from_numpy(np.stack(H_list)).unsqueeze(1).to(device) # (B, 1, F, T) P_mag = torch.from_numpy(np.stack(P_list)).unsqueeze(1).to(device) return H_mag, P_mag # ============================================================ # GPU/MPS HPSS (순수 PyTorch — unfold + median, Triton 불필요) # ============================================================ def _gpu_median_filter_2d(x, kernel_size, dim): """GPU median filter along one axis using unfold + median. CUDA에서 빠름. MPS에서는 median이 극도로 느리므로 _avg_filter_2d 사용 권장. Args: x: (B, F, T) tensor on GPU kernel_size: odd integer dim: 1=freq축 (P 추출), 2=time축 (H 추출) """ pad = kernel_size // 2 if dim == 2: x_pad = F.pad(x, (pad, pad), mode='reflect') x_unfold = x_pad.unfold(2, kernel_size, 1) else: x_pad = F.pad(x, (0, 0, pad, pad), mode='reflect') x_unfold = x_pad.unfold(1, kernel_size, 1) return x_unfold.median(dim=-1).values def _avg_filter_2d(x, kernel_size, dim): """avg_pool 기반 smoothing filter — MPS 최적화 (median 대비 400x 빠름). median과 동일하지 않지만, HPSS Wiener masking에서 충분한 근사. H/P 비율 계산에서 절대값보다 상대적 크기가 중요하므로 성능 차이 미미. Args: x: (B, F, T) tensor kernel_size: odd integer dim: 1=freq축, 2=time축 """ pad = kernel_size // 2 B, F_dim, T = x.shape if dim == 2: # time축 x_flat = x.reshape(B * F_dim, 1, T) out = F.avg_pool1d(x_flat, kernel_size=kernel_size, stride=1, padding=pad) return out.reshape(B, F_dim, T) else: # freq축 x_t = x.transpose(1, 2) # (B, T, F) x_flat = x_t.reshape(B * T, 1, F_dim) out = F.avg_pool1d(x_flat, kernel_size=kernel_size, stride=1, padding=pad) return out.reshape(B, T, F_dim).transpose(1, 2) def hpss_gpu_pure(mag, h_kernel=31, p_kernel=31): """순수 PyTorch HPSS — CUDA/MPS 모두 호환. CUDA: unfold + median (정확), MPS: avg_pool 근사 (400x 빠름). Args: mag: (B, 1, F, T) STFT magnitude on any device Returns: H_mag, P_mag: (B, 1, F, T) """ mag_sq = mag.squeeze(1) # (B, F, T) # 모든 CNN이 median filter HPSS로 학습됨 → avg_pool 근사 사용 금지 # MPS에서 unfold().median()이 극도로 느림 (13초/곡) → CPU에서 수행 후 복귀 if mag_sq.device.type == 'mps': orig_device = mag_sq.device mag_cpu = mag_sq.cpu() H_filter = _gpu_median_filter_2d(mag_cpu, h_kernel, dim=2).to(orig_device) P_filter = _gpu_median_filter_2d(mag_cpu, p_kernel, dim=1).to(orig_device) else: H_filter = _gpu_median_filter_2d(mag_sq, h_kernel, dim=2) P_filter = _gpu_median_filter_2d(mag_sq, p_kernel, dim=1) H2 = H_filter ** 2 P2 = P_filter ** 2 denom = H2 + P2 + 1e-10 H_mask = H2 / denom P_mask = P2 / denom H_mag = (mag_sq * H_mask).unsqueeze(1) P_mag = (mag_sq * P_mask).unsqueeze(1) return H_mag, P_mag # ============================================================ # 7ch Forensic Feature Computation # ============================================================ def compute_forensic_features_7ch(mel_res, mel_H, mel_P): """Compute 7-channel forensic features from HPSS mel spectrograms. Channels: ch1: mel_residual - UNet residual mel spectrogram ch2: mel_harmonic - HPSS harmonic mel ch3: mel_percussive - HPSS percussive mel ch4: delta - temporal 1st derivative ch5: delta2 - temporal 2nd derivative ch6: hp_ratio - log(H/P) ratio ch7: spectral_flux - |delta| (absolute spectral change) Args: mel_res: (B, 1, N_MELS, T) mel_H: (B, 1, N_MELS, T) mel_P: (B, 1, N_MELS, T) Returns: (B, 7, N_MELS, T) concatenated features """ delta = torch.diff(mel_res, n=1, dim=-1) delta = F.pad(delta, (1, 0)) delta2 = torch.diff(delta, n=1, dim=-1) delta2 = F.pad(delta2, (1, 0)) hp_ratio = mel_H - mel_P spectral_flux = torch.abs(delta) return torch.cat([mel_res, mel_H, mel_P, delta, delta2, hp_ratio, spectral_flux], dim=1)