buchi-stdesign's picture
Upload 18 files
1ee91f8 verified
import math
import torch
import torch.nn.functional as F
import json
def init_weights(m):
if isinstance(m, torch.nn.Conv1d):
torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('relu'))
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
elif isinstance(m, torch.nn.ConvTranspose1d):
torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('relu'))
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
elif isinstance(m, torch.nn.Linear):
torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('relu'))
if m.bias is not None:
torch.nn.init.zeros_(m.bias)
def kl_divergence(m_p, logs_p, m_q, logs_q):
"""KLダイバージェンス(確率分布の違い)を計算"""
return 0.5 * (logs_q - logs_p - 1 + (torch.exp(2 * logs_p) + (m_p - m_q) ** 2) / torch.exp(2 * logs_q))
def rand_gumbel(shape):
"""ガンベル分布から乱数をサンプリング"""
return -math.log(-math.log(torch.rand(shape, device="cpu").clamp(1e-5, 1 - 1e-5)))
def rand_uniform(shape):
"""一様分布から乱数をサンプリング"""
return torch.rand(shape, device="cpu")
def rand_logistic(shape):
"""ロジスティック分布から乱数をサンプリング"""
return torch.distributions.RelaxedOneHotCategorical(1.0, logits=torch.zeros(shape)).sample()
def slice_segments(x, ids_str, segment_size=4):
"""入力テンソルxからids_strをもとにセグメントをスライス"""
ret = []
for i, ids in enumerate(ids_str):
start = ids * segment_size
ret.append(x[i, :, start: start + segment_size])
return torch.stack(ret)
def rand_slice_segments(x, x_lengths=None, segment_size=4):
"""ランダムにセグメントをスライス"""
b, d, t = x.size()
if x_lengths is None:
ids_str = torch.randint(0, t - segment_size, (b,), device=x.device)
else:
ids_str = (torch.rand(b, device=x.device) * (x_lengths - segment_size)).long()
return slice_segments(x, ids_str, segment_size)
def get_hparams_from_file(config_path):
"""設定ファイル(config.json)を読み込んで辞書型に変換"""
with open(config_path, "r", encoding="utf-8") as f:
data = f.read()
config = json.loads(data)
return config