del fixed embedding voice diffussion
Browse files- Modules/diffusion/modules.py +30 -72
- Modules/diffusion/sampler.py +5 -47
- api.py +3 -3
- models.py +1 -0
- msinference.py +0 -2
Modules/diffusion/modules.py
CHANGED
|
@@ -146,6 +146,7 @@ class StyleTransformer1d(nn.Module):
|
|
| 146 |
return mapping
|
| 147 |
|
| 148 |
def run(self, x, time, embedding, features):
|
|
|
|
| 149 |
|
| 150 |
mapping = self.get_mapping(time, features)
|
| 151 |
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
|
|
@@ -161,37 +162,22 @@ class StyleTransformer1d(nn.Module):
|
|
| 161 |
|
| 162 |
return x
|
| 163 |
|
| 164 |
-
def forward(self,
|
| 165 |
-
|
| 166 |
-
|
| 167 |
embedding= None,
|
| 168 |
-
features = None
|
| 169 |
-
embedding_scale: float = 1.0) -> Tensor:
|
| 170 |
-
|
| 171 |
-
b, device = embedding.shape[0], embedding.device
|
| 172 |
-
fixed_embedding = self.fixed_embedding(embedding)
|
| 173 |
-
if embedding_mask_proba > 0.0:
|
| 174 |
-
# Randomly mask embedding
|
| 175 |
-
batch_mask = rand_bool(
|
| 176 |
-
shape=(b, 1, 1), proba=embedding_mask_proba, device=device
|
| 177 |
-
)
|
| 178 |
-
embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
| 179 |
|
| 180 |
-
|
| 181 |
-
|
| 182 |
-
|
| 183 |
-
|
| 184 |
-
|
| 185 |
-
|
| 186 |
-
|
| 187 |
-
|
| 188 |
-
|
| 189 |
-
|
| 190 |
-
# raise ValueError
|
| 191 |
-
return self.run(x, time, embedding=embedding, features=features)
|
| 192 |
-
|
| 193 |
-
|
| 194 |
-
return x
|
| 195 |
|
| 196 |
|
| 197 |
class StyleTransformerBlock(nn.Module):
|
|
@@ -216,24 +202,11 @@ class StyleTransformerBlock(nn.Module):
|
|
| 216 |
features=features,
|
| 217 |
style_dim=style_dim,
|
| 218 |
num_heads=num_heads,
|
| 219 |
-
head_features=head_features
|
| 220 |
-
use_rel_pos=use_rel_pos,
|
| 221 |
-
# rel_pos_num_buckets=rel_pos_num_buckets,
|
| 222 |
-
# rel_pos_max_distance=rel_pos_max_distance,
|
| 223 |
)
|
| 224 |
|
| 225 |
if self.use_cross_attention:
|
| 226 |
raise ValueError
|
| 227 |
-
# self.cross_attention = StyleAttention(
|
| 228 |
-
# features=features,
|
| 229 |
-
# style_dim=style_dim,
|
| 230 |
-
# num_heads=num_heads,
|
| 231 |
-
# head_features=head_features,
|
| 232 |
-
# context_features=context_features,
|
| 233 |
-
# use_rel_pos=use_rel_pos,
|
| 234 |
-
# rel_pos_num_buckets=rel_pos_num_buckets,
|
| 235 |
-
# rel_pos_max_distance=rel_pos_max_distance,
|
| 236 |
-
# )
|
| 237 |
|
| 238 |
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
| 239 |
|
|
@@ -254,7 +227,7 @@ class StyleAttention(nn.Module):
|
|
| 254 |
head_features: int,
|
| 255 |
num_heads: int,
|
| 256 |
context_features = None,
|
| 257 |
-
use_rel_pos: bool,
|
| 258 |
# rel_pos_num_buckets: Optional[int] = None,
|
| 259 |
# rel_pos_max_distance: Optional[int] = None,
|
| 260 |
):
|
|
@@ -274,23 +247,20 @@ class StyleAttention(nn.Module):
|
|
| 274 |
self.attention = AttentionBase(
|
| 275 |
features,
|
| 276 |
num_heads=num_heads,
|
| 277 |
-
head_features=head_features
|
| 278 |
-
use_rel_pos=use_rel_pos,
|
| 279 |
-
# rel_pos_num_buckets=rel_pos_num_buckets,
|
| 280 |
-
# rel_pos_max_distance=rel_pos_max_distance,
|
| 281 |
)
|
| 282 |
|
| 283 |
-
def forward(self, x
|
| 284 |
|
| 285 |
-
|
| 286 |
-
|
| 287 |
context = default(context, x)
|
| 288 |
-
|
| 289 |
-
|
| 290 |
x, context = self.norm(x, s), self.norm_context(context, s)
|
| 291 |
|
| 292 |
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
|
| 293 |
-
|
| 294 |
return self.attention(q, k, v)
|
| 295 |
|
| 296 |
|
|
@@ -310,25 +280,13 @@ class AttentionBase(nn.Module):
|
|
| 310 |
features,
|
| 311 |
*,
|
| 312 |
head_features,
|
| 313 |
-
num_heads
|
| 314 |
-
use_rel_pos,
|
| 315 |
-
out_features = None,
|
| 316 |
-
# rel_pos_num_buckets: Optional[int] = None,
|
| 317 |
-
# rel_pos_max_distance: Optional[int] = None,
|
| 318 |
-
):
|
| 319 |
super().__init__()
|
| 320 |
self.scale = head_features ** -0.5
|
| 321 |
self.num_heads = num_heads
|
| 322 |
-
|
| 323 |
-
|
| 324 |
-
|
| 325 |
-
if use_rel_pos:
|
| 326 |
-
raise ValueError
|
| 327 |
-
|
| 328 |
-
if out_features is None:
|
| 329 |
-
out_features = features
|
| 330 |
-
|
| 331 |
-
self.to_out = nn.Linear(in_features=mid_features, out_features=out_features)
|
| 332 |
|
| 333 |
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 334 |
# Split heads
|
|
@@ -358,7 +316,7 @@ class Attention(nn.Module):
|
|
| 358 |
num_heads,
|
| 359 |
out_features=None,
|
| 360 |
context_features=None,
|
| 361 |
-
use_rel_pos,
|
| 362 |
# rel_pos_num_buckets: Optional[int] = None,
|
| 363 |
# rel_pos_max_distance: Optional[int] = None,
|
| 364 |
):
|
|
@@ -381,7 +339,7 @@ class Attention(nn.Module):
|
|
| 381 |
out_features=out_features,
|
| 382 |
num_heads=num_heads,
|
| 383 |
head_features=head_features,
|
| 384 |
-
use_rel_pos=use_rel_pos,
|
| 385 |
# rel_pos_num_buckets=rel_pos_num_buckets,
|
| 386 |
# rel_pos_max_distance=rel_pos_max_distance,
|
| 387 |
)
|
|
|
|
| 146 |
return mapping
|
| 147 |
|
| 148 |
def run(self, x, time, embedding, features):
|
| 149 |
+
# called by forward()
|
| 150 |
|
| 151 |
mapping = self.get_mapping(time, features)
|
| 152 |
x = torch.cat([x.expand(-1, embedding.size(1), -1), embedding], axis=-1)
|
|
|
|
| 162 |
|
| 163 |
return x
|
| 164 |
|
| 165 |
+
def forward(self,
|
| 166 |
+
x,
|
| 167 |
+
time,
|
| 168 |
embedding= None,
|
| 169 |
+
features = None):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 170 |
|
| 171 |
+
b, device = embedding.shape[0], embedding.device
|
| 172 |
+
# if
|
| 173 |
+
# embedding_mask_proba: float = 0.0, > 0
|
| 174 |
+
# fixed_embedding = self.fixed_embedding(embedding)
|
| 175 |
+
# embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
| 176 |
+
return self.run(x,
|
| 177 |
+
time,
|
| 178 |
+
embedding=embedding,
|
| 179 |
+
# embedding=self.fixed_embedding(embedding), # fixedemb has noisy beginnings on chapters.wav
|
| 180 |
+
features=features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 181 |
|
| 182 |
|
| 183 |
class StyleTransformerBlock(nn.Module):
|
|
|
|
| 202 |
features=features,
|
| 203 |
style_dim=style_dim,
|
| 204 |
num_heads=num_heads,
|
| 205 |
+
head_features=head_features
|
|
|
|
|
|
|
|
|
|
| 206 |
)
|
| 207 |
|
| 208 |
if self.use_cross_attention:
|
| 209 |
raise ValueError
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 210 |
|
| 211 |
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
| 212 |
|
|
|
|
| 227 |
head_features: int,
|
| 228 |
num_heads: int,
|
| 229 |
context_features = None,
|
| 230 |
+
# use_rel_pos: bool,
|
| 231 |
# rel_pos_num_buckets: Optional[int] = None,
|
| 232 |
# rel_pos_max_distance: Optional[int] = None,
|
| 233 |
):
|
|
|
|
| 247 |
self.attention = AttentionBase(
|
| 248 |
features,
|
| 249 |
num_heads=num_heads,
|
| 250 |
+
head_features=head_features
|
|
|
|
|
|
|
|
|
|
| 251 |
)
|
| 252 |
|
| 253 |
+
def forward(self, x, s, *, context = None):
|
| 254 |
|
| 255 |
+
if context is not None:
|
| 256 |
+
raise ValueError
|
| 257 |
context = default(context, x)
|
| 258 |
+
|
| 259 |
+
|
| 260 |
x, context = self.norm(x, s), self.norm_context(context, s)
|
| 261 |
|
| 262 |
q, k, v = (self.to_q(x), *torch.chunk(self.to_kv(context), chunks=2, dim=-1))
|
| 263 |
+
|
| 264 |
return self.attention(q, k, v)
|
| 265 |
|
| 266 |
|
|
|
|
| 280 |
features,
|
| 281 |
*,
|
| 282 |
head_features,
|
| 283 |
+
num_heads):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 284 |
super().__init__()
|
| 285 |
self.scale = head_features ** -0.5
|
| 286 |
self.num_heads = num_heads
|
| 287 |
+
mid_features = head_features * num_heads
|
| 288 |
+
self.to_out = nn.Linear(in_features=mid_features,
|
| 289 |
+
out_features=features)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 290 |
|
| 291 |
def forward(self, q: Tensor, k: Tensor, v: Tensor) -> Tensor:
|
| 292 |
# Split heads
|
|
|
|
| 316 |
num_heads,
|
| 317 |
out_features=None,
|
| 318 |
context_features=None,
|
| 319 |
+
# use_rel_pos,
|
| 320 |
# rel_pos_num_buckets: Optional[int] = None,
|
| 321 |
# rel_pos_max_distance: Optional[int] = None,
|
| 322 |
):
|
|
|
|
| 339 |
out_features=out_features,
|
| 340 |
num_heads=num_heads,
|
| 341 |
head_features=head_features,
|
| 342 |
+
# use_rel_pos=use_rel_pos,
|
| 343 |
# rel_pos_num_buckets=rel_pos_num_buckets,
|
| 344 |
# rel_pos_max_distance=rel_pos_max_distance,
|
| 345 |
)
|
Modules/diffusion/sampler.py
CHANGED
|
@@ -1,61 +1,18 @@
|
|
| 1 |
-
from math import
|
| 2 |
import torch.nn as nn
|
| 3 |
from einops import rearrange
|
| 4 |
from torch import Tensor
|
| 5 |
-
|
| 6 |
from functools import reduce
|
| 7 |
-
from inspect import isfunction
|
| 8 |
-
from math import ceil, floor, log2, pi
|
| 9 |
-
# from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
| 10 |
import torch
|
| 11 |
import torch.nn.functional as F
|
| 12 |
-
from einops import rearrange
|
| 13 |
-
from torch import Generator, Tensor
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
|
| 20 |
-
# def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
|
| 21 |
-
# return isinstance(obj, list) or isinstance(obj, tuple)
|
| 22 |
-
|
| 23 |
|
| 24 |
def default(val, d):
|
| 25 |
if val is not None: #exists(val):
|
| 26 |
return val
|
| 27 |
return d #d() if isfunction(d) else d
|
| 28 |
|
| 29 |
-
|
| 30 |
-
# def to_list(val: Union[T, Sequence[T]]) -> List[T]:
|
| 31 |
-
# if isinstance(val, tuple):
|
| 32 |
-
# return list(val)
|
| 33 |
-
# if isinstance(val, list):
|
| 34 |
-
# return val
|
| 35 |
-
# return [val] # type: ignore
|
| 36 |
-
|
| 37 |
-
|
| 38 |
-
# def prod(vals: Sequence[int]) -> int:
|
| 39 |
-
# return reduce(lambda x, y: x * y, vals)
|
| 40 |
-
|
| 41 |
-
|
| 42 |
-
def closest_power_2(x: float) -> int:
|
| 43 |
-
exponent = log2(x)
|
| 44 |
-
distance_fn = lambda z: abs(x - 2 ** z) # noqa
|
| 45 |
-
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
|
| 46 |
-
return 2 ** int(exponent_closest)
|
| 47 |
-
|
| 48 |
-
def rand_bool(shape, proba, device = None):
|
| 49 |
-
if proba == 1:
|
| 50 |
-
return torch.ones(shape, device=device, dtype=torch.bool)
|
| 51 |
-
elif proba == 0:
|
| 52 |
-
return torch.zeros(shape, device=device, dtype=torch.bool)
|
| 53 |
-
else:
|
| 54 |
-
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
|
| 55 |
-
|
| 56 |
-
# ============================= END functions from diffusION.utils
|
| 57 |
-
|
| 58 |
-
|
| 59 |
class LogNormalDistribution():
|
| 60 |
def __init__(self, mean: float, std: float):
|
| 61 |
self.mean = mean
|
|
@@ -238,7 +195,8 @@ class DiffusionSampler(nn.Module):
|
|
| 238 |
|
| 239 |
# Compute sigmas using schedule
|
| 240 |
sigmas = self.sigma_schedule(num_steps, device)
|
| 241 |
-
|
|
|
|
| 242 |
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
|
| 243 |
# Sample using sampler
|
| 244 |
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
|
|
|
|
| 1 |
+
from math import sqrt
|
| 2 |
import torch.nn as nn
|
| 3 |
from einops import rearrange
|
| 4 |
from torch import Tensor
|
|
|
|
| 5 |
from functools import reduce
|
| 6 |
+
# from inspect import isfunction
|
| 7 |
+
# from math import ceil, floor, log2, pi
|
|
|
|
| 8 |
import torch
|
| 9 |
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 10 |
|
| 11 |
def default(val, d):
|
| 12 |
if val is not None: #exists(val):
|
| 13 |
return val
|
| 14 |
return d #d() if isfunction(d) else d
|
| 15 |
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 16 |
class LogNormalDistribution():
|
| 17 |
def __init__(self, mean: float, std: float):
|
| 18 |
self.mean = mean
|
|
|
|
| 195 |
|
| 196 |
# Compute sigmas using schedule
|
| 197 |
sigmas = self.sigma_schedule(num_steps, device)
|
| 198 |
+
|
| 199 |
+
# L242 KWARGS dict_keys(['embedding', 'features'])
|
| 200 |
fn = lambda *a, **ka: self.denoise_fn(*a, **{**ka, **kwargs}) # noqa
|
| 201 |
# Sample using sampler
|
| 202 |
x = self.sampler(noise, fn=fn, sigmas=sigmas, num_steps=num_steps)
|
api.py
CHANGED
|
@@ -171,8 +171,8 @@ def tts_multi_sentence(precomputed_style_vector=None,
|
|
| 171 |
precomputed_style_vector,
|
| 172 |
alpha=0.3,
|
| 173 |
beta=0.7,
|
| 174 |
-
diffusion_steps=diffusion_steps
|
| 175 |
-
|
| 176 |
x = np.concatenate(x)
|
| 177 |
|
| 178 |
# Fallback - MMS TTS - Non-English
|
|
@@ -530,7 +530,7 @@ def serve_wav():
|
|
| 530 |
|
| 531 |
# audios = [msinference.inference(text,
|
| 532 |
# msinference.compute_style(f'voices/{voice}.wav'),
|
| 533 |
-
# alpha=0.3, beta=0.7, diffusion_steps=7
|
| 534 |
# # for t in [text]:
|
| 535 |
# output_buffer = io.BytesIO()
|
| 536 |
# write(output_buffer, 24000, np.concatenate(audios))
|
|
|
|
| 171 |
precomputed_style_vector,
|
| 172 |
alpha=0.3,
|
| 173 |
beta=0.7,
|
| 174 |
+
diffusion_steps=diffusion_steps)
|
| 175 |
+
)
|
| 176 |
x = np.concatenate(x)
|
| 177 |
|
| 178 |
# Fallback - MMS TTS - Non-English
|
|
|
|
| 530 |
|
| 531 |
# audios = [msinference.inference(text,
|
| 532 |
# msinference.compute_style(f'voices/{voice}.wav'),
|
| 533 |
+
# alpha=0.3, beta=0.7, diffusion_steps=7)]
|
| 534 |
# # for t in [text]:
|
| 535 |
# output_buffer = io.BytesIO()
|
| 536 |
# write(output_buffer, 24000, np.concatenate(audios))
|
models.py
CHANGED
|
@@ -69,6 +69,7 @@ class AudioDiffusionConditional(nn.Module):
|
|
| 69 |
|
| 70 |
def forward(self, *args, **kwargs):
|
| 71 |
default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
|
|
|
|
| 72 |
return self.diffusion(*args, **{**default_kwargs, **kwargs})
|
| 73 |
|
| 74 |
# def sample(self, *args, **kwargs):
|
|
|
|
| 69 |
|
| 70 |
def forward(self, *args, **kwargs):
|
| 71 |
default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
|
| 72 |
+
# here embedding_scale = 1.0 is passed to DiffusionSampler() - del no-op if scale = 1.0
|
| 73 |
return self.diffusion(*args, **{**default_kwargs, **kwargs})
|
| 74 |
|
| 75 |
# def sample(self, *args, **kwargs):
|
msinference.py
CHANGED
|
@@ -174,7 +174,6 @@ def inference(text,
|
|
| 174 |
alpha = 0.3,
|
| 175 |
beta = 0.7,
|
| 176 |
diffusion_steps=7, # 7 if voice is native English else 5 for non-native
|
| 177 |
-
embedding_scale=1,
|
| 178 |
use_gruut=False):
|
| 179 |
text = text.strip()
|
| 180 |
ps = global_phonemizer.phonemize([text])
|
|
@@ -213,7 +212,6 @@ def inference(text,
|
|
| 213 |
|
| 214 |
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
| 215 |
embedding=bert_dur,
|
| 216 |
-
embedding_scale=embedding_scale,
|
| 217 |
features=ref_s, # reference from the same speaker as the embedding
|
| 218 |
num_steps=diffusion_steps).squeeze(1)
|
| 219 |
|
|
|
|
| 174 |
alpha = 0.3,
|
| 175 |
beta = 0.7,
|
| 176 |
diffusion_steps=7, # 7 if voice is native English else 5 for non-native
|
|
|
|
| 177 |
use_gruut=False):
|
| 178 |
text = text.strip()
|
| 179 |
ps = global_phonemizer.phonemize([text])
|
|
|
|
| 212 |
|
| 213 |
s_pred = sampler(noise = torch.randn((1, 256)).unsqueeze(1).to(device),
|
| 214 |
embedding=bert_dur,
|
|
|
|
| 215 |
features=ref_s, # reference from the same speaker as the embedding
|
| 216 |
num_steps=diffusion_steps).squeeze(1)
|
| 217 |
|