Spaces:
Runtime error
Runtime error
| import math | |
| import torch | |
| import torch.nn.functional as F | |
| import json | |
| def init_weights(m): | |
| if isinstance(m, torch.nn.Conv1d): | |
| torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('relu')) | |
| if m.bias is not None: | |
| torch.nn.init.zeros_(m.bias) | |
| elif isinstance(m, torch.nn.ConvTranspose1d): | |
| torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('relu')) | |
| if m.bias is not None: | |
| torch.nn.init.zeros_(m.bias) | |
| elif isinstance(m, torch.nn.Linear): | |
| torch.nn.init.xavier_uniform_(m.weight, gain=torch.nn.init.calculate_gain('relu')) | |
| if m.bias is not None: | |
| torch.nn.init.zeros_(m.bias) | |
| def kl_divergence(m_p, logs_p, m_q, logs_q): | |
| """KLダイバージェンス(確率分布の違い)を計算""" | |
| return 0.5 * (logs_q - logs_p - 1 + (torch.exp(2 * logs_p) + (m_p - m_q) ** 2) / torch.exp(2 * logs_q)) | |
| def rand_gumbel(shape): | |
| """ガンベル分布から乱数をサンプリング""" | |
| return -math.log(-math.log(torch.rand(shape, device="cpu").clamp(1e-5, 1 - 1e-5))) | |
| def rand_uniform(shape): | |
| """一様分布から乱数をサンプリング""" | |
| return torch.rand(shape, device="cpu") | |
| def rand_logistic(shape): | |
| """ロジスティック分布から乱数をサンプリング""" | |
| return torch.distributions.RelaxedOneHotCategorical(1.0, logits=torch.zeros(shape)).sample() | |
| def slice_segments(x, ids_str, segment_size=4): | |
| """入力テンソルxからids_strをもとにセグメントをスライス""" | |
| ret = [] | |
| for i, ids in enumerate(ids_str): | |
| start = ids * segment_size | |
| ret.append(x[i, :, start: start + segment_size]) | |
| return torch.stack(ret) | |
| def rand_slice_segments(x, x_lengths=None, segment_size=4): | |
| """ランダムにセグメントをスライス""" | |
| b, d, t = x.size() | |
| if x_lengths is None: | |
| ids_str = torch.randint(0, t - segment_size, (b,), device=x.device) | |
| else: | |
| ids_str = (torch.rand(b, device=x.device) * (x_lengths - segment_size)).long() | |
| return slice_segments(x, ids_str, segment_size) | |
| def get_hparams_from_file(config_path): | |
| """設定ファイル(config.json)を読み込んで辞書型に変換""" | |
| with open(config_path, "r", encoding="utf-8") as f: | |
| data = f.read() | |
| config = json.loads(data) | |
| return config | |