guilinhu's picture
Upload folder using huggingface_hub
df9f13e verified
import os
import glob
import importlib
import json
import librosa
import soundfile as sf
import torch
import torchaudio
import math
import torch.nn as nn
class PositionalEncoding(nn.Module):
"""This class implements the absolute sinusoidal positional encoding function.
PE(pos, 2i) = sin(pos/(10000^(2i/dmodel)))
PE(pos, 2i+1) = cos(pos/(10000^(2i/dmodel)))
Arguments
---------
input_size: int
Embedding dimension.
max_len : int, optional
Max length of the input sequences (default 2500).
Example
-------
>>> a = torch.rand((8, 120, 512))
>>> enc = PositionalEncoding(input_size=a.shape[-1])
>>> b = enc(a)
>>> b.shape
torch.Size([1, 120, 512])
"""
def __init__(self, input_size, max_len=2500):
super().__init__()
if input_size % 2 != 0:
raise ValueError(f"Cannot use sin/cos positional encoding with odd channels (got channels={input_size})")
self.max_len = max_len
pe = torch.zeros(self.max_len, input_size, requires_grad=False)
positions = torch.arange(0, self.max_len).unsqueeze(1).float()
denominator = torch.exp(torch.arange(0, input_size, 2).float() * -(math.log(10000.0) / input_size))
pe[:, 0::2] = torch.sin(positions * denominator)
pe[:, 1::2] = torch.cos(positions * denominator)
pe = pe.unsqueeze(0)
self.register_buffer("pe", pe)
def forward(self, x):
"""
Arguments
---------
x : tensor
Input feature shape (batch, time, fea)
"""
return self.pe[:, : x.size(1)].clone().detach()
def count_parameters(model):
"""
Count the number of parameters in a PyTorch model.
Parameters:
model (torch.nn.Module): The PyTorch model.
Returns:
int: Number of parameters in the model.
"""
N_param = sum(p.numel() for p in model.parameters())
print(f"Model params number {N_param/1e6} M")
def import_attr(import_path):
module, attr = import_path.rsplit(".", 1)
return getattr(importlib.import_module(module), attr)
class Params:
"""Class that loads hyperparameters from a json file.
Example:
```
params = Params(json_path)
print(params.learning_rate)
params.learning_rate = 0.5 # change the value of learning_rate in params
```
"""
def __init__(self, json_path):
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
def save(self, json_path):
with open(json_path, "w") as f:
json.dump(self.__dict__, f, indent=4)
def update(self, json_path):
"""Loads parameters from json file"""
with open(json_path) as f:
params = json.load(f)
self.__dict__.update(params)
@property
def dict(self):
"""Gives dict-like access to Params instance by `params.dict['learning_rate']"""
return self.__dict__
def load_net_torch(expriment_config, return_params=False):
params = Params(expriment_config)
params.pl_module_args["slow_model_ckpt"] = None
params.pl_module_args["use_dp"] = False
params.pl_module_args["prev_ckpt"] = None
pl_module = import_attr(params.pl_module)(**params.pl_module_args)
with open(expriment_config) as f:
params = json.load(f)
if return_params:
return pl_module, params
else:
return pl_module
def load_net(expriment_config, return_params=False):
params = Params(expriment_config)
params.pl_module_args["use_dp"] = False
pl_module = import_attr(params.pl_module)(**params.pl_module_args)
with open(expriment_config) as f:
params = json.load(f)
if return_params:
return pl_module, params
else:
return pl_module
def load_pretrained(run_dir, return_params=False, map_location="cpu", use_last=False):
config_path = os.path.join(run_dir, "config.json")
pl_module, params = load_net(config_path, return_params=True)
# Get all "best" checkpoints
if use_last:
name = "last.pt"
else:
name = "best.pt"
ckpt_path = os.path.join(run_dir, f"checkpoints/{name}")
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Given run ({run_dir}) doesn't have any pretrained checkpoints!")
print("Loading checkpoint from", ckpt_path)
# Load checkpoint
# state_dict = torch.load(ckpt_path, map_location=map_location)['state_dict']
pl_module.load_state(ckpt_path, map_location)
print("Loaded module at epoch", pl_module.epoch)
if return_params:
return pl_module, params
else:
return pl_module
def load_pretrained_with_last(run_dir, return_params=False, map_location="cpu", use_last=False):
config_path = os.path.join(run_dir, "config.json")
pl_module, params = load_net(config_path, return_params=True)
# Get all "best" checkpoints
if use_last:
name = "last.pt"
else:
name = "best.pt"
ckpt_path = os.path.join(run_dir, f"checkpoints/{name}")
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Given run ({run_dir}) doesn't have any pretrained checkpoints!")
print("Loading checkpoint from", ckpt_path)
# Load checkpoint
# state_dict = torch.load(ckpt_path, map_location=map_location)['state_dict']
pl_module.load_state(ckpt_path, map_location)
print("Loaded module at epoch", pl_module.epoch)
if return_params:
return pl_module, params
else:
return pl_module
def load_pretrained2(run_dir, return_params=False, map_location="cpu"):
config_path = os.path.join(run_dir, "config.json")
pl_module, params = load_net(config_path, return_params=True)
ckpt_path = os.path.join(run_dir, "checkpoints", "best.pt")
print("Loading checkpoint from", ckpt_path)
# Load checkpoint
# state_dict = torch.load(ckpt_path, map_location=map_location)['state_dict']
pl_module.load_state(ckpt_path)
if return_params:
return pl_module, params
else:
return pl_module
def load_torch_pretrained(run_dir, return_params=False, map_location="cpu", model_epoch="best"):
config_path = os.path.join(run_dir, "config.json")
print(config_path)
pl_module, params = load_net_torch(config_path, return_params=True)
# Get all "best" checkpoints
ckpt_path = os.path.join(run_dir, f"checkpoints/{model_epoch}.pt")
if not os.path.exists(ckpt_path):
raise FileNotFoundError(f"Given run ({run_dir}) doesn't have any pretrained checkpoints!")
print("Loading checkpoint from", ckpt_path)
# Load checkpoint
# state_dict = torch.load(ckpt_path, map_location=map_location)['state_dict']
pl_module.load_state(ckpt_path, map_location)
print("Loaded module at epoch", pl_module.epoch)
if return_params:
return pl_module, params
else:
return pl_module
def read_audio_file(file_path, sr):
"""
Reads audio file to system memory.
"""
return librosa.core.load(file_path, mono=False, sr=sr)[0]
def read_audio_file_torch(file_path, downsample=1, input_mean=False):
waveform, sample_rate = torchaudio.load(file_path)
if downsample > 1:
waveform = torchaudio.functional.resample(waveform, sample_rate, sample_rate // downsample)
if waveform.shape[0] > 1 and input_mean == True:
waveform = torch.mean(waveform, dim=0)
waveform = waveform.unsqueeze(0)
elif waveform.shape[0] > 1 and input_mean == "L":
waveform = waveform[0:1, ...]
elif waveform.shape[0] > 1 and input_mean == "R":
waveform = waveform[1:2, ...]
return waveform
def write_audio_file(file_path, data, sr, subtype="PCM_16"):
"""
Writes audio file to system memory.
@param file_path: Path of the file to write to
@param data: Audio signal to write (n_channels x n_samples)
@param sr: Sampling rate
"""
sf.write(file_path, data.T, sr, subtype)
def read_json(path):
with open(path, "rb") as f:
return json.load(f)
import random
import numpy as np
def seed_all(seed):
random.seed(seed)
np.random.seed(seed)
torch.manual_seed(seed)
if torch.cuda.is_available():
torch.cuda.manual_seed(seed)