import math import torch import torch.nn.functional as F import json def init_weights(m): if isinstance(m, torch.nn.Conv1d): torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('relu')) if m.bias is not None: torch.nn.init.zeros_(m.bias) elif isinstance(m, torch.nn.ConvTranspose1d): torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('relu')) if m.bias is not None: torch.nn.init.zeros_(m.bias) elif isinstance(m, torch.nn.Linear): torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('relu')) if m.bias is not None: torch.nn.init.zeros_(m.bias) def kl_divergence(m_p, logs_p, m_q, logs_q): """KLダイバージェンス(確率分布の違い)を計算""" return 0.5 * (logs_q - logs_p - 1 + (torch.exp(2 * logs_p) + (m_p - m_q) ** 2) / torch.exp(2 * logs_q)) def rand_gumbel(shape): """ガンベル分布から乱数をサンプリング""" return -math.log(-math.log(torch.rand(shape, device="cpu").clamp(1e-5, 1 - 1e-5))) def rand_uniform(shape): """一様分布から乱数をサンプリング""" return torch.rand(shape, device="cpu") def rand_logistic(shape): """ロジスティック分布から乱数をサンプリング""" return torch.distributions.RelaxedOneHotCategorical(1.0, logits=torch.zeros(shape)).sample() def slice_segments(x, ids_str, segment_size=4): """入力テンソルxからids_strをもとにセグメントをスライス""" ret = [] for i, ids in enumerate(ids_str): start = ids * segment_size ret.append(x[i, :, start: start + segment_size]) return torch.stack(ret) def rand_slice_segments(x, x_lengths=None, segment_size=4): """ランダムにセグメントをスライス""" b, d, t = x.size() if x_lengths is None: ids_str = torch.randint(0, t - segment_size, (b,), device=x.device) else: ids_str = (torch.rand(b, device=x.device) * (x_lengths - segment_size)).long() return slice_segments(x, ids_str, segment_size) def get_hparams_from_file(config_path): """設定ファイル(config.json)を読み込んで辞書型に変換""" with open(config_path, "r", encoding="utf-8") as f: data = f.read() config = json.loads(data) return config