factor diffusion
Browse files- Modules/diffusion/diffusion.py +0 -85
- Modules/diffusion/modules.py +96 -77
- Modules/diffusion/sampler.py +70 -46
- Modules/diffusion/utils.py +0 -82
- models.py +77 -8
Modules/diffusion/diffusion.py
DELETED
|
@@ -1,85 +0,0 @@
|
|
| 1 |
-
from math import pi
|
| 2 |
-
from random import randint
|
| 3 |
-
from typing import Any, Optional, Sequence, Tuple, Union
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
from einops import rearrange
|
| 7 |
-
from torch import Tensor, nn
|
| 8 |
-
from tqdm import tqdm
|
| 9 |
-
|
| 10 |
-
from .utils import *
|
| 11 |
-
from .sampler import *
|
| 12 |
-
|
| 13 |
-
"""
|
| 14 |
-
Diffusion Classes (generic for 1d data)
|
| 15 |
-
"""
|
| 16 |
-
|
| 17 |
-
|
| 18 |
-
class Model1d(nn.Module):
|
| 19 |
-
def __init__(self, unet_type: str = "base", **kwargs):
|
| 20 |
-
super().__init__()
|
| 21 |
-
diffusion_kwargs, kwargs = groupby("diffusion_", kwargs)
|
| 22 |
-
self.unet = None
|
| 23 |
-
self.diffusion = None
|
| 24 |
-
|
| 25 |
-
def forward(self, x: Tensor, **kwargs) -> Tensor:
|
| 26 |
-
return self.diffusion(x, **kwargs)
|
| 27 |
-
|
| 28 |
-
def sample(self, *args, **kwargs) -> Tensor:
|
| 29 |
-
return self.diffusion.sample(*args, **kwargs)
|
| 30 |
-
|
| 31 |
-
|
| 32 |
-
"""
|
| 33 |
-
Audio Diffusion Classes (specific for 1d audio data)
|
| 34 |
-
"""
|
| 35 |
-
|
| 36 |
-
|
| 37 |
-
def get_default_model_kwargs():
|
| 38 |
-
return dict(
|
| 39 |
-
channels=128,
|
| 40 |
-
patch_size=16,
|
| 41 |
-
multipliers=[1, 2, 4, 4, 4, 4, 4],
|
| 42 |
-
factors=[4, 4, 4, 2, 2, 2],
|
| 43 |
-
num_blocks=[2, 2, 2, 2, 2, 2],
|
| 44 |
-
attentions=[0, 0, 0, 1, 1, 1, 1],
|
| 45 |
-
attention_heads=8,
|
| 46 |
-
attention_features=64,
|
| 47 |
-
attention_multiplier=2,
|
| 48 |
-
attention_use_rel_pos=False,
|
| 49 |
-
diffusion_type="v",
|
| 50 |
-
diffusion_sigma_distribution=UniformDistribution(),
|
| 51 |
-
)
|
| 52 |
-
|
| 53 |
-
|
| 54 |
-
def get_default_sampling_kwargs():
|
| 55 |
-
return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
|
| 56 |
-
|
| 57 |
-
class AudioDiffusionConditional(Model1d):
|
| 58 |
-
def __init__(
|
| 59 |
-
self,
|
| 60 |
-
embedding_features: int,
|
| 61 |
-
embedding_max_length: int,
|
| 62 |
-
embedding_mask_proba: float = 0.1,
|
| 63 |
-
**kwargs,
|
| 64 |
-
):
|
| 65 |
-
self.embedding_mask_proba = embedding_mask_proba
|
| 66 |
-
default_kwargs = dict(
|
| 67 |
-
**get_default_model_kwargs(),
|
| 68 |
-
unet_type="cfg",
|
| 69 |
-
context_embedding_features=embedding_features,
|
| 70 |
-
context_embedding_max_length=embedding_max_length,
|
| 71 |
-
)
|
| 72 |
-
super().__init__(**{**default_kwargs, **kwargs})
|
| 73 |
-
|
| 74 |
-
def forward(self, *args, **kwargs):
|
| 75 |
-
default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
|
| 76 |
-
return super().forward(*args, **{**default_kwargs, **kwargs})
|
| 77 |
-
|
| 78 |
-
def sample(self, *args, **kwargs):
|
| 79 |
-
default_kwargs = dict(
|
| 80 |
-
**get_default_sampling_kwargs(),
|
| 81 |
-
embedding_scale=5.0,
|
| 82 |
-
)
|
| 83 |
-
return super().sample(*args, **{**default_kwargs, **kwargs})
|
| 84 |
-
|
| 85 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
Modules/diffusion/modules.py
CHANGED
|
@@ -1,8 +1,5 @@
|
|
| 1 |
from math import floor, log, pi
|
| 2 |
-
|
| 3 |
-
|
| 4 |
-
from .utils import *
|
| 5 |
-
|
| 6 |
import torch
|
| 7 |
import torch.nn as nn
|
| 8 |
from einops import rearrange, reduce, repeat
|
|
@@ -11,9 +8,10 @@ from einops_exts import rearrange_many
|
|
| 11 |
from torch import Tensor, einsum
|
| 12 |
|
| 13 |
|
| 14 |
-
|
| 15 |
-
|
| 16 |
-
|
|
|
|
| 17 |
|
| 18 |
class AdaLayerNorm(nn.Module):
|
| 19 |
def __init__(self, style_dim, channels, eps=1e-5):
|
|
@@ -38,6 +36,9 @@ class AdaLayerNorm(nn.Module):
|
|
| 38 |
return x.transpose(1, -1).transpose(-1, -2)
|
| 39 |
|
| 40 |
class StyleTransformer1d(nn.Module):
|
|
|
|
|
|
|
|
|
|
| 41 |
def __init__(
|
| 42 |
self,
|
| 43 |
num_layers: int,
|
|
@@ -48,14 +49,14 @@ class StyleTransformer1d(nn.Module):
|
|
| 48 |
use_context_time: bool = True,
|
| 49 |
use_rel_pos: bool = False,
|
| 50 |
context_features_multiplier: int = 1,
|
| 51 |
-
rel_pos_num_buckets: Optional[int] = None,
|
| 52 |
-
rel_pos_max_distance: Optional[int] = None,
|
| 53 |
-
context_features
|
| 54 |
-
context_embedding_features
|
| 55 |
-
embedding_max_length
|
| 56 |
):
|
| 57 |
super().__init__()
|
| 58 |
-
|
| 59 |
self.blocks = nn.ModuleList(
|
| 60 |
[
|
| 61 |
StyleTransformerBlock(
|
|
@@ -65,8 +66,8 @@ class StyleTransformer1d(nn.Module):
|
|
| 65 |
multiplier=multiplier,
|
| 66 |
style_dim=context_features,
|
| 67 |
use_rel_pos=use_rel_pos,
|
| 68 |
-
rel_pos_num_buckets=rel_pos_num_buckets,
|
| 69 |
-
rel_pos_max_distance=rel_pos_max_distance,
|
| 70 |
)
|
| 71 |
for i in range(num_layers)
|
| 72 |
]
|
|
@@ -81,11 +82,14 @@ class StyleTransformer1d(nn.Module):
|
|
| 81 |
),
|
| 82 |
)
|
| 83 |
|
| 84 |
-
use_context_features =
|
| 85 |
self.use_context_features = use_context_features
|
| 86 |
self.use_context_time = use_context_time
|
| 87 |
|
| 88 |
if use_context_time or use_context_features:
|
|
|
|
|
|
|
|
|
|
| 89 |
context_mapping_features = channels + context_embedding_features
|
| 90 |
|
| 91 |
self.to_mapping = nn.Sequential(
|
|
@@ -96,7 +100,7 @@ class StyleTransformer1d(nn.Module):
|
|
| 96 |
)
|
| 97 |
|
| 98 |
if use_context_time:
|
| 99 |
-
|
| 100 |
self.to_time = nn.Sequential(
|
| 101 |
TimePositionalEmbedding(
|
| 102 |
dim=channels, out_features=context_mapping_features
|
|
@@ -105,7 +109,7 @@ class StyleTransformer1d(nn.Module):
|
|
| 105 |
)
|
| 106 |
|
| 107 |
if use_context_features:
|
| 108 |
-
|
| 109 |
self.to_features = nn.Sequential(
|
| 110 |
nn.Linear(
|
| 111 |
in_features=context_features, out_features=context_mapping_features
|
|
@@ -119,23 +123,23 @@ class StyleTransformer1d(nn.Module):
|
|
| 119 |
|
| 120 |
|
| 121 |
def get_mapping(
|
| 122 |
-
self,
|
| 123 |
-
|
|
|
|
| 124 |
"""Combines context time features and features into mapping"""
|
| 125 |
items, mapping = [], None
|
| 126 |
# Compute time features
|
| 127 |
if self.use_context_time:
|
| 128 |
-
|
| 129 |
-
assert exists(time), assert_message
|
| 130 |
items += [self.to_time(time)]
|
| 131 |
# Compute features
|
| 132 |
if self.use_context_features:
|
| 133 |
-
|
| 134 |
-
assert exists(features), assert_message
|
| 135 |
items += [self.to_features(features)]
|
| 136 |
|
| 137 |
# Compute joint mapping
|
| 138 |
if self.use_context_time or self.use_context_features:
|
|
|
|
| 139 |
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
|
| 140 |
mapping = self.to_mapping(mapping)
|
| 141 |
|
|
@@ -160,8 +164,8 @@ class StyleTransformer1d(nn.Module):
|
|
| 160 |
def forward(self, x: Tensor,
|
| 161 |
time: Tensor,
|
| 162 |
embedding_mask_proba: float = 0.0,
|
| 163 |
-
embedding
|
| 164 |
-
features
|
| 165 |
embedding_scale: float = 1.0) -> Tensor:
|
| 166 |
|
| 167 |
b, device = embedding.shape[0], embedding.device
|
|
@@ -174,13 +178,18 @@ class StyleTransformer1d(nn.Module):
|
|
| 174 |
embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
| 175 |
|
| 176 |
if embedding_scale != 1.0:
|
| 177 |
-
|
|
|
|
| 178 |
out = self.run(x, time, embedding=embedding, features=features)
|
| 179 |
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
|
| 180 |
-
|
|
|
|
| 181 |
return out_masked + (out - out_masked) * embedding_scale
|
|
|
|
| 182 |
else:
|
|
|
|
| 183 |
return self.run(x, time, embedding=embedding, features=features)
|
|
|
|
| 184 |
|
| 185 |
return x
|
| 186 |
|
|
@@ -194,42 +203,45 @@ class StyleTransformerBlock(nn.Module):
|
|
| 194 |
style_dim: int,
|
| 195 |
multiplier: int,
|
| 196 |
use_rel_pos: bool,
|
| 197 |
-
rel_pos_num_buckets: Optional[int] = None,
|
| 198 |
-
rel_pos_max_distance: Optional[int] = None,
|
| 199 |
-
context_features
|
| 200 |
):
|
| 201 |
super().__init__()
|
| 202 |
|
| 203 |
-
self.use_cross_attention =
|
| 204 |
-
|
|
|
|
| 205 |
self.attention = StyleAttention(
|
| 206 |
features=features,
|
| 207 |
style_dim=style_dim,
|
| 208 |
num_heads=num_heads,
|
| 209 |
head_features=head_features,
|
| 210 |
use_rel_pos=use_rel_pos,
|
| 211 |
-
rel_pos_num_buckets=rel_pos_num_buckets,
|
| 212 |
-
rel_pos_max_distance=rel_pos_max_distance,
|
| 213 |
)
|
| 214 |
|
| 215 |
if self.use_cross_attention:
|
| 216 |
-
|
| 217 |
-
|
| 218 |
-
|
| 219 |
-
|
| 220 |
-
|
| 221 |
-
|
| 222 |
-
|
| 223 |
-
|
| 224 |
-
|
| 225 |
-
|
|
|
|
| 226 |
|
| 227 |
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
| 228 |
|
| 229 |
-
def forward(self, x: Tensor, s: Tensor, *, context
|
| 230 |
x = self.attention(x, s) + x
|
| 231 |
if self.use_cross_attention:
|
| 232 |
-
|
|
|
|
| 233 |
x = self.feed_forward(x) + x
|
| 234 |
return x
|
| 235 |
|
|
@@ -241,10 +253,10 @@ class StyleAttention(nn.Module):
|
|
| 241 |
style_dim: int,
|
| 242 |
head_features: int,
|
| 243 |
num_heads: int,
|
| 244 |
-
context_features
|
| 245 |
use_rel_pos: bool,
|
| 246 |
-
rel_pos_num_buckets: Optional[int] = None,
|
| 247 |
-
rel_pos_max_distance: Optional[int] = None,
|
| 248 |
):
|
| 249 |
super().__init__()
|
| 250 |
self.context_features = context_features
|
|
@@ -264,15 +276,16 @@ class StyleAttention(nn.Module):
|
|
| 264 |
num_heads=num_heads,
|
| 265 |
head_features=head_features,
|
| 266 |
use_rel_pos=use_rel_pos,
|
| 267 |
-
rel_pos_num_buckets=rel_pos_num_buckets,
|
| 268 |
-
rel_pos_max_distance=rel_pos_max_distance,
|
| 269 |
)
|
| 270 |
|
| 271 |
-
def forward(self, x: Tensor, s: Tensor, *, context
|
| 272 |
-
|
| 273 |
-
|
| 274 |
# Use context if provided
|
| 275 |
context = default(context, x)
|
|
|
|
| 276 |
# Normalize then compute q from input and k,v from context
|
| 277 |
x, context = self.norm(x, s), self.norm_context(context, s)
|
| 278 |
|
|
@@ -280,7 +293,9 @@ class StyleAttention(nn.Module):
|
|
| 280 |
# Compute and return attention
|
| 281 |
return self.attention(q, k, v)
|
| 282 |
|
| 283 |
-
|
|
|
|
|
|
|
| 284 |
mid_features = features * multiplier
|
| 285 |
return nn.Sequential(
|
| 286 |
nn.Linear(in_features=features, out_features=mid_features),
|
|
@@ -292,14 +307,14 @@ def FeedForward(features: int, multiplier: int) -> nn.Module:
|
|
| 292 |
class AttentionBase(nn.Module):
|
| 293 |
def __init__(
|
| 294 |
self,
|
| 295 |
-
features
|
| 296 |
*,
|
| 297 |
-
head_features
|
| 298 |
-
num_heads
|
| 299 |
-
use_rel_pos
|
| 300 |
-
out_features
|
| 301 |
-
rel_pos_num_buckets: Optional[int] = None,
|
| 302 |
-
rel_pos_max_distance: Optional[int] = None,
|
| 303 |
):
|
| 304 |
super().__init__()
|
| 305 |
self.scale = head_features ** -0.5
|
|
@@ -320,7 +335,11 @@ class AttentionBase(nn.Module):
|
|
| 320 |
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
|
| 321 |
# Compute similarity matrix
|
| 322 |
sim = einsum("... n d, ... m d -> ... n m", q, k)
|
| 323 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
| 324 |
sim = sim * self.scale
|
| 325 |
# Get attention matrix with softmax
|
| 326 |
attn = sim.softmax(dim=-1)
|
|
@@ -333,15 +352,15 @@ class AttentionBase(nn.Module):
|
|
| 333 |
class Attention(nn.Module):
|
| 334 |
def __init__(
|
| 335 |
self,
|
| 336 |
-
features
|
| 337 |
*,
|
| 338 |
-
head_features
|
| 339 |
-
num_heads
|
| 340 |
-
out_features
|
| 341 |
-
context_features
|
| 342 |
-
use_rel_pos
|
| 343 |
-
rel_pos_num_buckets: Optional[int] = None,
|
| 344 |
-
rel_pos_max_distance: Optional[int] = None,
|
| 345 |
):
|
| 346 |
super().__init__()
|
| 347 |
self.context_features = context_features
|
|
@@ -363,13 +382,13 @@ class Attention(nn.Module):
|
|
| 363 |
num_heads=num_heads,
|
| 364 |
head_features=head_features,
|
| 365 |
use_rel_pos=use_rel_pos,
|
| 366 |
-
rel_pos_num_buckets=rel_pos_num_buckets,
|
| 367 |
-
rel_pos_max_distance=rel_pos_max_distance,
|
| 368 |
)
|
| 369 |
|
| 370 |
-
def forward(self, x: Tensor, *, context
|
| 371 |
-
assert_message = "You must provide a context when using context_features"
|
| 372 |
-
assert not self.context_features or exists(context), assert_message
|
| 373 |
# Use context if provided
|
| 374 |
context = default(context, x)
|
| 375 |
# Normalize then compute q from input and k,v from context
|
|
|
|
| 1 |
from math import floor, log, pi
|
| 2 |
+
import torch.nn.functional as F
|
|
|
|
|
|
|
|
|
|
| 3 |
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
from einops import rearrange, reduce, repeat
|
|
|
|
| 8 |
from torch import Tensor, einsum
|
| 9 |
|
| 10 |
|
| 11 |
+
def default(val, d):
|
| 12 |
+
if val is not None: #exists(val):
|
| 13 |
+
return val
|
| 14 |
+
return d # d() if isfunction(d) else d
|
| 15 |
|
| 16 |
class AdaLayerNorm(nn.Module):
|
| 17 |
def __init__(self, style_dim, channels, eps=1e-5):
|
|
|
|
| 36 |
return x.transpose(1, -1).transpose(-1, -2)
|
| 37 |
|
| 38 |
class StyleTransformer1d(nn.Module):
|
| 39 |
+
|
| 40 |
+
# artificial_stylets / models.py
|
| 41 |
+
|
| 42 |
def __init__(
|
| 43 |
self,
|
| 44 |
num_layers: int,
|
|
|
|
| 49 |
use_context_time: bool = True,
|
| 50 |
use_rel_pos: bool = False,
|
| 51 |
context_features_multiplier: int = 1,
|
| 52 |
+
# rel_pos_num_buckets: Optional[int] = None,
|
| 53 |
+
# rel_pos_max_distance: Optional[int] = None,
|
| 54 |
+
context_features=None,
|
| 55 |
+
context_embedding_features=None,
|
| 56 |
+
embedding_max_length=512,
|
| 57 |
):
|
| 58 |
super().__init__()
|
| 59 |
+
|
| 60 |
self.blocks = nn.ModuleList(
|
| 61 |
[
|
| 62 |
StyleTransformerBlock(
|
|
|
|
| 66 |
multiplier=multiplier,
|
| 67 |
style_dim=context_features,
|
| 68 |
use_rel_pos=use_rel_pos,
|
| 69 |
+
# rel_pos_num_buckets=rel_pos_num_buckets,
|
| 70 |
+
# rel_pos_max_distance=rel_pos_max_distance,
|
| 71 |
)
|
| 72 |
for i in range(num_layers)
|
| 73 |
]
|
|
|
|
| 82 |
),
|
| 83 |
)
|
| 84 |
|
| 85 |
+
use_context_features = context_features is not None
|
| 86 |
self.use_context_features = use_context_features
|
| 87 |
self.use_context_time = use_context_time
|
| 88 |
|
| 89 |
if use_context_time or use_context_features:
|
| 90 |
+
# print(f'{use_context_time=} {use_context_features=}ooooooooooooooooooooooooooooooooooo')
|
| 91 |
+
# raise ValueError
|
| 92 |
+
# True True both context
|
| 93 |
context_mapping_features = channels + context_embedding_features
|
| 94 |
|
| 95 |
self.to_mapping = nn.Sequential(
|
|
|
|
| 100 |
)
|
| 101 |
|
| 102 |
if use_context_time:
|
| 103 |
+
|
| 104 |
self.to_time = nn.Sequential(
|
| 105 |
TimePositionalEmbedding(
|
| 106 |
dim=channels, out_features=context_mapping_features
|
|
|
|
| 109 |
)
|
| 110 |
|
| 111 |
if use_context_features:
|
| 112 |
+
|
| 113 |
self.to_features = nn.Sequential(
|
| 114 |
nn.Linear(
|
| 115 |
in_features=context_features, out_features=context_mapping_features
|
|
|
|
| 123 |
|
| 124 |
|
| 125 |
def get_mapping(
|
| 126 |
+
self,
|
| 127 |
+
time=None,
|
| 128 |
+
features=None):
|
| 129 |
"""Combines context time features and features into mapping"""
|
| 130 |
items, mapping = [], None
|
| 131 |
# Compute time features
|
| 132 |
if self.use_context_time:
|
| 133 |
+
|
|
|
|
| 134 |
items += [self.to_time(time)]
|
| 135 |
# Compute features
|
| 136 |
if self.use_context_features:
|
| 137 |
+
|
|
|
|
| 138 |
items += [self.to_features(features)]
|
| 139 |
|
| 140 |
# Compute joint mapping
|
| 141 |
if self.use_context_time or self.use_context_features:
|
| 142 |
+
# raise ValueError
|
| 143 |
mapping = reduce(torch.stack(items), "n b m -> b m", "sum")
|
| 144 |
mapping = self.to_mapping(mapping)
|
| 145 |
|
|
|
|
| 164 |
def forward(self, x: Tensor,
|
| 165 |
time: Tensor,
|
| 166 |
embedding_mask_proba: float = 0.0,
|
| 167 |
+
embedding= None,
|
| 168 |
+
features = None,
|
| 169 |
embedding_scale: float = 1.0) -> Tensor:
|
| 170 |
|
| 171 |
b, device = embedding.shape[0], embedding.device
|
|
|
|
| 178 |
embedding = torch.where(batch_mask, fixed_embedding, embedding)
|
| 179 |
|
| 180 |
if embedding_scale != 1.0:
|
| 181 |
+
|
| 182 |
+
|
| 183 |
out = self.run(x, time, embedding=embedding, features=features)
|
| 184 |
out_masked = self.run(x, time, embedding=fixed_embedding, features=features)
|
| 185 |
+
|
| 186 |
+
raise ValueError
|
| 187 |
return out_masked + (out - out_masked) * embedding_scale
|
| 188 |
+
|
| 189 |
else:
|
| 190 |
+
# raise ValueError
|
| 191 |
return self.run(x, time, embedding=embedding, features=features)
|
| 192 |
+
|
| 193 |
|
| 194 |
return x
|
| 195 |
|
|
|
|
| 203 |
style_dim: int,
|
| 204 |
multiplier: int,
|
| 205 |
use_rel_pos: bool,
|
| 206 |
+
# rel_pos_num_buckets: Optional[int] = None,
|
| 207 |
+
# rel_pos_max_distance: Optional[int] = None,
|
| 208 |
+
context_features = None,
|
| 209 |
):
|
| 210 |
super().__init__()
|
| 211 |
|
| 212 |
+
self.use_cross_attention = (context_features is not None) and (context_features > 0)
|
| 213 |
+
# print(f'{rel_pos_num_buckets=} {rel_pos_max_distance=}') # None None
|
| 214 |
+
# raise ValueError
|
| 215 |
self.attention = StyleAttention(
|
| 216 |
features=features,
|
| 217 |
style_dim=style_dim,
|
| 218 |
num_heads=num_heads,
|
| 219 |
head_features=head_features,
|
| 220 |
use_rel_pos=use_rel_pos,
|
| 221 |
+
# rel_pos_num_buckets=rel_pos_num_buckets,
|
| 222 |
+
# rel_pos_max_distance=rel_pos_max_distance,
|
| 223 |
)
|
| 224 |
|
| 225 |
if self.use_cross_attention:
|
| 226 |
+
raise ValueError
|
| 227 |
+
# self.cross_attention = StyleAttention(
|
| 228 |
+
# features=features,
|
| 229 |
+
# style_dim=style_dim,
|
| 230 |
+
# num_heads=num_heads,
|
| 231 |
+
# head_features=head_features,
|
| 232 |
+
# context_features=context_features,
|
| 233 |
+
# use_rel_pos=use_rel_pos,
|
| 234 |
+
# rel_pos_num_buckets=rel_pos_num_buckets,
|
| 235 |
+
# rel_pos_max_distance=rel_pos_max_distance,
|
| 236 |
+
# )
|
| 237 |
|
| 238 |
self.feed_forward = FeedForward(features=features, multiplier=multiplier)
|
| 239 |
|
| 240 |
+
def forward(self, x: Tensor, s: Tensor, *, context = None) -> Tensor:
|
| 241 |
x = self.attention(x, s) + x
|
| 242 |
if self.use_cross_attention:
|
| 243 |
+
raise ValueError
|
| 244 |
+
# x = self.cross_attention(x, s, context=context) + x
|
| 245 |
x = self.feed_forward(x) + x
|
| 246 |
return x
|
| 247 |
|
|
|
|
| 253 |
style_dim: int,
|
| 254 |
head_features: int,
|
| 255 |
num_heads: int,
|
| 256 |
+
context_features = None,
|
| 257 |
use_rel_pos: bool,
|
| 258 |
+
# rel_pos_num_buckets: Optional[int] = None,
|
| 259 |
+
# rel_pos_max_distance: Optional[int] = None,
|
| 260 |
):
|
| 261 |
super().__init__()
|
| 262 |
self.context_features = context_features
|
|
|
|
| 276 |
num_heads=num_heads,
|
| 277 |
head_features=head_features,
|
| 278 |
use_rel_pos=use_rel_pos,
|
| 279 |
+
# rel_pos_num_buckets=rel_pos_num_buckets,
|
| 280 |
+
# rel_pos_max_distance=rel_pos_max_distance,
|
| 281 |
)
|
| 282 |
|
| 283 |
+
def forward(self, x: Tensor, s: Tensor, *, context = None):
|
| 284 |
+
|
| 285 |
+
# raise ValueError
|
| 286 |
# Use context if provided
|
| 287 |
context = default(context, x)
|
| 288 |
+
# print(context.shape,'ppppppppppppppppppppppppppppppppppppppppppp') # bs, time, 1024
|
| 289 |
# Normalize then compute q from input and k,v from context
|
| 290 |
x, context = self.norm(x, s), self.norm_context(context, s)
|
| 291 |
|
|
|
|
| 293 |
# Compute and return attention
|
| 294 |
return self.attention(q, k, v)
|
| 295 |
|
| 296 |
+
|
| 297 |
+
def FeedForward(features,
|
| 298 |
+
multiplier):
|
| 299 |
mid_features = features * multiplier
|
| 300 |
return nn.Sequential(
|
| 301 |
nn.Linear(in_features=features, out_features=mid_features),
|
|
|
|
| 307 |
class AttentionBase(nn.Module):
|
| 308 |
def __init__(
|
| 309 |
self,
|
| 310 |
+
features,
|
| 311 |
*,
|
| 312 |
+
head_features,
|
| 313 |
+
num_heads,
|
| 314 |
+
use_rel_pos,
|
| 315 |
+
out_features = None,
|
| 316 |
+
# rel_pos_num_buckets: Optional[int] = None,
|
| 317 |
+
# rel_pos_max_distance: Optional[int] = None,
|
| 318 |
):
|
| 319 |
super().__init__()
|
| 320 |
self.scale = head_features ** -0.5
|
|
|
|
| 335 |
q, k, v = rearrange_many((q, k, v), "b n (h d) -> b h n d", h=self.num_heads)
|
| 336 |
# Compute similarity matrix
|
| 337 |
sim = einsum("... n d, ... m d -> ... n m", q, k)
|
| 338 |
+
|
| 339 |
+
# _____THERE_IS_NO_rel_po
|
| 340 |
+
# sim = (sim + self.rel_pos(*sim.shape[-2:])) if self.use_rel_pos else sim
|
| 341 |
+
# print(self.rel_pos)
|
| 342 |
+
|
| 343 |
sim = sim * self.scale
|
| 344 |
# Get attention matrix with softmax
|
| 345 |
attn = sim.softmax(dim=-1)
|
|
|
|
| 352 |
class Attention(nn.Module):
|
| 353 |
def __init__(
|
| 354 |
self,
|
| 355 |
+
features,
|
| 356 |
*,
|
| 357 |
+
head_features,
|
| 358 |
+
num_heads,
|
| 359 |
+
out_features=None,
|
| 360 |
+
context_features=None,
|
| 361 |
+
use_rel_pos,
|
| 362 |
+
# rel_pos_num_buckets: Optional[int] = None,
|
| 363 |
+
# rel_pos_max_distance: Optional[int] = None,
|
| 364 |
):
|
| 365 |
super().__init__()
|
| 366 |
self.context_features = context_features
|
|
|
|
| 382 |
num_heads=num_heads,
|
| 383 |
head_features=head_features,
|
| 384 |
use_rel_pos=use_rel_pos,
|
| 385 |
+
# rel_pos_num_buckets=rel_pos_num_buckets,
|
| 386 |
+
# rel_pos_max_distance=rel_pos_max_distance,
|
| 387 |
)
|
| 388 |
|
| 389 |
+
def forward(self, x: Tensor, *, context = None) -> Tensor:
|
| 390 |
+
# assert_message = "You must provide a context when using context_features"
|
| 391 |
+
# assert not self.context_features or exists(context), assert_message
|
| 392 |
# Use context if provided
|
| 393 |
context = default(context, x)
|
| 394 |
# Normalize then compute q from input and k,v from context
|
Modules/diffusion/sampler.py
CHANGED
|
@@ -1,11 +1,59 @@
|
|
| 1 |
from math import atan, cos, pi, sin, sqrt
|
| 2 |
-
from typing import Any, Callable, List, Optional, Tuple, Type
|
| 3 |
-
import torch
|
| 4 |
import torch.nn as nn
|
| 5 |
-
import torch.nn.functional as F
|
| 6 |
from einops import rearrange
|
| 7 |
from torch import Tensor
|
| 8 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 9 |
|
| 10 |
|
| 11 |
class LogNormalDistribution():
|
|
@@ -29,14 +77,13 @@ class UniformDistribution():
|
|
| 29 |
def to_batch(
|
| 30 |
batch_size: int,
|
| 31 |
device: torch.device,
|
| 32 |
-
x
|
| 33 |
-
xs
|
| 34 |
-
)
|
| 35 |
-
assert exists(x) ^ exists(xs), "Either x or xs must be provided"
|
| 36 |
# If x provided use the same for all batch items
|
| 37 |
-
if exists(x):
|
| 38 |
xs = torch.full(size=(batch_size,), fill_value=x).to(device)
|
| 39 |
-
assert exists(xs)
|
| 40 |
return xs
|
| 41 |
|
| 42 |
class KDiffusion(nn.Module):
|
|
@@ -58,7 +105,7 @@ class KDiffusion(nn.Module):
|
|
| 58 |
self.sigma_distribution = sigma_distribution
|
| 59 |
self.dynamic_threshold = dynamic_threshold
|
| 60 |
|
| 61 |
-
def get_scale_weights(self, sigmas
|
| 62 |
sigma_data = self.sigma_data
|
| 63 |
c_noise = torch.log(sigmas) * 0.25
|
| 64 |
sigmas = rearrange(sigmas, "b -> b 1 1")
|
|
@@ -69,9 +116,9 @@ class KDiffusion(nn.Module):
|
|
| 69 |
|
| 70 |
def denoise_fn(
|
| 71 |
self,
|
| 72 |
-
x_noisy
|
| 73 |
-
sigmas
|
| 74 |
-
sigma
|
| 75 |
**kwargs,
|
| 76 |
):
|
| 77 |
# raise ValueError
|
|
@@ -107,7 +154,7 @@ class KarrasSchedule(nn.Module):
|
|
| 107 |
self.sigma_max = sigma_max
|
| 108 |
self.rho = rho
|
| 109 |
|
| 110 |
-
def forward(self, num_steps: int, device
|
| 111 |
rho_inv = 1.0 / self.rho
|
| 112 |
steps = torch.arange(num_steps, device=device, dtype=torch.float32)
|
| 113 |
sigmas = (
|
|
@@ -118,32 +165,7 @@ class KarrasSchedule(nn.Module):
|
|
| 118 |
sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
|
| 119 |
return sigmas
|
| 120 |
|
| 121 |
-
|
| 122 |
-
""" Samplers """
|
| 123 |
-
|
| 124 |
-
|
| 125 |
-
class Sampler(nn.Module):
|
| 126 |
-
|
| 127 |
-
|
| 128 |
-
|
| 129 |
-
def forward(
|
| 130 |
-
self, noise: Tensor, fn: Callable, sigmas: Tensor, num_steps: int
|
| 131 |
-
) -> Tensor:
|
| 132 |
-
raise NotImplementedError()
|
| 133 |
-
|
| 134 |
-
def inpaint(
|
| 135 |
-
self,
|
| 136 |
-
source: Tensor,
|
| 137 |
-
mask: Tensor,
|
| 138 |
-
fn: Callable,
|
| 139 |
-
sigmas: Tensor,
|
| 140 |
-
num_steps: int,
|
| 141 |
-
num_resamples: int,
|
| 142 |
-
) -> Tensor:
|
| 143 |
-
raise NotImplementedError("Inpainting not available with current sampler")
|
| 144 |
-
|
| 145 |
-
|
| 146 |
-
class ADPM2Sampler(Sampler):
|
| 147 |
"""https://www.desmos.com/calculator/jbxjlqd9mb"""
|
| 148 |
|
| 149 |
diffusion_types = [KDiffusion,] # VKDiffusion]
|
|
@@ -152,15 +174,17 @@ class ADPM2Sampler(Sampler):
|
|
| 152 |
super().__init__()
|
| 153 |
self.rho = rho
|
| 154 |
|
| 155 |
-
def get_sigmas(self,
|
|
|
|
|
|
|
| 156 |
r = self.rho
|
| 157 |
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
|
| 158 |
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
|
| 159 |
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
|
| 160 |
return sigma_up, sigma_down, sigma_mid
|
| 161 |
|
| 162 |
-
def step(self, x
|
| 163 |
-
|
| 164 |
sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
|
| 165 |
# Derivative at sigma (∂x/∂sigma)
|
| 166 |
d = (x - fn(x, sigma=sigma)) / sigma
|
|
@@ -175,7 +199,7 @@ class ADPM2Sampler(Sampler):
|
|
| 175 |
return x_next
|
| 176 |
|
| 177 |
def forward(
|
| 178 |
-
self, noise
|
| 179 |
# raise ValueError
|
| 180 |
x = sigmas[0] * noise
|
| 181 |
# Denoise to sample
|
|
@@ -211,7 +235,7 @@ class DiffusionSampler(nn.Module):
|
|
| 211 |
# raise ValueError
|
| 212 |
device = noise.device
|
| 213 |
num_steps = default(num_steps, self.num_steps) # type: ignore
|
| 214 |
-
|
| 215 |
# Compute sigmas using schedule
|
| 216 |
sigmas = self.sigma_schedule(num_steps, device)
|
| 217 |
# Append additional kwargs to denoise function (used e.g. for conditional unet)
|
|
|
|
| 1 |
from math import atan, cos, pi, sin, sqrt
|
|
|
|
|
|
|
| 2 |
import torch.nn as nn
|
|
|
|
| 3 |
from einops import rearrange
|
| 4 |
from torch import Tensor
|
| 5 |
+
|
| 6 |
+
from functools import reduce
|
| 7 |
+
from inspect import isfunction
|
| 8 |
+
from math import ceil, floor, log2, pi
|
| 9 |
+
# from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
| 10 |
+
import torch
|
| 11 |
+
import torch.nn.functional as F
|
| 12 |
+
from einops import rearrange
|
| 13 |
+
from torch import Generator, Tensor
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
|
| 19 |
+
|
| 20 |
+
# def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
|
| 21 |
+
# return isinstance(obj, list) or isinstance(obj, tuple)
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def default(val, d):
|
| 25 |
+
if val is not None: #exists(val):
|
| 26 |
+
return val
|
| 27 |
+
return d #d() if isfunction(d) else d
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
# def to_list(val: Union[T, Sequence[T]]) -> List[T]:
|
| 31 |
+
# if isinstance(val, tuple):
|
| 32 |
+
# return list(val)
|
| 33 |
+
# if isinstance(val, list):
|
| 34 |
+
# return val
|
| 35 |
+
# return [val] # type: ignore
|
| 36 |
+
|
| 37 |
+
|
| 38 |
+
# def prod(vals: Sequence[int]) -> int:
|
| 39 |
+
# return reduce(lambda x, y: x * y, vals)
|
| 40 |
+
|
| 41 |
+
|
| 42 |
+
def closest_power_2(x: float) -> int:
|
| 43 |
+
exponent = log2(x)
|
| 44 |
+
distance_fn = lambda z: abs(x - 2 ** z) # noqa
|
| 45 |
+
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
|
| 46 |
+
return 2 ** int(exponent_closest)
|
| 47 |
+
|
| 48 |
+
def rand_bool(shape, proba, device = None):
|
| 49 |
+
if proba == 1:
|
| 50 |
+
return torch.ones(shape, device=device, dtype=torch.bool)
|
| 51 |
+
elif proba == 0:
|
| 52 |
+
return torch.zeros(shape, device=device, dtype=torch.bool)
|
| 53 |
+
else:
|
| 54 |
+
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
|
| 55 |
+
|
| 56 |
+
# ============================= END functions from diffusION.utils
|
| 57 |
|
| 58 |
|
| 59 |
class LogNormalDistribution():
|
|
|
|
| 77 |
def to_batch(
|
| 78 |
batch_size: int,
|
| 79 |
device: torch.device,
|
| 80 |
+
x = None,
|
| 81 |
+
xs = None):
|
| 82 |
+
# assert exists(x) ^ exists(xs), "Either x or xs must be provided"
|
|
|
|
| 83 |
# If x provided use the same for all batch items
|
| 84 |
+
if x is not None: #exists(x):
|
| 85 |
xs = torch.full(size=(batch_size,), fill_value=x).to(device)
|
| 86 |
+
# assert exists(xs)
|
| 87 |
return xs
|
| 88 |
|
| 89 |
class KDiffusion(nn.Module):
|
|
|
|
| 105 |
self.sigma_distribution = sigma_distribution
|
| 106 |
self.dynamic_threshold = dynamic_threshold
|
| 107 |
|
| 108 |
+
def get_scale_weights(self, sigmas):
|
| 109 |
sigma_data = self.sigma_data
|
| 110 |
c_noise = torch.log(sigmas) * 0.25
|
| 111 |
sigmas = rearrange(sigmas, "b -> b 1 1")
|
|
|
|
| 116 |
|
| 117 |
def denoise_fn(
|
| 118 |
self,
|
| 119 |
+
x_noisy,
|
| 120 |
+
sigmas = None,
|
| 121 |
+
sigma = None,
|
| 122 |
**kwargs,
|
| 123 |
):
|
| 124 |
# raise ValueError
|
|
|
|
| 154 |
self.sigma_max = sigma_max
|
| 155 |
self.rho = rho
|
| 156 |
|
| 157 |
+
def forward(self, num_steps: int, device):
|
| 158 |
rho_inv = 1.0 / self.rho
|
| 159 |
steps = torch.arange(num_steps, device=device, dtype=torch.float32)
|
| 160 |
sigmas = (
|
|
|
|
| 165 |
sigmas = F.pad(sigmas, pad=(0, 1), value=0.0)
|
| 166 |
return sigmas
|
| 167 |
|
| 168 |
+
class ADPM2Sampler(nn.Module):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 169 |
"""https://www.desmos.com/calculator/jbxjlqd9mb"""
|
| 170 |
|
| 171 |
diffusion_types = [KDiffusion,] # VKDiffusion]
|
|
|
|
| 174 |
super().__init__()
|
| 175 |
self.rho = rho
|
| 176 |
|
| 177 |
+
def get_sigmas(self,
|
| 178 |
+
sigma,
|
| 179 |
+
sigma_next):
|
| 180 |
r = self.rho
|
| 181 |
sigma_up = sqrt(sigma_next ** 2 * (sigma ** 2 - sigma_next ** 2) / sigma ** 2)
|
| 182 |
sigma_down = sqrt(sigma_next ** 2 - sigma_up ** 2)
|
| 183 |
sigma_mid = ((sigma ** (1 / r) + sigma_down ** (1 / r)) / 2) ** r
|
| 184 |
return sigma_up, sigma_down, sigma_mid
|
| 185 |
|
| 186 |
+
def step(self, x, fn, sigma, sigma_next):
|
| 187 |
+
|
| 188 |
sigma_up, sigma_down, sigma_mid = self.get_sigmas(sigma, sigma_next)
|
| 189 |
# Derivative at sigma (∂x/∂sigma)
|
| 190 |
d = (x - fn(x, sigma=sigma)) / sigma
|
|
|
|
| 199 |
return x_next
|
| 200 |
|
| 201 |
def forward(
|
| 202 |
+
self, noise, fn, sigmas, num_steps):
|
| 203 |
# raise ValueError
|
| 204 |
x = sigmas[0] * noise
|
| 205 |
# Denoise to sample
|
|
|
|
| 235 |
# raise ValueError
|
| 236 |
device = noise.device
|
| 237 |
num_steps = default(num_steps, self.num_steps) # type: ignore
|
| 238 |
+
|
| 239 |
# Compute sigmas using schedule
|
| 240 |
sigmas = self.sigma_schedule(num_steps, device)
|
| 241 |
# Append additional kwargs to denoise function (used e.g. for conditional unet)
|
Modules/diffusion/utils.py
DELETED
|
@@ -1,82 +0,0 @@
|
|
| 1 |
-
from functools import reduce
|
| 2 |
-
from inspect import isfunction
|
| 3 |
-
from math import ceil, floor, log2, pi
|
| 4 |
-
from typing import Callable, Dict, List, Optional, Sequence, Tuple, TypeVar, Union
|
| 5 |
-
|
| 6 |
-
import torch
|
| 7 |
-
import torch.nn.functional as F
|
| 8 |
-
from einops import rearrange
|
| 9 |
-
from torch import Generator, Tensor
|
| 10 |
-
from typing_extensions import TypeGuard
|
| 11 |
-
|
| 12 |
-
T = TypeVar("T")
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def exists(val: Optional[T]) -> TypeGuard[T]:
|
| 16 |
-
return val is not None
|
| 17 |
-
|
| 18 |
-
|
| 19 |
-
def iff(condition: bool, value: T) -> Optional[T]:
|
| 20 |
-
return value if condition else None
|
| 21 |
-
|
| 22 |
-
|
| 23 |
-
def is_sequence(obj: T) -> TypeGuard[Union[list, tuple]]:
|
| 24 |
-
return isinstance(obj, list) or isinstance(obj, tuple)
|
| 25 |
-
|
| 26 |
-
|
| 27 |
-
def default(val: Optional[T], d: Union[Callable[..., T], T]) -> T:
|
| 28 |
-
if exists(val):
|
| 29 |
-
return val
|
| 30 |
-
return d() if isfunction(d) else d
|
| 31 |
-
|
| 32 |
-
|
| 33 |
-
def to_list(val: Union[T, Sequence[T]]) -> List[T]:
|
| 34 |
-
if isinstance(val, tuple):
|
| 35 |
-
return list(val)
|
| 36 |
-
if isinstance(val, list):
|
| 37 |
-
return val
|
| 38 |
-
return [val] # type: ignore
|
| 39 |
-
|
| 40 |
-
|
| 41 |
-
def prod(vals: Sequence[int]) -> int:
|
| 42 |
-
return reduce(lambda x, y: x * y, vals)
|
| 43 |
-
|
| 44 |
-
|
| 45 |
-
def closest_power_2(x: float) -> int:
|
| 46 |
-
exponent = log2(x)
|
| 47 |
-
distance_fn = lambda z: abs(x - 2 ** z) # noqa
|
| 48 |
-
exponent_closest = min((floor(exponent), ceil(exponent)), key=distance_fn)
|
| 49 |
-
return 2 ** int(exponent_closest)
|
| 50 |
-
|
| 51 |
-
def rand_bool(shape, proba, device = None):
|
| 52 |
-
if proba == 1:
|
| 53 |
-
return torch.ones(shape, device=device, dtype=torch.bool)
|
| 54 |
-
elif proba == 0:
|
| 55 |
-
return torch.zeros(shape, device=device, dtype=torch.bool)
|
| 56 |
-
else:
|
| 57 |
-
return torch.bernoulli(torch.full(shape, proba, device=device)).to(torch.bool)
|
| 58 |
-
|
| 59 |
-
|
| 60 |
-
"""
|
| 61 |
-
Kwargs Utils
|
| 62 |
-
"""
|
| 63 |
-
|
| 64 |
-
|
| 65 |
-
def group_dict_by_prefix(prefix: str, d: Dict) -> Tuple[Dict, Dict]:
|
| 66 |
-
return_dicts: Tuple[Dict, Dict] = ({}, {})
|
| 67 |
-
for key in d.keys():
|
| 68 |
-
no_prefix = int(not key.startswith(prefix))
|
| 69 |
-
return_dicts[no_prefix][key] = d[key]
|
| 70 |
-
return return_dicts
|
| 71 |
-
|
| 72 |
-
|
| 73 |
-
def groupby(prefix: str, d: Dict, keep_prefix: bool = False) -> Tuple[Dict, Dict]:
|
| 74 |
-
kwargs_with_prefix, kwargs = group_dict_by_prefix(prefix, d)
|
| 75 |
-
if keep_prefix:
|
| 76 |
-
return kwargs_with_prefix, kwargs
|
| 77 |
-
kwargs_no_prefix = {k[len(prefix) :]: v for k, v in kwargs_with_prefix.items()}
|
| 78 |
-
return kwargs_no_prefix, kwargs
|
| 79 |
-
|
| 80 |
-
|
| 81 |
-
def prefix_dict(prefix: str, d: Dict) -> Dict:
|
| 82 |
-
return {prefix + str(k): v for k, v in d.items()}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
models.py
CHANGED
|
@@ -2,27 +2,96 @@
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import os.path as osp
|
| 5 |
-
|
| 6 |
import copy
|
| 7 |
import math
|
| 8 |
-
|
| 9 |
import numpy as np
|
| 10 |
import torch
|
| 11 |
import torch.nn as nn
|
| 12 |
import torch.nn.functional as F
|
| 13 |
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
| 14 |
-
|
| 15 |
from Utils.ASR.models import ASRCNN
|
| 16 |
from Utils.JDC.model import JDCNet
|
| 17 |
-
|
| 18 |
from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
|
| 19 |
from Modules.diffusion.modules import StyleTransformer1d
|
| 20 |
-
from Modules.diffusion.diffusion import AudioDiffusionConditional
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 21 |
|
| 22 |
|
| 23 |
|
| 24 |
-
from munch import Munch
|
| 25 |
-
import yaml
|
| 26 |
|
| 27 |
class LearnedDownSample(nn.Module):
|
| 28 |
def __init__(self, layer_type, dim_in):
|
|
@@ -561,7 +630,7 @@ def build_model(args, text_aligner, pitch_extractor, bert):
|
|
| 561 |
channels=args.style_dim*2,
|
| 562 |
context_features=args.style_dim*2,
|
| 563 |
)
|
| 564 |
-
|
| 565 |
diffusion.diffusion = KDiffusion(
|
| 566 |
net=diffusion.unet,
|
| 567 |
sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
|
|
|
|
| 2 |
|
| 3 |
import os
|
| 4 |
import os.path as osp
|
|
|
|
| 5 |
import copy
|
| 6 |
import math
|
|
|
|
| 7 |
import numpy as np
|
| 8 |
import torch
|
| 9 |
import torch.nn as nn
|
| 10 |
import torch.nn.functional as F
|
| 11 |
from torch.nn.utils import weight_norm, remove_weight_norm, spectral_norm
|
|
|
|
| 12 |
from Utils.ASR.models import ASRCNN
|
| 13 |
from Utils.JDC.model import JDCNet
|
|
|
|
| 14 |
from Modules.diffusion.sampler import KDiffusion, LogNormalDistribution
|
| 15 |
from Modules.diffusion.modules import StyleTransformer1d
|
| 16 |
+
# from Modules.diffusion.diffusion import AudioDiffusionConditional
|
| 17 |
+
from munch import Munch
|
| 18 |
+
import yaml
|
| 19 |
+
from math import pi
|
| 20 |
+
from random import randint
|
| 21 |
+
# from typing import Any, Optional, Sequence, Tuple, Union
|
| 22 |
+
import torch
|
| 23 |
+
from einops import rearrange
|
| 24 |
+
from torch import Tensor, nn
|
| 25 |
+
from tqdm import tqdm
|
| 26 |
+
# from Modules.diffusion.utils import *
|
| 27 |
+
# from Modules.diffusion.sampler import *
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
|
| 31 |
+
|
| 32 |
+
def get_default_model_kwargs():
|
| 33 |
+
return dict(
|
| 34 |
+
channels=128,
|
| 35 |
+
patch_size=16,
|
| 36 |
+
multipliers=[1, 2, 4, 4, 4, 4, 4],
|
| 37 |
+
factors=[4, 4, 4, 2, 2, 2],
|
| 38 |
+
num_blocks=[2, 2, 2, 2, 2, 2],
|
| 39 |
+
attentions=[0, 0, 0, 1, 1, 1, 1],
|
| 40 |
+
attention_heads=8,
|
| 41 |
+
attention_features=64,
|
| 42 |
+
attention_multiplier=2,
|
| 43 |
+
attention_use_rel_pos=False,
|
| 44 |
+
diffusion_type="v",
|
| 45 |
+
diffusion_sigma_distribution=UniformDistribution(),
|
| 46 |
+
)
|
| 47 |
+
|
| 48 |
+
|
| 49 |
+
def get_default_sampling_kwargs():
|
| 50 |
+
return dict(sigma_schedule=LinearSchedule(), sampler=VSampler(), clamp=True)
|
| 51 |
+
|
| 52 |
+
class AudioDiffusionConditional(nn.Module):
|
| 53 |
+
def __init__(
|
| 54 |
+
self,
|
| 55 |
+
embedding_features: int,
|
| 56 |
+
embedding_max_length: int,
|
| 57 |
+
embedding_mask_proba: float = 0.1,
|
| 58 |
+
**kwargs,
|
| 59 |
+
):
|
| 60 |
+
self.unet = None
|
| 61 |
+
self.embedding_mask_proba = embedding_mask_proba
|
| 62 |
+
# default_kwargs = dict(
|
| 63 |
+
# **get_default_model_kwargs(),
|
| 64 |
+
# unet_type="cfg",
|
| 65 |
+
# context_embedding_features=embedding_features,
|
| 66 |
+
# context_embedding_max_length=embedding_max_length,
|
| 67 |
+
# )
|
| 68 |
+
super().__init__()
|
| 69 |
+
|
| 70 |
+
def forward(self, *args, **kwargs):
|
| 71 |
+
default_kwargs = dict(embedding_mask_proba=self.embedding_mask_proba)
|
| 72 |
+
return self.diffusion(*args, **{**default_kwargs, **kwargs})
|
| 73 |
+
|
| 74 |
+
# def sample(self, *args, **kwargs):
|
| 75 |
+
# default_kwargs = dict(
|
| 76 |
+
# **get_default_sampling_kwargs(),
|
| 77 |
+
# embedding_scale=5.0,
|
| 78 |
+
# )
|
| 79 |
+
# return super().sample(*args, **{**default_kwargs, **kwargs})
|
| 80 |
+
|
| 81 |
+
|
| 82 |
+
|
| 83 |
+
|
| 84 |
+
|
| 85 |
+
|
| 86 |
+
|
| 87 |
+
|
| 88 |
+
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
|
| 92 |
|
| 93 |
|
| 94 |
|
|
|
|
|
|
|
| 95 |
|
| 96 |
class LearnedDownSample(nn.Module):
|
| 97 |
def __init__(self, layer_type, dim_in):
|
|
|
|
| 630 |
channels=args.style_dim*2,
|
| 631 |
context_features=args.style_dim*2,
|
| 632 |
)
|
| 633 |
+
# this initialises self.diffusion for AudioDiffusionConditional
|
| 634 |
diffusion.diffusion = KDiffusion(
|
| 635 |
net=diffusion.unet,
|
| 636 |
sigma_distribution=LogNormalDistribution(mean = args.diffusion.dist.mean, std = args.diffusion.dist.std),
|