|
|
import os |
|
|
import glob |
|
|
import importlib |
|
|
import json |
|
|
|
|
|
import librosa |
|
|
import soundfile as sf |
|
|
import torch |
|
|
import torchaudio |
|
|
import math |
|
|
import torch.nn as nn |
|
|
|
|
|
|
|
|
class PositionalEncoding(nn.Module): |
|
|
"""This class implements the absolute sinusoidal positional encoding function. |
|
|
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel))) |
|
|
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel))) |
|
|
Arguments |
|
|
--------- |
|
|
input_size: int |
|
|
Embedding dimension. |
|
|
max_len : int, optional |
|
|
Max length of the input sequences (default 2500). |
|
|
Example |
|
|
------- |
|
|
>>> a = torch.rand((8, 120, 512)) |
|
|
>>> enc = PositionalEncoding(input_size=a.shape[-1]) |
|
|
>>> b = enc(a) |
|
|
>>> b.shape |
|
|
torch.Size([1, 120, 512]) |
|
|
""" |
|
|
|
|
|
def __init__(self, input_size, max_len=2500): |
|
|
super().__init__() |
|
|
if input_size % 2 != 0: |
|
|
raise ValueError(f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})") |
|
|
self.max_len = max_len |
|
|
pe = torch.zeros(self.max_len, input_size, requires_grad=False) |
|
|
positions = torch.arange(0, self.max_len).unsqueeze(1).float() |
|
|
denominator = torch.exp(torch.arange(0, input_size, 2).float() * -(math.log(10000.0) / input_size)) |
|
|
|
|
|
pe[:, 0::2] = torch.sin(positions * denominator) |
|
|
pe[:, 1::2] = torch.cos(positions * denominator) |
|
|
pe = pe.unsqueeze(0) |
|
|
self.register_buffer("pe", pe) |
|
|
|
|
|
def forward(self, x): |
|
|
""" |
|
|
Arguments |
|
|
--------- |
|
|
x : tensor |
|
|
Input feature shape (batch, time, fea) |
|
|
""" |
|
|
return self.pe[:, : x.size(1)].clone().detach() |
|
|
|
|
|
|
|
|
def count_parameters(model): |
|
|
""" |
|
|
Count the number of parameters in a PyTorch model. |
|
|
|
|
|
Parameters: |
|
|
model (torch.nn.Module): The PyTorch model. |
|
|
|
|
|
Returns: |
|
|
int: Number of parameters in the model. |
|
|
""" |
|
|
N_param = sum(p.numel() for p in model.parameters()) |
|
|
print(f"Model params number {N_param/1e6} M") |
|
|
|
|
|
|
|
|
def import_attr(import_path): |
|
|
module, attr = import_path.rsplit(".", 1) |
|
|
return getattr(importlib.import_module(module), attr) |
|
|
|
|
|
|
|
|
class Params: |
|
|
"""Class that loads hyperparameters from a json file. |
|
|
Example: |
|
|
``` |
|
|
params = Params(json_path) |
|
|
print(params.learning_rate) |
|
|
params.learning_rate = 0.5 # change the value of learning_rate in params |
|
|
``` |
|
|
""" |
|
|
|
|
|
def __init__(self, json_path): |
|
|
with open(json_path) as f: |
|
|
params = json.load(f) |
|
|
self.__dict__.update(params) |
|
|
|
|
|
def save(self, json_path): |
|
|
with open(json_path, "w") as f: |
|
|
json.dump(self.__dict__, f, indent=4) |
|
|
|
|
|
def update(self, json_path): |
|
|
"""Loads parameters from json file""" |
|
|
with open(json_path) as f: |
|
|
params = json.load(f) |
|
|
self.__dict__.update(params) |
|
|
|
|
|
@property |
|
|
def dict(self): |
|
|
"""Gives dict-like access to Params instance by `params.dict['learning_rate']""" |
|
|
return self.__dict__ |
|
|
|
|
|
|
|
|
def load_net_torch(expriment_config, return_params=False): |
|
|
params = Params(expriment_config) |
|
|
params.pl_module_args["slow_model_ckpt"] = None |
|
|
params.pl_module_args["use_dp"] = False |
|
|
params.pl_module_args["prev_ckpt"] = None |
|
|
pl_module = import_attr(params.pl_module)(**params.pl_module_args) |
|
|
|
|
|
with open(expriment_config) as f: |
|
|
params = json.load(f) |
|
|
|
|
|
if return_params: |
|
|
return pl_module, params |
|
|
else: |
|
|
return pl_module |
|
|
|
|
|
|
|
|
def load_net(expriment_config, return_params=False): |
|
|
params = Params(expriment_config) |
|
|
params.pl_module_args["use_dp"] = False |
|
|
pl_module = import_attr(params.pl_module)(**params.pl_module_args) |
|
|
|
|
|
with open(expriment_config) as f: |
|
|
params = json.load(f) |
|
|
|
|
|
if return_params: |
|
|
return pl_module, params |
|
|
else: |
|
|
return pl_module |
|
|
|
|
|
|
|
|
def load_pretrained(run_dir, return_params=False, map_location="cpu", use_last=False): |
|
|
config_path = os.path.join(run_dir, "config.json") |
|
|
|
|
|
pl_module, params = load_net(config_path, return_params=True) |
|
|
|
|
|
|
|
|
if use_last: |
|
|
name = "last.pt" |
|
|
else: |
|
|
name = "best.pt" |
|
|
ckpt_path = os.path.join(run_dir, f"checkpoints/{name}") |
|
|
|
|
|
if not os.path.exists(ckpt_path): |
|
|
raise FileNotFoundError(f"Given run ({run_dir}) doesn't have any pretrained checkpoints!") |
|
|
|
|
|
print("Loading checkpoint from", ckpt_path) |
|
|
|
|
|
|
|
|
|
|
|
pl_module.load_state(ckpt_path, map_location) |
|
|
print("Loaded module at epoch", pl_module.epoch) |
|
|
|
|
|
if return_params: |
|
|
return pl_module, params |
|
|
else: |
|
|
return pl_module |
|
|
|
|
|
|
|
|
def load_pretrained_with_last(run_dir, return_params=False, map_location="cpu", use_last=False): |
|
|
config_path = os.path.join(run_dir, "config.json") |
|
|
|
|
|
pl_module, params = load_net(config_path, return_params=True) |
|
|
|
|
|
|
|
|
if use_last: |
|
|
name = "last.pt" |
|
|
else: |
|
|
name = "best.pt" |
|
|
ckpt_path = os.path.join(run_dir, f"checkpoints/{name}") |
|
|
|
|
|
if not os.path.exists(ckpt_path): |
|
|
raise FileNotFoundError(f"Given run ({run_dir}) doesn't have any pretrained checkpoints!") |
|
|
|
|
|
print("Loading checkpoint from", ckpt_path) |
|
|
|
|
|
|
|
|
|
|
|
pl_module.load_state(ckpt_path, map_location) |
|
|
print("Loaded module at epoch", pl_module.epoch) |
|
|
|
|
|
if return_params: |
|
|
return pl_module, params |
|
|
else: |
|
|
return pl_module |
|
|
|
|
|
|
|
|
def load_pretrained2(run_dir, return_params=False, map_location="cpu"): |
|
|
config_path = os.path.join(run_dir, "config.json") |
|
|
pl_module, params = load_net(config_path, return_params=True) |
|
|
|
|
|
ckpt_path = os.path.join(run_dir, "checkpoints", "best.pt") |
|
|
print("Loading checkpoint from", ckpt_path) |
|
|
|
|
|
|
|
|
|
|
|
pl_module.load_state(ckpt_path) |
|
|
|
|
|
if return_params: |
|
|
return pl_module, params |
|
|
else: |
|
|
return pl_module |
|
|
|
|
|
|
|
|
def load_torch_pretrained(run_dir, return_params=False, map_location="cpu", model_epoch="best"): |
|
|
config_path = os.path.join(run_dir, "config.json") |
|
|
|
|
|
print(config_path) |
|
|
pl_module, params = load_net_torch(config_path, return_params=True) |
|
|
|
|
|
|
|
|
ckpt_path = os.path.join(run_dir, f"checkpoints/{model_epoch}.pt") |
|
|
|
|
|
if not os.path.exists(ckpt_path): |
|
|
raise FileNotFoundError(f"Given run ({run_dir}) doesn't have any pretrained checkpoints!") |
|
|
|
|
|
print("Loading checkpoint from", ckpt_path) |
|
|
|
|
|
|
|
|
|
|
|
pl_module.load_state(ckpt_path, map_location) |
|
|
print("Loaded module at epoch", pl_module.epoch) |
|
|
|
|
|
if return_params: |
|
|
return pl_module, params |
|
|
else: |
|
|
return pl_module |
|
|
|
|
|
|
|
|
def read_audio_file(file_path, sr): |
|
|
""" |
|
|
Reads audio file to system memory. |
|
|
""" |
|
|
return librosa.core.load(file_path, mono=False, sr=sr)[0] |
|
|
|
|
|
|
|
|
def read_audio_file_torch(file_path, downsample=1, input_mean=False): |
|
|
waveform, sample_rate = torchaudio.load(file_path) |
|
|
if downsample > 1: |
|
|
waveform = torchaudio.functional.resample(waveform, sample_rate, sample_rate // downsample) |
|
|
|
|
|
if waveform.shape[0] > 1 and input_mean == True: |
|
|
waveform = torch.mean(waveform, dim=0) |
|
|
waveform = waveform.unsqueeze(0) |
|
|
|
|
|
elif waveform.shape[0] > 1 and input_mean == "L": |
|
|
waveform = waveform[0:1, ...] |
|
|
|
|
|
elif waveform.shape[0] > 1 and input_mean == "R": |
|
|
waveform = waveform[1:2, ...] |
|
|
|
|
|
return waveform |
|
|
|
|
|
|
|
|
def write_audio_file(file_path, data, sr, subtype="PCM_16"): |
|
|
""" |
|
|
Writes audio file to system memory. |
|
|
@param file_path: Path of the file to write to |
|
|
@param data: Audio signal to write (n_channels x n_samples) |
|
|
@param sr: Sampling rate |
|
|
""" |
|
|
sf.write(file_path, data.T, sr, subtype) |
|
|
|
|
|
|
|
|
def read_json(path): |
|
|
with open(path, "rb") as f: |
|
|
return json.load(f) |
|
|
|
|
|
|
|
|
import random |
|
|
import numpy as np |
|
|
|
|
|
|
|
|
def seed_all(seed): |
|
|
random.seed(seed) |
|
|
np.random.seed(seed) |
|
|
torch.manual_seed(seed) |
|
|
|
|
|
if torch.cuda.is_available(): |
|
|
torch.cuda.manual_seed(seed) |
|
|
|