Yusen commited on
Commit
22224e1
·
1 Parent(s): 68e03b2

Upload commons.py

Browse files
Files changed (1) hide show
  1. commons.py +188 -0
commons.py ADDED
@@ -0,0 +1,188 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ import numpy as np
3
+ import torch
4
+ from torch import nn
5
+ from torch.nn import functional as F
6
+
7
+ def slice_pitch_segments(x, ids_str, segment_size=4):
8
+ ret = torch.zeros_like(x[:, :segment_size])
9
+ for i in range(x.size(0)):
10
+ idx_str = ids_str[i]
11
+ idx_end = idx_str + segment_size
12
+ ret[i] = x[i, idx_str:idx_end]
13
+ return ret
14
+
15
+ def rand_slice_segments_with_pitch(x, pitch, x_lengths=None, segment_size=4):
16
+ b, d, t = x.size()
17
+ if x_lengths is None:
18
+ x_lengths = t
19
+ ids_str_max = x_lengths - segment_size + 1
20
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
21
+ ret = slice_segments(x, ids_str, segment_size)
22
+ ret_pitch = slice_pitch_segments(pitch, ids_str, segment_size)
23
+ return ret, ret_pitch, ids_str
24
+
25
+ def init_weights(m, mean=0.0, std=0.01):
26
+ classname = m.__class__.__name__
27
+ if classname.find("Conv") != -1:
28
+ m.weight.data.normal_(mean, std)
29
+
30
+
31
+ def get_padding(kernel_size, dilation=1):
32
+ return int((kernel_size*dilation - dilation)/2)
33
+
34
+
35
+ def convert_pad_shape(pad_shape):
36
+ l = pad_shape[::-1]
37
+ pad_shape = [item for sublist in l for item in sublist]
38
+ return pad_shape
39
+
40
+
41
+ def intersperse(lst, item):
42
+ result = [item] * (len(lst) * 2 + 1)
43
+ result[1::2] = lst
44
+ return result
45
+
46
+
47
+ def kl_divergence(m_p, logs_p, m_q, logs_q):
48
+ """KL(P||Q)"""
49
+ kl = (logs_q - logs_p) - 0.5
50
+ kl += 0.5 * (torch.exp(2. * logs_p) + ((m_p - m_q)**2)) * torch.exp(-2. * logs_q)
51
+ return kl
52
+
53
+
54
+ def rand_gumbel(shape):
55
+ """Sample from the Gumbel distribution, protect from overflows."""
56
+ uniform_samples = torch.rand(shape) * 0.99998 + 0.00001
57
+ return -torch.log(-torch.log(uniform_samples))
58
+
59
+
60
+ def rand_gumbel_like(x):
61
+ g = rand_gumbel(x.size()).to(dtype=x.dtype, device=x.device)
62
+ return g
63
+
64
+
65
+ def slice_segments(x, ids_str, segment_size=4):
66
+ ret = torch.zeros_like(x[:, :, :segment_size])
67
+ for i in range(x.size(0)):
68
+ idx_str = ids_str[i]
69
+ idx_end = idx_str + segment_size
70
+ ret[i] = x[i, :, idx_str:idx_end]
71
+ return ret
72
+
73
+
74
+ def rand_slice_segments(x, x_lengths=None, segment_size=4):
75
+ b, d, t = x.size()
76
+ if x_lengths is None:
77
+ x_lengths = t
78
+ ids_str_max = x_lengths - segment_size + 1
79
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
80
+ ret = slice_segments(x, ids_str, segment_size)
81
+ return ret, ids_str
82
+
83
+
84
+ def rand_spec_segments(x, x_lengths=None, segment_size=4):
85
+ b, d, t = x.size()
86
+ if x_lengths is None:
87
+ x_lengths = t
88
+ ids_str_max = x_lengths - segment_size
89
+ ids_str = (torch.rand([b]).to(device=x.device) * ids_str_max).to(dtype=torch.long)
90
+ ret = slice_segments(x, ids_str, segment_size)
91
+ return ret, ids_str
92
+
93
+
94
+ def get_timing_signal_1d(
95
+ length, channels, min_timescale=1.0, max_timescale=1.0e4):
96
+ position = torch.arange(length, dtype=torch.float)
97
+ num_timescales = channels // 2
98
+ log_timescale_increment = (
99
+ math.log(float(max_timescale) / float(min_timescale)) /
100
+ (num_timescales - 1))
101
+ inv_timescales = min_timescale * torch.exp(
102
+ torch.arange(num_timescales, dtype=torch.float) * -log_timescale_increment)
103
+ scaled_time = position.unsqueeze(0) * inv_timescales.unsqueeze(1)
104
+ signal = torch.cat([torch.sin(scaled_time), torch.cos(scaled_time)], 0)
105
+ signal = F.pad(signal, [0, 0, 0, channels % 2])
106
+ signal = signal.view(1, channels, length)
107
+ return signal
108
+
109
+
110
+ def add_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4):
111
+ b, channels, length = x.size()
112
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
113
+ return x + signal.to(dtype=x.dtype, device=x.device)
114
+
115
+
116
+ def cat_timing_signal_1d(x, min_timescale=1.0, max_timescale=1.0e4, axis=1):
117
+ b, channels, length = x.size()
118
+ signal = get_timing_signal_1d(length, channels, min_timescale, max_timescale)
119
+ return torch.cat([x, signal.to(dtype=x.dtype, device=x.device)], axis)
120
+
121
+
122
+ def subsequent_mask(length):
123
+ mask = torch.tril(torch.ones(length, length)).unsqueeze(0).unsqueeze(0)
124
+ return mask
125
+
126
+
127
+ @torch.jit.script
128
+ def fused_add_tanh_sigmoid_multiply(input_a, input_b, n_channels):
129
+ n_channels_int = n_channels[0]
130
+ in_act = input_a + input_b
131
+ t_act = torch.tanh(in_act[:, :n_channels_int, :])
132
+ s_act = torch.sigmoid(in_act[:, n_channels_int:, :])
133
+ acts = t_act * s_act
134
+ return acts
135
+
136
+
137
+ def convert_pad_shape(pad_shape):
138
+ l = pad_shape[::-1]
139
+ pad_shape = [item for sublist in l for item in sublist]
140
+ return pad_shape
141
+
142
+
143
+ def shift_1d(x):
144
+ x = F.pad(x, convert_pad_shape([[0, 0], [0, 0], [1, 0]]))[:, :, :-1]
145
+ return x
146
+
147
+
148
+ def sequence_mask(length, max_length=None):
149
+ if max_length is None:
150
+ max_length = length.max()
151
+ x = torch.arange(max_length, dtype=length.dtype, device=length.device)
152
+ return x.unsqueeze(0) < length.unsqueeze(1)
153
+
154
+
155
+ def generate_path(duration, mask):
156
+ """
157
+ duration: [b, 1, t_x]
158
+ mask: [b, 1, t_y, t_x]
159
+ """
160
+ device = duration.device
161
+
162
+ b, _, t_y, t_x = mask.shape
163
+ cum_duration = torch.cumsum(duration, -1)
164
+
165
+ cum_duration_flat = cum_duration.view(b * t_x)
166
+ path = sequence_mask(cum_duration_flat, t_y).to(mask.dtype)
167
+ path = path.view(b, t_x, t_y)
168
+ path = path - F.pad(path, convert_pad_shape([[0, 0], [1, 0], [0, 0]]))[:, :-1]
169
+ path = path.unsqueeze(1).transpose(2,3) * mask
170
+ return path
171
+
172
+
173
+ def clip_grad_value_(parameters, clip_value, norm_type=2):
174
+ if isinstance(parameters, torch.Tensor):
175
+ parameters = [parameters]
176
+ parameters = list(filter(lambda p: p.grad is not None, parameters))
177
+ norm_type = float(norm_type)
178
+ if clip_value is not None:
179
+ clip_value = float(clip_value)
180
+
181
+ total_norm = 0
182
+ for p in parameters:
183
+ param_norm = p.grad.data.norm(norm_type)
184
+ total_norm += param_norm.item() ** norm_type
185
+ if clip_value is not None:
186
+ p.grad.data.clamp_(min=-clip_value, max=clip_value)
187
+ total_norm = total_norm ** (1. / norm_type)
188
+ return total_norm