Spaces:
Runtime error
Runtime error
| # Created: 2026-03-03 | |
| # Purpose: ArtifactNet 7ch Forensic CNN 아키텍처 (PyTorch) | |
| # Dependencies: torch, numpy | |
| """ArtifactNet model architecture — ArtifactUNet + 7ch Forensic CNN. | |
| v9.0: PyTorch 7ch pipeline (replaces ONNX v8.0). | |
| GPU required for HPSS median filtering. | |
| """ | |
| import numpy as np | |
| import torch | |
| import torch.nn as nn | |
| import torch.nn.functional as F | |
| SR = 44100 | |
| N_FFT = 2048 | |
| HOP_LENGTH = 512 | |
| N_MELS = 128 | |
| FREQ_BINS = N_FFT // 2 + 1 # 1025 | |
| # ============================================================ | |
| # GatedResidualBlock | |
| # ============================================================ | |
| class GatedResidualBlock(nn.Module): | |
| """GLU bottleneck with dilated convolution.""" | |
| def __init__(self, channels, dilation=1): | |
| super().__init__() | |
| mid = channels // 2 | |
| self.proj_in = nn.Conv2d(channels, mid, 1) | |
| self.conv = nn.Conv2d( | |
| mid, mid * 2, 3, | |
| dilation=dilation, padding=dilation) | |
| self.bn = nn.BatchNorm2d(mid * 2) | |
| self.proj_out = nn.Conv2d(mid, channels, 1) | |
| def forward(self, x): | |
| h = F.relu(self.proj_in(x)) | |
| h = self.bn(self.conv(h)) | |
| a, b = h.chunk(2, dim=1) | |
| return x + self.proj_out(torch.tanh(a) * torch.sigmoid(b)) | |
| # ============================================================ | |
| # ConvBlock | |
| # ============================================================ | |
| class ConvBlock(nn.Module): | |
| def __init__(self, in_ch, out_ch): | |
| super().__init__() | |
| self.block = nn.Sequential( | |
| nn.Conv2d(in_ch, out_ch, 3, padding=1), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(inplace=True), | |
| nn.Conv2d(out_ch, out_ch, 3, padding=1), | |
| nn.BatchNorm2d(out_ch), | |
| nn.ReLU(inplace=True), | |
| ) | |
| def forward(self, x): | |
| return self.block(x) | |
| # ============================================================ | |
| # ArtifactUNet | |
| # ============================================================ | |
| class ArtifactUNet(nn.Module): | |
| """STFT magnitude masking U-Net. mask in [0, 0.5].""" | |
| def __init__(self, base_channels=32, mask_max=0.5): | |
| super().__init__() | |
| c = base_channels | |
| self.mask_max = mask_max | |
| self.enc1 = ConvBlock(1, c) | |
| self.pool1 = nn.MaxPool2d(2, 2) | |
| self.enc2 = ConvBlock(c, c * 2) | |
| self.pool2 = nn.MaxPool2d(2, 2) | |
| self.enc3 = ConvBlock(c * 2, c * 4) | |
| self.pool3 = nn.MaxPool2d(2, 2) | |
| self.enc4 = ConvBlock(c * 4, c * 8) | |
| self.pool4 = nn.MaxPool2d(2, 2) | |
| self.bottleneck = nn.Sequential( | |
| GatedResidualBlock(c * 8, dilation=1), | |
| GatedResidualBlock(c * 8, dilation=2), | |
| GatedResidualBlock(c * 8, dilation=4), | |
| ) | |
| self.up4 = nn.ConvTranspose2d(c * 8, c * 8, 2, stride=2) | |
| self.dec4 = ConvBlock(c * 16, c * 4) | |
| self.up3 = nn.ConvTranspose2d(c * 4, c * 4, 2, stride=2) | |
| self.dec3 = ConvBlock(c * 8, c * 2) | |
| self.up2 = nn.ConvTranspose2d(c * 2, c * 2, 2, stride=2) | |
| self.dec2 = ConvBlock(c * 4, c) | |
| self.up1 = nn.ConvTranspose2d(c, c, 2, stride=2) | |
| self.dec1 = ConvBlock(c * 2, c) | |
| self.mask_head = nn.Conv2d(c, 1, 1) | |
| def forward(self, x): | |
| orig_f, orig_t = x.shape[2], x.shape[3] | |
| pad_f = (16 - orig_f % 16) % 16 | |
| pad_t = (16 - orig_t % 16) % 16 | |
| if pad_f > 0 or pad_t > 0: | |
| x = F.pad(x, (0, pad_t, 0, pad_f)) | |
| e1 = self.enc1(x) | |
| e2 = self.enc2(self.pool1(e1)) | |
| e3 = self.enc3(self.pool2(e2)) | |
| e4 = self.enc4(self.pool3(e3)) | |
| b = self.bottleneck(self.pool4(e4)) | |
| d4 = self._skip_cat(self.up4(b), e4) | |
| d4 = self.dec4(d4) | |
| d3 = self._skip_cat(self.up3(d4), e3) | |
| d3 = self.dec3(d3) | |
| d2 = self._skip_cat(self.up2(d3), e2) | |
| d2 = self.dec2(d2) | |
| d1 = self._skip_cat(self.up1(d2), e1) | |
| d1 = self.dec1(d1) | |
| mask = torch.sigmoid(self.mask_head(d1)) * self.mask_max | |
| return mask[:, :, :orig_f, :orig_t] | |
| def _skip_cat(up, skip): | |
| df = skip.shape[2] - up.shape[2] | |
| dt = skip.shape[3] - up.shape[3] | |
| if df > 0 or dt > 0: | |
| up = F.pad(up, (0, max(dt, 0), 0, max(df, 0))) | |
| elif df < 0 or dt < 0: | |
| up = up[:, :, :skip.shape[2], :skip.shape[3]] | |
| return torch.cat([up, skip], dim=1) | |
| # ============================================================ | |
| # ResidualCNNNch (7-channel forensic CNN) | |
| # ============================================================ | |
| class ResidualCNNNch(nn.Module): | |
| """N-channel forensic CNN. Conv-BN-ReLU-Pool structure.""" | |
| def __init__(self, in_channels=7): | |
| super().__init__() | |
| self.in_channels = in_channels | |
| self.features = nn.Sequential( | |
| nn.Conv2d(in_channels, 32, 3, padding=1), | |
| nn.BatchNorm2d(32), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2, 2), | |
| nn.Conv2d(32, 64, 3, padding=1), | |
| nn.BatchNorm2d(64), | |
| nn.ReLU(inplace=True), | |
| nn.MaxPool2d(2, 2), | |
| nn.Conv2d(64, 128, 3, padding=1), | |
| nn.BatchNorm2d(128), | |
| nn.ReLU(inplace=True), | |
| nn.AdaptiveAvgPool2d((4, 4)), | |
| ) | |
| self.classifier = nn.Sequential( | |
| nn.Dropout(0.5), | |
| nn.Linear(128 * 4 * 4, 256), | |
| nn.ReLU(inplace=True), | |
| nn.Dropout(0.3), | |
| nn.Linear(256, 1), | |
| ) | |
| def forward(self, x): | |
| x = self.features(x) | |
| x = x.view(x.size(0), -1) | |
| x = self.classifier(x) | |
| return x.squeeze(-1) | |
| class ResidualCNN7ch(nn.Module): | |
| """7-channel CNN for v9.x SOTA pipeline. | |
| 4-layer Conv + GlobalAvgPool + FC. ResidualCNNNch(3-conv)보다 깊음. | |
| 가중치: models/cnn_v94_best.pt (v9.4 SOTA, balanced dataset)""" | |
| def __init__(self): | |
| super().__init__() | |
| self.conv1 = nn.Conv2d(7, 32, 3, padding=1); self.bn1 = nn.BatchNorm2d(32); self.pool1 = nn.MaxPool2d(2) | |
| self.conv2 = nn.Conv2d(32, 64, 3, padding=1); self.bn2 = nn.BatchNorm2d(64); self.pool2 = nn.MaxPool2d(2) | |
| self.conv3 = nn.Conv2d(64, 128, 3, padding=1); self.bn3 = nn.BatchNorm2d(128); self.pool3 = nn.MaxPool2d(2) | |
| self.conv4 = nn.Conv2d(128, 256, 3, padding=1);self.bn4 = nn.BatchNorm2d(256); self.pool4 = nn.MaxPool2d(2) | |
| self.global_pool = nn.AdaptiveAvgPool2d(1) | |
| self.fc1 = nn.Linear(256, 128) | |
| self.dropout = nn.Dropout(0.5) | |
| self.fc2 = nn.Linear(128, 1) | |
| def forward(self, x): | |
| """x: (B, 7, N_MELS, T) → (B,) logits""" | |
| x = self.pool1(F.relu(self.bn1(self.conv1(x)))) | |
| x = self.pool2(F.relu(self.bn2(self.conv2(x)))) | |
| x = self.pool3(F.relu(self.bn3(self.conv3(x)))) | |
| x = self.pool4(F.relu(self.bn4(self.conv4(x)))) | |
| x = self.global_pool(x).view(x.size(0), -1) | |
| return self.fc2(F.relu(self.fc1(x))).view(-1) | |
| # ============================================================ | |
| # DifferentiableMel | |
| # ============================================================ | |
| class DifferentiableMel(nn.Module): | |
| """STFT magnitude -> log-mel dB (normalized).""" | |
| def __init__(self, sr=44100, n_fft=2048, n_mels=128, top_db=80.0): | |
| super().__init__() | |
| n_freqs = n_fft // 2 + 1 | |
| fb = self._create_mel_fb(n_freqs, n_mels, 0.0, sr / 2, sr) | |
| self.register_buffer('fb', fb) | |
| self.top_db = top_db | |
| def _create_mel_fb(n_freqs, n_mels, f_min, f_max, sr): | |
| def hz_to_mel(f): | |
| return 2595.0 * np.log10(1.0 + f / 700.0) | |
| def mel_to_hz(m): | |
| return 700.0 * (10.0 ** (m / 2595.0) - 1.0) | |
| mel_min = hz_to_mel(f_min) | |
| mel_max = hz_to_mel(f_max) | |
| mel_pts = np.linspace(mel_min, mel_max, n_mels + 2) | |
| hz_pts = mel_to_hz(mel_pts) | |
| freqs = np.linspace(0, sr / 2, n_freqs) | |
| fb = np.zeros((n_freqs, n_mels), dtype=np.float32) | |
| for i in range(n_mels): | |
| lo, mid, hi = hz_pts[i], hz_pts[i + 1], hz_pts[i + 2] | |
| for j in range(n_freqs): | |
| if lo <= freqs[j] <= mid and (mid - lo) > 0: | |
| fb[j, i] = (freqs[j] - lo) / (mid - lo) | |
| elif mid < freqs[j] <= hi and (hi - mid) > 0: | |
| fb[j, i] = (hi - freqs[j]) / (hi - mid) | |
| return torch.from_numpy(fb) | |
| def forward(self, stft_mag): | |
| """(B, 1, F, T) -> (B, 1, N_MELS, T) log-mel normalized.""" | |
| x = stft_mag.squeeze(1) | |
| power = x ** 2 | |
| mel = torch.einsum('fm,bft->bmt', self.fb, power) | |
| mel_db = 10.0 * torch.log10(torch.clamp(mel, min=1e-10)) | |
| max_val = mel_db.amax(dim=(-2, -1), keepdim=True) | |
| mel_db = torch.clamp(mel_db, min=max_val - self.top_db) | |
| mean = mel_db.mean(dim=(-2, -1), keepdim=True) | |
| std = mel_db.std(dim=(-2, -1), keepdim=True) | |
| mel_norm = (mel_db - mean) / (std + 1e-9) | |
| return mel_norm.unsqueeze(1) | |
| # ============================================================ | |
| # CPU HPSS (librosa) | |
| # ============================================================ | |
| def hpss_cpu(mag): | |
| """HPSS via librosa on CPU. mag: (B, 1, F, T) tensor -> H_mag, P_mag tensors. | |
| 각 배치를 numpy로 변환 → librosa.decompose.hpss → 다시 tensor. | |
| 데모용 CPU 파이프라인. 학습용 GPU HPSS는 train_nch_cnn_020303.py 참조. | |
| """ | |
| import librosa | |
| device = mag.device | |
| B = mag.shape[0] | |
| mag_np = mag.squeeze(1).cpu().numpy() # (B, F, T) | |
| H_list, P_list = [], [] | |
| for i in range(B): | |
| H, P = librosa.decompose.hpss(mag_np[i], kernel_size=31) | |
| H_list.append(H) | |
| P_list.append(P) | |
| H_mag = torch.from_numpy(np.stack(H_list)).unsqueeze(1).to(device) # (B, 1, F, T) | |
| P_mag = torch.from_numpy(np.stack(P_list)).unsqueeze(1).to(device) | |
| return H_mag, P_mag | |
| # ============================================================ | |
| # GPU/MPS HPSS (순수 PyTorch — unfold + median, Triton 불필요) | |
| # ============================================================ | |
| def _gpu_median_filter_2d(x, kernel_size, dim): | |
| """GPU median filter along one axis using unfold + median. | |
| CUDA에서 빠름. MPS에서는 median이 극도로 느리므로 _avg_filter_2d 사용 권장. | |
| Args: | |
| x: (B, F, T) tensor on GPU | |
| kernel_size: odd integer | |
| dim: 1=freq축 (P 추출), 2=time축 (H 추출) | |
| """ | |
| pad = kernel_size // 2 | |
| if dim == 2: | |
| x_pad = F.pad(x, (pad, pad), mode='reflect') | |
| x_unfold = x_pad.unfold(2, kernel_size, 1) | |
| else: | |
| x_pad = F.pad(x, (0, 0, pad, pad), mode='reflect') | |
| x_unfold = x_pad.unfold(1, kernel_size, 1) | |
| return x_unfold.median(dim=-1).values | |
| def _avg_filter_2d(x, kernel_size, dim): | |
| """avg_pool 기반 smoothing filter — MPS 최적화 (median 대비 400x 빠름). | |
| median과 동일하지 않지만, HPSS Wiener masking에서 충분한 근사. | |
| H/P 비율 계산에서 절대값보다 상대적 크기가 중요하므로 성능 차이 미미. | |
| Args: | |
| x: (B, F, T) tensor | |
| kernel_size: odd integer | |
| dim: 1=freq축, 2=time축 | |
| """ | |
| pad = kernel_size // 2 | |
| B, F_dim, T = x.shape | |
| if dim == 2: # time축 | |
| x_flat = x.reshape(B * F_dim, 1, T) | |
| out = F.avg_pool1d(x_flat, kernel_size=kernel_size, stride=1, padding=pad) | |
| return out.reshape(B, F_dim, T) | |
| else: # freq축 | |
| x_t = x.transpose(1, 2) # (B, T, F) | |
| x_flat = x_t.reshape(B * T, 1, F_dim) | |
| out = F.avg_pool1d(x_flat, kernel_size=kernel_size, stride=1, padding=pad) | |
| return out.reshape(B, T, F_dim).transpose(1, 2) | |
| def hpss_gpu_pure(mag, h_kernel=31, p_kernel=31): | |
| """순수 PyTorch HPSS — CUDA/MPS 모두 호환. | |
| CUDA: unfold + median (정확), MPS: avg_pool 근사 (400x 빠름). | |
| Args: | |
| mag: (B, 1, F, T) STFT magnitude on any device | |
| Returns: | |
| H_mag, P_mag: (B, 1, F, T) | |
| """ | |
| mag_sq = mag.squeeze(1) # (B, F, T) | |
| # 모든 CNN이 median filter HPSS로 학습됨 → avg_pool 근사 사용 금지 | |
| # MPS에서 unfold().median()이 극도로 느림 (13초/곡) → CPU에서 수행 후 복귀 | |
| if mag_sq.device.type == 'mps': | |
| orig_device = mag_sq.device | |
| mag_cpu = mag_sq.cpu() | |
| H_filter = _gpu_median_filter_2d(mag_cpu, h_kernel, dim=2).to(orig_device) | |
| P_filter = _gpu_median_filter_2d(mag_cpu, p_kernel, dim=1).to(orig_device) | |
| else: | |
| H_filter = _gpu_median_filter_2d(mag_sq, h_kernel, dim=2) | |
| P_filter = _gpu_median_filter_2d(mag_sq, p_kernel, dim=1) | |
| H2 = H_filter ** 2 | |
| P2 = P_filter ** 2 | |
| denom = H2 + P2 + 1e-10 | |
| H_mask = H2 / denom | |
| P_mask = P2 / denom | |
| H_mag = (mag_sq * H_mask).unsqueeze(1) | |
| P_mag = (mag_sq * P_mask).unsqueeze(1) | |
| return H_mag, P_mag | |
| # ============================================================ | |
| # 7ch Forensic Feature Computation | |
| # ============================================================ | |
| def compute_forensic_features_7ch(mel_res, mel_H, mel_P): | |
| """Compute 7-channel forensic features from HPSS mel spectrograms. | |
| Channels: | |
| ch1: mel_residual - UNet residual mel spectrogram | |
| ch2: mel_harmonic - HPSS harmonic mel | |
| ch3: mel_percussive - HPSS percussive mel | |
| ch4: delta - temporal 1st derivative | |
| ch5: delta2 - temporal 2nd derivative | |
| ch6: hp_ratio - log(H/P) ratio | |
| ch7: spectral_flux - |delta| (absolute spectral change) | |
| Args: | |
| mel_res: (B, 1, N_MELS, T) | |
| mel_H: (B, 1, N_MELS, T) | |
| mel_P: (B, 1, N_MELS, T) | |
| Returns: | |
| (B, 7, N_MELS, T) concatenated features | |
| """ | |
| delta = torch.diff(mel_res, n=1, dim=-1) | |
| delta = F.pad(delta, (1, 0)) | |
| delta2 = torch.diff(delta, n=1, dim=-1) | |
| delta2 = F.pad(delta2, (1, 0)) | |
| hp_ratio = mel_H - mel_P | |
| spectral_flux = torch.abs(delta) | |
| return torch.cat([mel_res, mel_H, mel_P, delta, delta2, hp_ratio, spectral_flux], dim=1) | |