Spaces:
Sleeping
Sleeping
| import tempfile | |
| from typing import Tuple | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from pathlib import Path | |
| from model import convNet | |
| from preprocess import Audio, fft_and_melscale | |
| from synthesize import create_tja, detect, synthesize | |
| def trim_silence(data: np.ndarray, sr: int): | |
| start = 0 | |
| end = len(data) - 1 | |
| while start < len(data) and np.abs(data[start]) < 0.2: | |
| start += 1 | |
| while end > 0 and np.abs(data[end]) < 0.1: | |
| end -= 1 | |
| start = max(start - sr * 3, 0) | |
| end = min(end + sr * 3, len(data)) | |
| print( | |
| f"Trimming {start/sr} seconds from the start and {end/sr} seconds from the end" | |
| ) | |
| data = data[start:end] | |
| return data | |
| class ODCNN: | |
| def __init__(self, don_model: str, ka_model: str, device: torch.device = "cpu"): | |
| donNet = convNet() | |
| donNet = donNet.to(device) | |
| donNet.load_state_dict(torch.load(don_model, map_location="cpu")) | |
| self.donNet = donNet | |
| kaNet = convNet() | |
| kaNet = kaNet.to(device) | |
| kaNet.load_state_dict(torch.load(ka_model, map_location="cpu")) | |
| self.kaNet = kaNet | |
| self.device = device | |
| def run(self, file: str, delta=0.05, trim=True) -> Tuple[str, str]: | |
| data, sr = sf.read(file, always_2d=True) | |
| song = Audio(data, sr) | |
| song.data = song.data.mean(axis=1) | |
| if trim: | |
| song.data = trim_silence(song.data, sr) | |
| song.feats = fft_and_melscale( | |
| song, | |
| nhop=512, | |
| nffts=[1024, 2048, 4096], | |
| mel_nband=80, | |
| mel_freqlo=27.5, | |
| mel_freqhi=16000.0, | |
| ) | |
| don_inference = self.donNet.infer(song.feats, self.device, minibatch=4192) | |
| don_inference = np.reshape(don_inference, (-1)) | |
| ka_inference = self.kaNet.infer(song.feats, self.device, minibatch=4192) | |
| ka_inference = np.reshape(ka_inference, (-1)) | |
| easy_detection = detect(don_inference, ka_inference, delta=0.25) | |
| normal_detection = detect(don_inference, ka_inference, delta=0.2) | |
| hard_detection = detect(don_inference, ka_inference, delta=0.15) | |
| oni_detection = detect(don_inference, ka_inference, delta=0.075) | |
| ura_detection = detect(don_inference, ka_inference, delta=delta) | |
| synthesized_path = tempfile.NamedTemporaryFile(suffix=".mp3", delete=False).name | |
| synthesize(*hard_detection, song, synthesized_path) | |
| file = Path(file) | |
| tja = create_tja( | |
| song, | |
| timestamps=[ | |
| easy_detection, | |
| normal_detection, | |
| hard_detection, | |
| oni_detection, | |
| ura_detection, | |
| ], | |
| title=file.stem, | |
| wave=file.name, | |
| ) | |
| return synthesized_path, tja | |