Clone tab: search fonts/pt_models with filename aliases; vendor models/
Browse files- .gitignore +2 -0
- app.py +57 -33
- models/__init__.py +0 -0
- models/reference_encoder.py +86 -0
- models/text_encoder.py +358 -0
- models/utils.py +197 -0
- models/vf_estimator.py +507 -0
.gitignore
CHANGED
|
@@ -13,6 +13,8 @@ voices
|
|
| 13 |
renikud.onnx
|
| 14 |
model.onnx
|
| 15 |
pt_weights
|
|
|
|
|
|
|
| 16 |
|
| 17 |
# Virtual environments
|
| 18 |
venv/
|
|
|
|
| 13 |
renikud.onnx
|
| 14 |
model.onnx
|
| 15 |
pt_weights
|
| 16 |
+
fonts/pt_models
|
| 17 |
+
pt_models
|
| 18 |
|
| 19 |
# Virtual environments
|
| 20 |
venv/
|
app.py
CHANGED
|
@@ -122,7 +122,7 @@ class TextProcessor:
|
|
| 122 |
from phonemizer.separator import Separator
|
| 123 |
EspeakWrapper.set_library(espeakng_loader.get_library_path())
|
| 124 |
if hasattr(EspeakWrapper, "set_data_path"):
|
| 125 |
-
|
| 126 |
self._espeak_separator = Separator(phone="", word=" ", syllable="")
|
| 127 |
self._espeak_ready = True
|
| 128 |
except Exception as e:
|
|
@@ -157,7 +157,7 @@ class TextProcessor:
|
|
| 157 |
return re.sub(r"\s+", " ", r.stdout.replace("\n", " ")).strip()
|
| 158 |
except Exception as e:
|
| 159 |
print(f"[WARN] espeak-ng subprocess failed for {lang}: {e}")
|
| 160 |
-
|
| 161 |
|
| 162 |
def _phonemize_segment(self, content: str, lang: str) -> str:
|
| 163 |
content = content.strip()
|
|
@@ -211,7 +211,7 @@ class UnicodeProcessor:
|
|
| 211 |
if isinstance(raw, dict) and "char_to_id" in raw:
|
| 212 |
self.pad_id = int(raw.get("pad_id", 0))
|
| 213 |
self._char_to_id = {k: int(v) for k, v in raw["char_to_id"].items()}
|
| 214 |
-
|
| 215 |
self._char_to_id = {chr(int(k)): int(v) for k, v in raw.items()}
|
| 216 |
print(f"[INFO] Loaded vocab from {indexer_path} ({len(self._char_to_id)} entries)")
|
| 217 |
else:
|
|
@@ -327,7 +327,7 @@ def chunk_text(text: str, max_len: int = 300) -> List[str]:
|
|
| 327 |
for sentence in re.split(pattern, paragraph):
|
| 328 |
if len(current) + len(sentence) + 1 <= max_len:
|
| 329 |
current += (" " if current else "") + sentence
|
| 330 |
-
|
| 331 |
if current:
|
| 332 |
chunks.append(current.strip())
|
| 333 |
current = sentence
|
|
@@ -550,7 +550,30 @@ def synthesize_text(text: str, voice: str, lang: str, steps: int, speed: float,
|
|
| 550 |
# Voice-clone tab (runs export_new_voice.py)
|
| 551 |
# ============================================================
|
| 552 |
EXPORT_SCRIPT = os.path.join(os.path.dirname(__file__), "export_new_voice.py")
|
| 553 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 554 |
|
| 555 |
|
| 556 |
def _refresh_voices() -> None:
|
|
@@ -568,19 +591,20 @@ def clone_voice(ref_wav: Optional[str], voice_name: str):
|
|
| 568 |
safe = re.sub(r"[^\w\-]+", "_", voice_name.strip())
|
| 569 |
out_path = os.path.join(VOICES_DIR, f"{safe}.json")
|
| 570 |
|
| 571 |
-
needed = {
|
| 572 |
-
|
| 573 |
-
"ttl_ckpt": os.path.join(PT_WEIGHTS_DIR, "vf_estimator.safetensors"),
|
| 574 |
-
"dp_ckpt": os.path.join(PT_WEIGHTS_DIR, "duration_predictor.safetensors"),
|
| 575 |
-
"stats": os.path.join(PT_WEIGHTS_DIR, "stats_multilingual.pt"),
|
| 576 |
-
}
|
| 577 |
-
missing = [v for v in needed.values() if not os.path.exists(v)]
|
| 578 |
if missing:
|
|
|
|
|
|
|
|
|
|
|
|
|
| 579 |
return (
|
| 580 |
-
"Voice cloning needs PyTorch checkpoints.
|
| 581 |
-
"
|
| 582 |
-
"
|
| 583 |
-
|
|
|
|
|
|
|
| 584 |
), gr.update()
|
| 585 |
|
| 586 |
cmd = [
|
|
@@ -666,40 +690,40 @@ with gr.Blocks(title="BlueTTS — Multilingual TTS") as demo:
|
|
| 666 |
|
| 667 |
with gr.Tabs():
|
| 668 |
with gr.TabItem("Synthesize"):
|
| 669 |
-
|
| 670 |
-
|
| 671 |
label="Text", placeholder="Type or paste text here…",
|
| 672 |
lines=4, elem_classes="big-input",
|
| 673 |
-
|
| 674 |
-
|
| 675 |
-
|
| 676 |
-
|
| 677 |
-
|
| 678 |
choices=[("English 🇺🇸", "en"), ("Hebrew 🇮🇱", "he"),
|
| 679 |
("Spanish 🇪🇸", "es"), ("German 🇩🇪", "de"),
|
| 680 |
("Italian 🇮🇹", "it")],
|
| 681 |
value="en", label="Language", elem_classes="ctrl-lang",
|
| 682 |
)
|
| 683 |
-
|
| 684 |
choices=list(VOICES.keys()),
|
| 685 |
value=next(iter(VOICES.keys()), None),
|
| 686 |
label="Voice", elem_classes="ctrl-voice",
|
| 687 |
-
|
| 688 |
-
|
| 689 |
steps_input = gr.Slider(2, 32, 8, step=1, label="Quality (steps)", elem_classes="ctrl-steps")
|
| 690 |
speed_input = gr.Slider(0.5, 2.0, 1.0, step=0.05, label="Speed", elem_classes="ctrl-speed")
|
| 691 |
cfg_input = gr.Slider(1.0, 7.0, 3.0, step=0.1, label="CFG Scale", elem_classes="ctrl-cfg")
|
| 692 |
-
|
| 693 |
-
|
| 694 |
-
|
| 695 |
|
| 696 |
gr.Examples(examples=EXAMPLES, inputs=[text_input, lang_input], label="Examples")
|
| 697 |
|
| 698 |
-
|
| 699 |
-
|
| 700 |
inputs=[text_input, voice_input, lang_input, steps_input, speed_input, cfg_input],
|
| 701 |
-
|
| 702 |
-
|
| 703 |
|
| 704 |
with gr.TabItem("Clone Voice"):
|
| 705 |
with gr.Column(elem_classes="card"):
|
|
|
|
| 122 |
from phonemizer.separator import Separator
|
| 123 |
EspeakWrapper.set_library(espeakng_loader.get_library_path())
|
| 124 |
if hasattr(EspeakWrapper, "set_data_path"):
|
| 125 |
+
EspeakWrapper.set_data_path(espeakng_loader.get_data_path())
|
| 126 |
self._espeak_separator = Separator(phone="", word=" ", syllable="")
|
| 127 |
self._espeak_ready = True
|
| 128 |
except Exception as e:
|
|
|
|
| 157 |
return re.sub(r"\s+", " ", r.stdout.replace("\n", " ")).strip()
|
| 158 |
except Exception as e:
|
| 159 |
print(f"[WARN] espeak-ng subprocess failed for {lang}: {e}")
|
| 160 |
+
return text
|
| 161 |
|
| 162 |
def _phonemize_segment(self, content: str, lang: str) -> str:
|
| 163 |
content = content.strip()
|
|
|
|
| 211 |
if isinstance(raw, dict) and "char_to_id" in raw:
|
| 212 |
self.pad_id = int(raw.get("pad_id", 0))
|
| 213 |
self._char_to_id = {k: int(v) for k, v in raw["char_to_id"].items()}
|
| 214 |
+
else:
|
| 215 |
self._char_to_id = {chr(int(k)): int(v) for k, v in raw.items()}
|
| 216 |
print(f"[INFO] Loaded vocab from {indexer_path} ({len(self._char_to_id)} entries)")
|
| 217 |
else:
|
|
|
|
| 327 |
for sentence in re.split(pattern, paragraph):
|
| 328 |
if len(current) + len(sentence) + 1 <= max_len:
|
| 329 |
current += (" " if current else "") + sentence
|
| 330 |
+
else:
|
| 331 |
if current:
|
| 332 |
chunks.append(current.strip())
|
| 333 |
current = sentence
|
|
|
|
| 550 |
# Voice-clone tab (runs export_new_voice.py)
|
| 551 |
# ============================================================
|
| 552 |
EXPORT_SCRIPT = os.path.join(os.path.dirname(__file__), "export_new_voice.py")
|
| 553 |
+
|
| 554 |
+
# Accept checkpoints from a handful of common locations (with the filename
|
| 555 |
+
# variants we've seen in the wild) so the clone tab works out of the box.
|
| 556 |
+
PT_WEIGHTS_SEARCH = [
|
| 557 |
+
"pt_weights",
|
| 558 |
+
os.path.join("fonts", "pt_models"),
|
| 559 |
+
"pt_models",
|
| 560 |
+
]
|
| 561 |
+
PT_WEIGHT_ALIASES: dict[str, list[str]] = {
|
| 562 |
+
"ae_ckpt": ["blue_codec.safetensors", "blue_codec.pt"],
|
| 563 |
+
"ttl_ckpt": ["vf_estimator.safetensors", "vf_estimator.pt", "vf_estimetor.pt"],
|
| 564 |
+
"dp_ckpt": ["duration_predictor.safetensors", "duration_predictor.pt",
|
| 565 |
+
"duration_predictor_final.pt"],
|
| 566 |
+
"stats": ["stats_multilingual.pt", "stats.pt"],
|
| 567 |
+
}
|
| 568 |
+
|
| 569 |
+
|
| 570 |
+
def _find_pt_weight(aliases: list[str]) -> Optional[str]:
|
| 571 |
+
for d in PT_WEIGHTS_SEARCH:
|
| 572 |
+
for name in aliases:
|
| 573 |
+
p = os.path.join(d, name)
|
| 574 |
+
if os.path.exists(p):
|
| 575 |
+
return p
|
| 576 |
+
return None
|
| 577 |
|
| 578 |
|
| 579 |
def _refresh_voices() -> None:
|
|
|
|
| 591 |
safe = re.sub(r"[^\w\-]+", "_", voice_name.strip())
|
| 592 |
out_path = os.path.join(VOICES_DIR, f"{safe}.json")
|
| 593 |
|
| 594 |
+
needed: dict[str, Optional[str]] = {k: _find_pt_weight(v) for k, v in PT_WEIGHT_ALIASES.items()}
|
| 595 |
+
missing = [k for k, v in needed.items() if v is None]
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 596 |
if missing:
|
| 597 |
+
searched = ", ".join(PT_WEIGHTS_SEARCH)
|
| 598 |
+
wanted = "\n".join(
|
| 599 |
+
f" {k}: any of {PT_WEIGHT_ALIASES[k]}" for k in missing
|
| 600 |
+
)
|
| 601 |
return (
|
| 602 |
+
"Voice cloning needs PyTorch checkpoints. I looked in: "
|
| 603 |
+
f"{searched}\nMissing:\n{wanted}\n\n"
|
| 604 |
+
"Fetch them with:\n"
|
| 605 |
+
" hf download notmax123/blue blue_codec.safetensors "
|
| 606 |
+
"duration_predictor.safetensors vf_estimator.safetensors "
|
| 607 |
+
"stats_multilingual.pt --local-dir pt_weights"
|
| 608 |
), gr.update()
|
| 609 |
|
| 610 |
cmd = [
|
|
|
|
| 690 |
|
| 691 |
with gr.Tabs():
|
| 692 |
with gr.TabItem("Synthesize"):
|
| 693 |
+
with gr.Column(elem_classes="card"):
|
| 694 |
+
text_input = gr.Textbox(
|
| 695 |
label="Text", placeholder="Type or paste text here…",
|
| 696 |
lines=4, elem_classes="big-input",
|
| 697 |
+
value="Great ideas become real when a small team keeps building every single day.",
|
| 698 |
+
)
|
| 699 |
+
with gr.Column(elem_classes="controls-row"):
|
| 700 |
+
with gr.Row(elem_classes="ctrl-row1"):
|
| 701 |
+
lang_input = gr.Dropdown(
|
| 702 |
choices=[("English 🇺🇸", "en"), ("Hebrew 🇮🇱", "he"),
|
| 703 |
("Spanish 🇪🇸", "es"), ("German 🇩🇪", "de"),
|
| 704 |
("Italian 🇮🇹", "it")],
|
| 705 |
value="en", label="Language", elem_classes="ctrl-lang",
|
| 706 |
)
|
| 707 |
+
voice_input = gr.Dropdown(
|
| 708 |
choices=list(VOICES.keys()),
|
| 709 |
value=next(iter(VOICES.keys()), None),
|
| 710 |
label="Voice", elem_classes="ctrl-voice",
|
| 711 |
+
)
|
| 712 |
+
with gr.Row(elem_classes="ctrl-row2"):
|
| 713 |
steps_input = gr.Slider(2, 32, 8, step=1, label="Quality (steps)", elem_classes="ctrl-steps")
|
| 714 |
speed_input = gr.Slider(0.5, 2.0, 1.0, step=0.05, label="Speed", elem_classes="ctrl-speed")
|
| 715 |
cfg_input = gr.Slider(1.0, 7.0, 3.0, step=0.1, label="CFG Scale", elem_classes="ctrl-cfg")
|
| 716 |
+
btn = gr.Button("⚡ Generate Speech", elem_classes="gen-btn")
|
| 717 |
+
audio_out = gr.Audio(label="Output", type="numpy", autoplay=True)
|
| 718 |
+
stats_out = gr.HTML()
|
| 719 |
|
| 720 |
gr.Examples(examples=EXAMPLES, inputs=[text_input, lang_input], label="Examples")
|
| 721 |
|
| 722 |
+
btn.click(
|
| 723 |
+
synthesize_text,
|
| 724 |
inputs=[text_input, voice_input, lang_input, steps_input, speed_input, cfg_input],
|
| 725 |
+
outputs=[audio_out, stats_out],
|
| 726 |
+
)
|
| 727 |
|
| 728 |
with gr.TabItem("Clone Voice"):
|
| 729 |
with gr.Column(elem_classes="card"):
|
models/__init__.py
ADDED
|
File without changes
|
models/reference_encoder.py
ADDED
|
@@ -0,0 +1,86 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
|
| 4 |
+
from .text_encoder import ConvNeXtWrapper
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
class ReferenceEncoder(nn.Module):
|
| 8 |
+
def __init__(
|
| 9 |
+
self,
|
| 10 |
+
in_channels: int = 144,
|
| 11 |
+
d_model: int = 256,
|
| 12 |
+
hidden_dim: int = 1024,
|
| 13 |
+
num_blocks: int = 6,
|
| 14 |
+
num_tokens: int = 50,
|
| 15 |
+
num_heads: int = 2,
|
| 16 |
+
kernel_size: int = 5,
|
| 17 |
+
dilation_lst: list = None,
|
| 18 |
+
prototype_dim: int = 256,
|
| 19 |
+
n_units: int = 256,
|
| 20 |
+
style_value_dim: int = 256,
|
| 21 |
+
):
|
| 22 |
+
super().__init__()
|
| 23 |
+
self.d_model = d_model
|
| 24 |
+
self.num_tokens = num_tokens
|
| 25 |
+
|
| 26 |
+
if hidden_dim % d_model != 0:
|
| 27 |
+
raise ValueError(f"hidden_dim ({hidden_dim}) must be divisible by d_model ({d_model})")
|
| 28 |
+
mlp_ratio = hidden_dim // d_model
|
| 29 |
+
|
| 30 |
+
self.input_proj = nn.Conv1d(in_channels, d_model, kernel_size=1)
|
| 31 |
+
self.convnext = ConvNeXtWrapper(
|
| 32 |
+
d_model,
|
| 33 |
+
n_layers=num_blocks,
|
| 34 |
+
expansion_factor=mlp_ratio,
|
| 35 |
+
kernel_size=kernel_size,
|
| 36 |
+
dilation_lst=dilation_lst,
|
| 37 |
+
)
|
| 38 |
+
|
| 39 |
+
self.ref_keys = nn.Parameter(torch.randn(num_tokens, prototype_dim) * 0.02)
|
| 40 |
+
self.q_proj = nn.Linear(prototype_dim, n_units) if prototype_dim != n_units else nn.Identity()
|
| 41 |
+
self.out_proj = nn.Linear(n_units, style_value_dim) if n_units != style_value_dim else nn.Identity()
|
| 42 |
+
|
| 43 |
+
self.attn1 = nn.MultiheadAttention(
|
| 44 |
+
embed_dim=n_units, num_heads=num_heads, kdim=d_model, vdim=d_model, batch_first=True
|
| 45 |
+
)
|
| 46 |
+
self.attn2 = nn.MultiheadAttention(
|
| 47 |
+
embed_dim=n_units, num_heads=num_heads, kdim=d_model, vdim=d_model, batch_first=True
|
| 48 |
+
)
|
| 49 |
+
|
| 50 |
+
def forward(self, z_ref: torch.Tensor, mask: torch.Tensor = None):
|
| 51 |
+
B = z_ref.shape[0]
|
| 52 |
+
x = self.input_proj(z_ref)
|
| 53 |
+
x = self.convnext(x, mask=mask)
|
| 54 |
+
kv = x.transpose(1, 2)
|
| 55 |
+
|
| 56 |
+
key_padding_mask = None
|
| 57 |
+
if mask is not None:
|
| 58 |
+
key_padding_mask = (mask.squeeze(1) == 0)
|
| 59 |
+
|
| 60 |
+
q0 = self.ref_keys.unsqueeze(0).expand(B, -1, -1)
|
| 61 |
+
q0 = self.q_proj(q0)
|
| 62 |
+
|
| 63 |
+
q1, _ = self.attn1(query=q0, key=kv, value=kv, key_padding_mask=key_padding_mask, need_weights=False)
|
| 64 |
+
q2 = q0 + q1
|
| 65 |
+
out, _ = self.attn2(query=q2, key=kv, value=kv, key_padding_mask=key_padding_mask, need_weights=False)
|
| 66 |
+
return self.out_proj(out)
|
| 67 |
+
|
| 68 |
+
@staticmethod
|
| 69 |
+
def remap_legacy_state_dict(state_dict: dict) -> dict:
|
| 70 |
+
"""Remap pre-refactor checkpoints (per-layer pre-norm + FFN) onto current layout."""
|
| 71 |
+
remapped = {}
|
| 72 |
+
legacy_prefix_map = {
|
| 73 |
+
"attn_layers.0.attn.": "attn1.",
|
| 74 |
+
"attn_layers.1.attn.": "attn2.",
|
| 75 |
+
}
|
| 76 |
+
drop_substrings = (".norm_q.", ".norm_kv.", ".ffn.", "pos_emb.")
|
| 77 |
+
for k, v in state_dict.items():
|
| 78 |
+
if any(s in k for s in drop_substrings):
|
| 79 |
+
continue
|
| 80 |
+
new_key = k
|
| 81 |
+
for old, new in legacy_prefix_map.items():
|
| 82 |
+
if new_key.startswith(old):
|
| 83 |
+
new_key = new + new_key[len(old):]
|
| 84 |
+
break
|
| 85 |
+
remapped[new_key] = v
|
| 86 |
+
return remapped
|
models/text_encoder.py
ADDED
|
@@ -0,0 +1,358 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
|
| 5 |
+
|
| 6 |
+
class LayerNorm(nn.Module):
|
| 7 |
+
def __init__(self, channels: int, eps: float = 1e-6):
|
| 8 |
+
super().__init__()
|
| 9 |
+
self.norm = nn.LayerNorm(channels, eps=eps)
|
| 10 |
+
|
| 11 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 12 |
+
x = x.transpose(1, 2)
|
| 13 |
+
x = self.norm(x)
|
| 14 |
+
x = x.transpose(1, 2)
|
| 15 |
+
return x
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class ConvNeXtBlock(nn.Module):
|
| 19 |
+
def __init__(self, dim: int, expansion_factor: int = 4, kernel_size: int = 5, dilation: int = 1, layer_scale_init_value: float = 1e-6):
|
| 20 |
+
super().__init__()
|
| 21 |
+
hidden_dim = dim * expansion_factor
|
| 22 |
+
if (kernel_size % 2) != 1:
|
| 23 |
+
raise ValueError(f"ConvNeXtBlock expects odd kernel_size, got {kernel_size}")
|
| 24 |
+
self.pad = ((kernel_size - 1) // 2) * dilation
|
| 25 |
+
self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=0, groups=dim, dilation=dilation)
|
| 26 |
+
self.norm = LayerNorm(dim, eps=1e-6)
|
| 27 |
+
self.pwconv1 = nn.Conv1d(dim, hidden_dim, kernel_size=1)
|
| 28 |
+
self.act = nn.GELU()
|
| 29 |
+
self.pwconv2 = nn.Conv1d(hidden_dim, dim, kernel_size=1)
|
| 30 |
+
self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1)), requires_grad=True)
|
| 31 |
+
|
| 32 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 33 |
+
if mask is not None:
|
| 34 |
+
x = x * mask
|
| 35 |
+
residual = x
|
| 36 |
+
|
| 37 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
| 38 |
+
x = self.dwconv(x)
|
| 39 |
+
if mask is not None:
|
| 40 |
+
x = x * mask
|
| 41 |
+
|
| 42 |
+
x = self.norm(x)
|
| 43 |
+
x = self.pwconv1(x)
|
| 44 |
+
x = self.act(x)
|
| 45 |
+
x = self.pwconv2(x)
|
| 46 |
+
x = self.gamma * x
|
| 47 |
+
|
| 48 |
+
x = residual + x
|
| 49 |
+
if mask is not None:
|
| 50 |
+
x = x * mask
|
| 51 |
+
return x
|
| 52 |
+
|
| 53 |
+
|
| 54 |
+
class ConvNeXtWrapper(nn.Module):
|
| 55 |
+
def __init__(self, d_model, n_layers, expansion_factor, kernel_size=5, dilation_lst=None):
|
| 56 |
+
super().__init__()
|
| 57 |
+
if dilation_lst is None:
|
| 58 |
+
dilation_lst = [1] * n_layers
|
| 59 |
+
self.convnext = nn.ModuleList([
|
| 60 |
+
ConvNeXtBlock(d_model, expansion_factor=expansion_factor, kernel_size=kernel_size, dilation=dilation_lst[i])
|
| 61 |
+
for i in range(n_layers)
|
| 62 |
+
])
|
| 63 |
+
|
| 64 |
+
def forward(self, x, mask=None):
|
| 65 |
+
for block in self.convnext:
|
| 66 |
+
x = block(x, mask=mask)
|
| 67 |
+
return x
|
| 68 |
+
|
| 69 |
+
|
| 70 |
+
class RelativeMultiHeadAttention(nn.Module):
|
| 71 |
+
def __init__(self, channels: int, n_heads: int, window_size: int = 4, p_dropout: float = 0.0):
|
| 72 |
+
super().__init__()
|
| 73 |
+
assert channels % n_heads == 0
|
| 74 |
+
self.channels = channels
|
| 75 |
+
self.n_heads = n_heads
|
| 76 |
+
self.head_dim = channels // n_heads
|
| 77 |
+
self.scale = self.head_dim ** -0.5
|
| 78 |
+
self.window_size = window_size
|
| 79 |
+
|
| 80 |
+
self.conv_q = nn.Conv1d(channels, channels, 1)
|
| 81 |
+
self.conv_k = nn.Conv1d(channels, channels, 1)
|
| 82 |
+
self.conv_v = nn.Conv1d(channels, channels, 1)
|
| 83 |
+
self.conv_o = nn.Conv1d(channels, channels, 1)
|
| 84 |
+
|
| 85 |
+
self.emb_rel_k = nn.Parameter(torch.randn(1, 2 * window_size + 1, self.head_dim) * 0.02)
|
| 86 |
+
self.emb_rel_v = nn.Parameter(torch.randn(1, 2 * window_size + 1, self.head_dim) * 0.02)
|
| 87 |
+
|
| 88 |
+
self.drop = nn.Dropout(p_dropout)
|
| 89 |
+
|
| 90 |
+
def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 91 |
+
B, C, L = x.shape
|
| 92 |
+
|
| 93 |
+
q = self.conv_q(x).view(B, self.n_heads, self.head_dim, L).transpose(2, 3)
|
| 94 |
+
q = q * self.scale
|
| 95 |
+
k = self.conv_k(x).view(B, self.n_heads, self.head_dim, L).transpose(2, 3)
|
| 96 |
+
v = self.conv_v(x).view(B, self.n_heads, self.head_dim, L).transpose(2, 3)
|
| 97 |
+
|
| 98 |
+
scores = torch.matmul(q, k.transpose(-2, -1))
|
| 99 |
+
|
| 100 |
+
t = torch.arange(L, device=x.device)
|
| 101 |
+
diff = t[None, :] - t[:, None]
|
| 102 |
+
window_mask = (diff.abs() <= self.window_size)
|
| 103 |
+
diff_clamped = torch.clamp(diff, -self.window_size, self.window_size)
|
| 104 |
+
indices = diff_clamped + self.window_size
|
| 105 |
+
|
| 106 |
+
rel_k = self.emb_rel_k[0][indices]
|
| 107 |
+
rel_scores = torch.einsum("bhld,ljd->bhlj", q, rel_k)
|
| 108 |
+
rel_scores = rel_scores * window_mask[None, None, :, :]
|
| 109 |
+
|
| 110 |
+
scores = scores + rel_scores
|
| 111 |
+
|
| 112 |
+
if attn_mask is not None:
|
| 113 |
+
scores = scores.masked_fill(attn_mask == 0, -1e4)
|
| 114 |
+
|
| 115 |
+
attn = torch.softmax(scores, dim=-1)
|
| 116 |
+
attn = self.drop(attn)
|
| 117 |
+
|
| 118 |
+
out = torch.matmul(attn, v)
|
| 119 |
+
|
| 120 |
+
rel_v = self.emb_rel_v[0][indices]
|
| 121 |
+
rel_v = rel_v * window_mask[:, :, None]
|
| 122 |
+
out_rel = torch.einsum("bhlj,ljd->bhld", attn, rel_v)
|
| 123 |
+
|
| 124 |
+
out = out + out_rel
|
| 125 |
+
out = out.transpose(2, 3).contiguous().view(B, C, L)
|
| 126 |
+
out = self.conv_o(out)
|
| 127 |
+
return out
|
| 128 |
+
|
| 129 |
+
|
| 130 |
+
class FeedForward(nn.Module):
|
| 131 |
+
def __init__(self, channels: int, filter_channels: int, kernel_size: int = 1, p_dropout: float = 0.0):
|
| 132 |
+
super().__init__()
|
| 133 |
+
self.conv_1 = nn.Conv1d(channels, filter_channels, kernel_size)
|
| 134 |
+
self.relu = nn.ReLU()
|
| 135 |
+
self.drop = nn.Dropout(p_dropout)
|
| 136 |
+
self.conv_2 = nn.Conv1d(filter_channels, channels, kernel_size)
|
| 137 |
+
|
| 138 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 139 |
+
if mask is not None:
|
| 140 |
+
x = x * mask
|
| 141 |
+
x = self.conv_1(x)
|
| 142 |
+
x = self.relu(x)
|
| 143 |
+
x = self.drop(x)
|
| 144 |
+
if mask is not None:
|
| 145 |
+
x = x * mask
|
| 146 |
+
x = self.conv_2(x)
|
| 147 |
+
if mask is not None:
|
| 148 |
+
x = x * mask
|
| 149 |
+
return x
|
| 150 |
+
|
| 151 |
+
|
| 152 |
+
class AttnEncoder(nn.Module):
|
| 153 |
+
def __init__(self, channels: int, n_heads: int, filter_channels: int, n_layers: int, p_dropout: float = 0.0):
|
| 154 |
+
super().__init__()
|
| 155 |
+
self.attn_layers = nn.ModuleList(
|
| 156 |
+
[RelativeMultiHeadAttention(channels, n_heads, window_size=4, p_dropout=p_dropout) for _ in range(n_layers)]
|
| 157 |
+
)
|
| 158 |
+
self.norm_layers_1 = nn.ModuleList([LayerNorm(channels) for _ in range(n_layers)])
|
| 159 |
+
self.ffn_layers = nn.ModuleList(
|
| 160 |
+
[FeedForward(channels, filter_channels, p_dropout=p_dropout) for _ in range(n_layers)]
|
| 161 |
+
)
|
| 162 |
+
self.norm_layers_2 = nn.ModuleList([LayerNorm(channels) for _ in range(n_layers)])
|
| 163 |
+
|
| 164 |
+
def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 165 |
+
if mask is not None:
|
| 166 |
+
x = x * mask
|
| 167 |
+
|
| 168 |
+
attn_mask = None
|
| 169 |
+
if mask is not None:
|
| 170 |
+
attn_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
|
| 171 |
+
|
| 172 |
+
for i in range(len(self.attn_layers)):
|
| 173 |
+
residual = x
|
| 174 |
+
x = self.attn_layers[i](x, attn_mask=attn_mask)
|
| 175 |
+
x = residual + x
|
| 176 |
+
x = self.norm_layers_1[i](x)
|
| 177 |
+
|
| 178 |
+
residual_ffn = x
|
| 179 |
+
x_ffn = self.ffn_layers[i](x, mask=mask)
|
| 180 |
+
x = residual_ffn + x_ffn
|
| 181 |
+
x = self.norm_layers_2[i](x)
|
| 182 |
+
|
| 183 |
+
if mask is not None:
|
| 184 |
+
x = x * mask
|
| 185 |
+
return x
|
| 186 |
+
|
| 187 |
+
|
| 188 |
+
class LinearWrapped(nn.Module):
|
| 189 |
+
def __init__(self, in_dim, out_dim=None):
|
| 190 |
+
super().__init__()
|
| 191 |
+
if out_dim is None:
|
| 192 |
+
out_dim = in_dim
|
| 193 |
+
self.linear = nn.Linear(in_dim, out_dim)
|
| 194 |
+
|
| 195 |
+
def forward(self, x):
|
| 196 |
+
return self.linear(x)
|
| 197 |
+
|
| 198 |
+
|
| 199 |
+
class StyleNorm(nn.Module):
|
| 200 |
+
def __init__(self, dim, eps: float = 1e-6):
|
| 201 |
+
super().__init__()
|
| 202 |
+
self.norm = nn.LayerNorm(dim, eps=eps)
|
| 203 |
+
|
| 204 |
+
def forward(self, x):
|
| 205 |
+
x = self.norm(x)
|
| 206 |
+
x = x.transpose(1, 2)
|
| 207 |
+
return x
|
| 208 |
+
|
| 209 |
+
|
| 210 |
+
class TextEmbedderWrapper(nn.Module):
|
| 211 |
+
def __init__(self, vocab_size, d_model):
|
| 212 |
+
super().__init__()
|
| 213 |
+
self.char_embedder = nn.Embedding(vocab_size, d_model)
|
| 214 |
+
|
| 215 |
+
def forward(self, x):
|
| 216 |
+
return self.char_embedder(x)
|
| 217 |
+
|
| 218 |
+
|
| 219 |
+
class StyleAttentionLayer(nn.Module):
|
| 220 |
+
def __init__(self, text_dim: int, style_dim: int, n_units: int, num_heads: int = 2, num_style_tokens: int = 50):
|
| 221 |
+
super().__init__()
|
| 222 |
+
assert n_units % num_heads == 0
|
| 223 |
+
self.num_heads = num_heads
|
| 224 |
+
self.dim = n_units
|
| 225 |
+
self.head_dim = n_units // num_heads
|
| 226 |
+
self.scale = n_units ** -0.5
|
| 227 |
+
|
| 228 |
+
self.W_query = LinearWrapped(text_dim, n_units)
|
| 229 |
+
self.W_value = LinearWrapped(style_dim, n_units)
|
| 230 |
+
self.out_fc = LinearWrapped(n_units, text_dim)
|
| 231 |
+
|
| 232 |
+
# ONNX folds `tanh(W_key(style_key))` into a baked constant; mirror with a learnable parameter.
|
| 233 |
+
self.key_const = nn.Parameter(torch.randn(num_heads, 1, self.head_dim, num_style_tokens) * 0.02)
|
| 234 |
+
|
| 235 |
+
def forward(self, x: torch.Tensor, values: torch.Tensor, mask_t: torch.Tensor | None = None) -> torch.Tensor:
|
| 236 |
+
B, T, C = x.shape
|
| 237 |
+
|
| 238 |
+
q = self.W_query(x)
|
| 239 |
+
qs = q.chunk(self.num_heads, dim=-1)
|
| 240 |
+
q = torch.stack(qs, dim=0)
|
| 241 |
+
|
| 242 |
+
k = self.key_const
|
| 243 |
+
|
| 244 |
+
if values.dim() == 2:
|
| 245 |
+
values = values.unsqueeze(0)
|
| 246 |
+
if values.shape[0] != B:
|
| 247 |
+
values = values.expand(B, -1, -1)
|
| 248 |
+
|
| 249 |
+
v = self.W_value(values)
|
| 250 |
+
vs = v.chunk(self.num_heads, dim=-1)
|
| 251 |
+
v = torch.stack(vs, dim=0)
|
| 252 |
+
|
| 253 |
+
scores = torch.matmul(q, k) * self.scale
|
| 254 |
+
attn = torch.softmax(scores, dim=-1)
|
| 255 |
+
|
| 256 |
+
if mask_t is not None:
|
| 257 |
+
attn_mask = (mask_t.unsqueeze(0) == 0)
|
| 258 |
+
attn = attn.masked_fill(attn_mask, 0.0)
|
| 259 |
+
|
| 260 |
+
out = torch.matmul(attn, v)
|
| 261 |
+
|
| 262 |
+
outs = out.chunk(self.num_heads, dim=0)
|
| 263 |
+
out = torch.cat(outs, dim=-1).squeeze(0)
|
| 264 |
+
|
| 265 |
+
out = self.out_fc(out)
|
| 266 |
+
|
| 267 |
+
if mask_t is not None:
|
| 268 |
+
out = out * mask_t
|
| 269 |
+
return out
|
| 270 |
+
|
| 271 |
+
|
| 272 |
+
class StyleAttention(nn.Module):
|
| 273 |
+
def __init__(self, text_dim: int, style_dim: int, n_units: int, num_heads: int = 2, num_style_tokens: int = 50):
|
| 274 |
+
super().__init__()
|
| 275 |
+
# attention1 / attention2 are separate: each owns its baked key constant.
|
| 276 |
+
self.attention1 = StyleAttentionLayer(text_dim, style_dim, n_units, num_heads, num_style_tokens)
|
| 277 |
+
self.attention2 = StyleAttentionLayer(text_dim, style_dim, n_units, num_heads, num_style_tokens)
|
| 278 |
+
self.norm = StyleNorm(text_dim)
|
| 279 |
+
|
| 280 |
+
def forward(self, x: torch.Tensor, style_values: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 281 |
+
x = x.transpose(1, 2)
|
| 282 |
+
|
| 283 |
+
mask_t = None
|
| 284 |
+
if mask is not None:
|
| 285 |
+
mask_t = mask.transpose(1, 2)
|
| 286 |
+
|
| 287 |
+
out1 = self.attention1(x, style_values, mask_t=mask_t)
|
| 288 |
+
x1 = x + out1
|
| 289 |
+
|
| 290 |
+
out2 = self.attention2(x1, style_values, mask_t=mask_t)
|
| 291 |
+
x2 = x + out2
|
| 292 |
+
|
| 293 |
+
x = self.norm(x2)
|
| 294 |
+
if mask is not None:
|
| 295 |
+
x = x * mask
|
| 296 |
+
return x
|
| 297 |
+
|
| 298 |
+
|
| 299 |
+
class TextEncoder(nn.Module):
|
| 300 |
+
def __init__(
|
| 301 |
+
self,
|
| 302 |
+
vocab_size: int = 256,
|
| 303 |
+
d_model: int = 256,
|
| 304 |
+
n_conv_layers: int = 6,
|
| 305 |
+
n_attn_layers: int = 4,
|
| 306 |
+
expansion_factor: int = 4,
|
| 307 |
+
p_dropout: float = 0.1,
|
| 308 |
+
kernel_size: int = 5,
|
| 309 |
+
dilation_lst: list = None,
|
| 310 |
+
attn_n_heads: int = 4,
|
| 311 |
+
attn_filter_channels: int = 1024,
|
| 312 |
+
spte_n_heads: int = 2,
|
| 313 |
+
spte_text_dim: int = 256,
|
| 314 |
+
spte_style_dim: int = 256,
|
| 315 |
+
spte_n_units: int = 256,
|
| 316 |
+
spte_n_style: int = 50,
|
| 317 |
+
):
|
| 318 |
+
super().__init__()
|
| 319 |
+
self.d_model = d_model
|
| 320 |
+
self.text_embedder = TextEmbedderWrapper(vocab_size, d_model)
|
| 321 |
+
self.convnext = ConvNeXtWrapper(
|
| 322 |
+
d_model, n_conv_layers, expansion_factor, kernel_size=kernel_size, dilation_lst=dilation_lst
|
| 323 |
+
)
|
| 324 |
+
self.attn_encoder = AttnEncoder(
|
| 325 |
+
d_model,
|
| 326 |
+
n_heads=attn_n_heads,
|
| 327 |
+
filter_channels=attn_filter_channels,
|
| 328 |
+
n_layers=n_attn_layers,
|
| 329 |
+
p_dropout=p_dropout,
|
| 330 |
+
)
|
| 331 |
+
self.speech_prompted_text_encoder = StyleAttention(
|
| 332 |
+
text_dim=spte_text_dim,
|
| 333 |
+
style_dim=spte_style_dim,
|
| 334 |
+
n_units=spte_n_units,
|
| 335 |
+
num_heads=spte_n_heads,
|
| 336 |
+
num_style_tokens=spte_n_style,
|
| 337 |
+
)
|
| 338 |
+
self.proj_out = nn.Identity()
|
| 339 |
+
|
| 340 |
+
def forward(self, text_ids: torch.Tensor, style_ttl: torch.Tensor, text_mask: torch.Tensor | None = None) -> torch.Tensor:
|
| 341 |
+
x = self.text_embedder(text_ids)
|
| 342 |
+
x = x.transpose(1, 2)
|
| 343 |
+
|
| 344 |
+
if text_mask is not None:
|
| 345 |
+
x = x * text_mask
|
| 346 |
+
|
| 347 |
+
x = self.convnext(x, mask=text_mask)
|
| 348 |
+
convnext_output = x
|
| 349 |
+
|
| 350 |
+
x = self.attn_encoder(x, mask=text_mask)
|
| 351 |
+
x = x + convnext_output
|
| 352 |
+
|
| 353 |
+
x = self.proj_out(x)
|
| 354 |
+
if text_mask is not None:
|
| 355 |
+
x = x * text_mask
|
| 356 |
+
|
| 357 |
+
x = self.speech_prompted_text_encoder(x, style_values=style_ttl, mask=text_mask)
|
| 358 |
+
return x
|
models/utils.py
ADDED
|
@@ -0,0 +1,197 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import torch
|
| 2 |
+
import torch.nn as nn
|
| 3 |
+
import torch.nn.functional as F
|
| 4 |
+
import torchaudio.transforms as T
|
| 5 |
+
|
| 6 |
+
|
| 7 |
+
def compress_latents(z: torch.Tensor, factor: int = 6) -> torch.Tensor:
|
| 8 |
+
B, C, T = z.shape
|
| 9 |
+
if T % factor != 0:
|
| 10 |
+
pad = factor - (T % factor)
|
| 11 |
+
z = torch.nn.functional.pad(z, (0, pad))
|
| 12 |
+
T = T + pad
|
| 13 |
+
return z.view(B, C, T // factor, factor).permute(0, 1, 3, 2).flatten(1, 2)
|
| 14 |
+
|
| 15 |
+
|
| 16 |
+
def decompress_latents(z: torch.Tensor, factor: int = 6, target_channels: int = 24) -> torch.Tensor:
|
| 17 |
+
B, _, T_low = z.shape
|
| 18 |
+
return z.view(B, target_channels, factor, T_low).permute(0, 1, 3, 2).flatten(2, 3)
|
| 19 |
+
|
| 20 |
+
|
| 21 |
+
def _resolve_vocab_size(char_dict_path, default=256):
|
| 22 |
+
import json as _json
|
| 23 |
+
import os as _os
|
| 24 |
+
if char_dict_path and _os.path.exists(char_dict_path):
|
| 25 |
+
try:
|
| 26 |
+
with open(char_dict_path, "r") as f:
|
| 27 |
+
cd = _json.load(f)
|
| 28 |
+
if isinstance(cd, dict) and "vocab_size" in cd:
|
| 29 |
+
return int(cd["vocab_size"])
|
| 30 |
+
if isinstance(cd, dict) and "char_to_id" in cd and isinstance(cd["char_to_id"], dict):
|
| 31 |
+
return max(cd["char_to_id"].values()) + 1
|
| 32 |
+
if isinstance(cd, dict):
|
| 33 |
+
return max(cd.values()) + 1 if cd else default
|
| 34 |
+
return len(cd)
|
| 35 |
+
except Exception:
|
| 36 |
+
pass
|
| 37 |
+
return default
|
| 38 |
+
|
| 39 |
+
|
| 40 |
+
def load_ttl_config(config_path="configs/tts.json"):
|
| 41 |
+
import json
|
| 42 |
+
with open(config_path, "r") as f:
|
| 43 |
+
full_config = json.load(f)
|
| 44 |
+
|
| 45 |
+
ttl = full_config["ttl"]
|
| 46 |
+
ae = full_config.get("ae", {})
|
| 47 |
+
dp = full_config.get("dp", {})
|
| 48 |
+
|
| 49 |
+
te = ttl["text_encoder"]
|
| 50 |
+
se = ttl["style_encoder"]
|
| 51 |
+
vf = ttl["vector_field"]
|
| 52 |
+
um = ttl["uncond_masker"]
|
| 53 |
+
|
| 54 |
+
char_dict_path = te.get("char_dict_path", te.get("text_embedder", {}).get("char_dict_path"))
|
| 55 |
+
vocab_size = _resolve_vocab_size(char_dict_path, default=256)
|
| 56 |
+
|
| 57 |
+
dp_char_dict_path = (
|
| 58 |
+
dp.get("sentence_encoder", {}).get("char_dict_path")
|
| 59 |
+
or dp.get("sentence_encoder", {}).get("text_embedder", {}).get("char_dict_path")
|
| 60 |
+
)
|
| 61 |
+
dp_vocab_size = _resolve_vocab_size(dp_char_dict_path, default=vocab_size)
|
| 62 |
+
|
| 63 |
+
ae_dec = ae.get("decoder", {})
|
| 64 |
+
ae_dec_cfg = {
|
| 65 |
+
"idim": ae_dec.get("idim", 24),
|
| 66 |
+
"hdim": ae_dec.get("hdim", 512),
|
| 67 |
+
"intermediate_dim": ae_dec.get("intermediate_dim", 2048),
|
| 68 |
+
"ksz": ae_dec.get("ksz", 7),
|
| 69 |
+
"dilation_lst": ae_dec.get("dilation_lst", [1, 2, 4, 1, 2, 4, 1, 1, 1, 1]),
|
| 70 |
+
"chunk_compress_factor": ae.get("chunk_compress_factor", 1),
|
| 71 |
+
"head": {
|
| 72 |
+
"idim": ae_dec.get("head", {}).get("idim", ae_dec.get("hdim", 512)),
|
| 73 |
+
"hdim": ae_dec.get("head", {}).get("hdim", 2048),
|
| 74 |
+
"odim": ae_dec.get("head", {}).get("odim", 512),
|
| 75 |
+
"ksz": ae_dec.get("head", {}).get("ksz", 3),
|
| 76 |
+
},
|
| 77 |
+
}
|
| 78 |
+
|
| 79 |
+
ae_enc = ae.get("encoder", {})
|
| 80 |
+
ae_enc_spec = ae_enc.get("spec_processor", {})
|
| 81 |
+
ae_enc_cfg = {
|
| 82 |
+
"ksz": ae_enc.get("ksz", 7),
|
| 83 |
+
"hdim": ae_enc.get("hdim", 512),
|
| 84 |
+
"intermediate_dim": ae_enc.get("intermediate_dim", 2048),
|
| 85 |
+
"dilation_lst": ae_enc.get("dilation_lst", [1] * 10),
|
| 86 |
+
"odim": ae_enc.get("odim", 24),
|
| 87 |
+
"idim": ae_enc.get("idim", 1253),
|
| 88 |
+
}
|
| 89 |
+
|
| 90 |
+
dp_se = dp.get("style_encoder", {}).get("style_token_layer", {})
|
| 91 |
+
|
| 92 |
+
return {
|
| 93 |
+
"full_config": full_config,
|
| 94 |
+
"ttl": ttl,
|
| 95 |
+
"ae": ae,
|
| 96 |
+
"dp": dp,
|
| 97 |
+
|
| 98 |
+
"vocab_size": vocab_size,
|
| 99 |
+
"char_dict_path": char_dict_path,
|
| 100 |
+
"dp_vocab_size": dp_vocab_size,
|
| 101 |
+
|
| 102 |
+
"latent_dim": ttl["latent_dim"],
|
| 103 |
+
"chunk_compress_factor": ttl["chunk_compress_factor"],
|
| 104 |
+
"compressed_channels": ttl["latent_dim"] * ttl["chunk_compress_factor"],
|
| 105 |
+
"normalizer_scale": ttl["normalizer"]["scale"],
|
| 106 |
+
"sigma_min": ttl["flow_matching"]["sig_min"],
|
| 107 |
+
"Ke": ttl["batch_expander"]["n_batch_expand"],
|
| 108 |
+
|
| 109 |
+
"te_d_model": te["text_embedder"]["char_emb_dim"],
|
| 110 |
+
"te_convnext_layers": te["convnext"]["num_layers"],
|
| 111 |
+
"te_expansion_factor": te["convnext"]["intermediate_dim"] // te["text_embedder"]["char_emb_dim"],
|
| 112 |
+
"te_attn_n_layers": te["attn_encoder"]["n_layers"],
|
| 113 |
+
"te_attn_p_dropout": te["attn_encoder"]["p_dropout"],
|
| 114 |
+
|
| 115 |
+
"se_d_model": se["proj_in"]["odim"],
|
| 116 |
+
"se_hidden_dim": se["convnext"]["intermediate_dim"],
|
| 117 |
+
"se_num_blocks": se["convnext"]["num_layers"],
|
| 118 |
+
"se_n_style": se["style_token_layer"]["n_style"],
|
| 119 |
+
"se_n_heads": se["style_token_layer"]["n_heads"],
|
| 120 |
+
|
| 121 |
+
"prob_both_uncond": um["prob_both_uncond"],
|
| 122 |
+
"prob_text_uncond": um["prob_text_uncond"],
|
| 123 |
+
"uncond_init_std": um["std"],
|
| 124 |
+
"um_text_dim": um["text_dim"],
|
| 125 |
+
"um_n_style": um["n_style"],
|
| 126 |
+
"um_style_key_dim": um["style_key_dim"],
|
| 127 |
+
"um_style_value_dim": um["style_value_dim"],
|
| 128 |
+
|
| 129 |
+
"vf_hidden": vf["proj_in"]["odim"],
|
| 130 |
+
"vf_time_dim": vf["time_encoder"]["time_dim"],
|
| 131 |
+
"vf_n_blocks": vf["main_blocks"]["n_blocks"],
|
| 132 |
+
"vf_text_dim": vf["main_blocks"]["text_cond_layer"]["text_dim"],
|
| 133 |
+
"vf_text_n_heads": vf["main_blocks"]["text_cond_layer"]["n_heads"],
|
| 134 |
+
"vf_style_dim": vf["main_blocks"]["style_cond_layer"]["style_dim"],
|
| 135 |
+
"vf_rotary_scale": vf["main_blocks"]["text_cond_layer"]["rotary_scale"],
|
| 136 |
+
|
| 137 |
+
"ae_dec_cfg": ae_dec_cfg,
|
| 138 |
+
"ae_enc_cfg": ae_enc_cfg,
|
| 139 |
+
"ae_sample_rate": ae.get("sample_rate", 44100),
|
| 140 |
+
"ae_n_fft": ae_enc_spec.get("n_fft", 2048),
|
| 141 |
+
"ae_hop_length": ae_enc_spec.get("hop_length", 512),
|
| 142 |
+
"ae_n_mels": ae_enc_spec.get("n_mels", 1253),
|
| 143 |
+
|
| 144 |
+
"dp_style_tokens": dp_se.get("n_style", 8),
|
| 145 |
+
"dp_style_dim": dp_se.get("style_value_dim", 16),
|
| 146 |
+
}
|
| 147 |
+
|
| 148 |
+
|
| 149 |
+
class MelSpectrogram(nn.Module):
|
| 150 |
+
def __init__(self, sample_rate=44100, n_fft=2048, win_length=2048,
|
| 151 |
+
hop_length=512, n_mels=1253, f_min=0, f_max=None):
|
| 152 |
+
super().__init__()
|
| 153 |
+
self.mel = T.MelSpectrogram(
|
| 154 |
+
sample_rate=sample_rate, n_fft=n_fft, win_length=win_length,
|
| 155 |
+
hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max,
|
| 156 |
+
center=True, power=1.0,
|
| 157 |
+
)
|
| 158 |
+
|
| 159 |
+
def forward(self, audio):
|
| 160 |
+
mel = torch.log(torch.clamp(self.mel(audio), min=1e-5))
|
| 161 |
+
return mel.squeeze(1) if mel.dim() == 4 and mel.shape[1] == 1 else mel
|
| 162 |
+
|
| 163 |
+
|
| 164 |
+
class MelSpectrogramNoLog(nn.Module):
|
| 165 |
+
def __init__(self, sample_rate=44100, n_fft=2048, win_length=2048,
|
| 166 |
+
hop_length=512, n_mels=1253, f_min=0, f_max=12000, power=1.0):
|
| 167 |
+
super().__init__()
|
| 168 |
+
self.mel = T.MelSpectrogram(
|
| 169 |
+
sample_rate=sample_rate, n_fft=n_fft, win_length=win_length,
|
| 170 |
+
hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max,
|
| 171 |
+
center=True, power=power,
|
| 172 |
+
)
|
| 173 |
+
|
| 174 |
+
def forward(self, audio):
|
| 175 |
+
mel = self.mel(audio)
|
| 176 |
+
return mel.squeeze(1) if mel.dim() == 4 and mel.shape[1] == 1 else mel
|
| 177 |
+
|
| 178 |
+
|
| 179 |
+
class LinearMelSpectrogram(nn.Module):
|
| 180 |
+
def __init__(self, sample_rate=44100, n_fft=2048, win_length=2048,
|
| 181 |
+
hop_length=512, n_mels=1253, f_min=0, f_max=None):
|
| 182 |
+
super().__init__()
|
| 183 |
+
self.spectrogram = T.Spectrogram(
|
| 184 |
+
n_fft=n_fft, win_length=win_length, hop_length=hop_length,
|
| 185 |
+
center=True, power=1.0,
|
| 186 |
+
)
|
| 187 |
+
self.mel_scale = T.MelScale(
|
| 188 |
+
n_mels=n_mels, sample_rate=sample_rate,
|
| 189 |
+
n_stft=n_fft // 2 + 1, f_min=f_min, f_max=f_max,
|
| 190 |
+
)
|
| 191 |
+
|
| 192 |
+
def forward(self, audio):
|
| 193 |
+
spec = self.spectrogram(audio)
|
| 194 |
+
mel = self.mel_scale(spec)
|
| 195 |
+
log_spec = torch.log(torch.clamp(spec, min=1e-5))
|
| 196 |
+
log_mel = torch.log(torch.clamp(mel, min=1e-5))
|
| 197 |
+
return torch.cat([log_spec, log_mel], dim=1)
|
models/vf_estimator.py
ADDED
|
@@ -0,0 +1,507 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import math
|
| 2 |
+
from typing import Optional
|
| 3 |
+
|
| 4 |
+
import torch
|
| 5 |
+
import torch.nn as nn
|
| 6 |
+
import torch.nn.functional as F
|
| 7 |
+
|
| 8 |
+
|
| 9 |
+
class LinearWrapper(nn.Module):
|
| 10 |
+
def __init__(self, in_features: int, out_features: int):
|
| 11 |
+
super().__init__()
|
| 12 |
+
self.linear = nn.Linear(in_features, out_features)
|
| 13 |
+
|
| 14 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 15 |
+
return self.linear(x)
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
class LayerNormWrapper(nn.Module):
|
| 19 |
+
def __init__(self, dim: int, eps: float = 1e-6):
|
| 20 |
+
super().__init__()
|
| 21 |
+
self.norm = nn.LayerNorm(dim, eps=eps)
|
| 22 |
+
|
| 23 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 24 |
+
x = x.transpose(1, 2)
|
| 25 |
+
x = self.norm(x)
|
| 26 |
+
x = x.transpose(1, 2)
|
| 27 |
+
return x
|
| 28 |
+
|
| 29 |
+
|
| 30 |
+
class ProjectionWrapper(nn.Module):
|
| 31 |
+
def __init__(self, in_channels: int, out_channels: int):
|
| 32 |
+
super().__init__()
|
| 33 |
+
self.net = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
|
| 34 |
+
|
| 35 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 36 |
+
return self.net(x)
|
| 37 |
+
|
| 38 |
+
|
| 39 |
+
class Mish(nn.Module):
|
| 40 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 41 |
+
return x * torch.tanh(F.softplus(x))
|
| 42 |
+
|
| 43 |
+
|
| 44 |
+
class SinusoidalPosEmb(nn.Module):
|
| 45 |
+
def __init__(self, dim: int, scale: float = 1000.0):
|
| 46 |
+
super().__init__()
|
| 47 |
+
self.dim = dim
|
| 48 |
+
self.scale = scale
|
| 49 |
+
|
| 50 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 51 |
+
x = x * self.scale
|
| 52 |
+
device = x.device
|
| 53 |
+
half_dim = self.dim // 2
|
| 54 |
+
emb = math.log(10000) / (half_dim - 1)
|
| 55 |
+
emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
|
| 56 |
+
emb = x[:, None] * emb[None, :]
|
| 57 |
+
emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
|
| 58 |
+
return emb
|
| 59 |
+
|
| 60 |
+
|
| 61 |
+
class TimeEncoder(nn.Module):
|
| 62 |
+
def __init__(self, embed_dim: int, hdim: int = 256):
|
| 63 |
+
super().__init__()
|
| 64 |
+
self.sinusoidal = SinusoidalPosEmb(embed_dim, scale=1000.0)
|
| 65 |
+
self.mlp = nn.Sequential(
|
| 66 |
+
LinearWrapper(embed_dim, hdim),
|
| 67 |
+
Mish(),
|
| 68 |
+
LinearWrapper(hdim, embed_dim),
|
| 69 |
+
)
|
| 70 |
+
|
| 71 |
+
def forward(self, x: torch.Tensor) -> torch.Tensor:
|
| 72 |
+
x = self.sinusoidal(x)
|
| 73 |
+
x = self.mlp(x)
|
| 74 |
+
return x
|
| 75 |
+
|
| 76 |
+
|
| 77 |
+
class TimeCondBlock(nn.Module):
|
| 78 |
+
def __init__(self, time_dim: int, channels: int):
|
| 79 |
+
super().__init__()
|
| 80 |
+
self.linear = LinearWrapper(time_dim, channels)
|
| 81 |
+
# Zero-init so the block starts as identity.
|
| 82 |
+
nn.init.zeros_(self.linear.linear.weight)
|
| 83 |
+
nn.init.zeros_(self.linear.linear.bias)
|
| 84 |
+
|
| 85 |
+
def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
|
| 86 |
+
cond = self.linear(time_emb)
|
| 87 |
+
cond = cond.unsqueeze(-1)
|
| 88 |
+
return x + cond
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
class ConvNeXtBlock1D(nn.Module):
|
| 92 |
+
def __init__(self, dim: int, kernel_size: int = 5, expansion: int = 2, dropout: float = 0.0, dilation: int = 1):
|
| 93 |
+
super().__init__()
|
| 94 |
+
self.pad = ((kernel_size - 1) // 2) * dilation
|
| 95 |
+
self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=0, groups=dim, dilation=dilation)
|
| 96 |
+
self.norm = LayerNormWrapper(dim)
|
| 97 |
+
self.pwconv1 = nn.Conv1d(dim, dim * expansion, kernel_size=1)
|
| 98 |
+
self.act = nn.GELU()
|
| 99 |
+
self.pwconv2 = nn.Conv1d(dim * expansion, dim, kernel_size=1)
|
| 100 |
+
self.gamma = nn.Parameter(torch.ones(1, dim, 1) * 1e-6)
|
| 101 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
|
| 102 |
+
|
| 103 |
+
def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
|
| 104 |
+
if mask is not None:
|
| 105 |
+
x = x * mask
|
| 106 |
+
residual = x
|
| 107 |
+
|
| 108 |
+
x = F.pad(x, (self.pad, self.pad), mode="replicate")
|
| 109 |
+
x = self.dwconv(x)
|
| 110 |
+
if mask is not None:
|
| 111 |
+
x = x * mask
|
| 112 |
+
|
| 113 |
+
x = self.norm(x)
|
| 114 |
+
x = self.pwconv1(x)
|
| 115 |
+
x = self.act(x)
|
| 116 |
+
x = self.pwconv2(x)
|
| 117 |
+
x = self.gamma * x
|
| 118 |
+
x = self.dropout(x)
|
| 119 |
+
|
| 120 |
+
x = x + residual
|
| 121 |
+
if mask is not None:
|
| 122 |
+
x = x * mask
|
| 123 |
+
return x
|
| 124 |
+
|
| 125 |
+
|
| 126 |
+
class ConvNeXtStack(nn.Module):
|
| 127 |
+
def __init__(self, channels, kernel_size, dilations):
|
| 128 |
+
super().__init__()
|
| 129 |
+
self.convnext = nn.ModuleList([
|
| 130 |
+
ConvNeXtBlock1D(channels, kernel_size=kernel_size, dilation=d, expansion=2)
|
| 131 |
+
for d in dilations
|
| 132 |
+
])
|
| 133 |
+
|
| 134 |
+
def forward(self, x, mask=None):
|
| 135 |
+
for blk in self.convnext:
|
| 136 |
+
x = blk(x, mask)
|
| 137 |
+
return x
|
| 138 |
+
|
| 139 |
+
|
| 140 |
+
def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
|
| 141 |
+
B, H, T, D = x.shape
|
| 142 |
+
assert D % 2 == 0, "head_dim must be even for RoPE"
|
| 143 |
+
|
| 144 |
+
x1 = x[..., : D // 2]
|
| 145 |
+
x2 = x[..., D // 2 :]
|
| 146 |
+
|
| 147 |
+
if cos.dim() == 2:
|
| 148 |
+
cos = cos[None, None, :, :]
|
| 149 |
+
sin = sin[None, None, :, :]
|
| 150 |
+
elif cos.dim() == 3:
|
| 151 |
+
cos = cos.unsqueeze(1)
|
| 152 |
+
sin = sin.unsqueeze(1)
|
| 153 |
+
|
| 154 |
+
x1_rot = x1 * cos - x2 * sin
|
| 155 |
+
x2_rot = x1 * sin + x2 * cos
|
| 156 |
+
|
| 157 |
+
return torch.cat([x1_rot, x2_rot], dim=-1)
|
| 158 |
+
|
| 159 |
+
|
| 160 |
+
class AttentionModule(nn.Module):
|
| 161 |
+
"""Text path uses LARoPE; style path uses tanh on keys (no RoPE)."""
|
| 162 |
+
|
| 163 |
+
def __init__(
|
| 164 |
+
self,
|
| 165 |
+
d_model: int,
|
| 166 |
+
d_context: int,
|
| 167 |
+
num_heads: int,
|
| 168 |
+
attn_dim: int,
|
| 169 |
+
use_rope: bool,
|
| 170 |
+
dropout: float = 0.0,
|
| 171 |
+
rope_gamma: float = 10.0,
|
| 172 |
+
attn_scale: Optional[float] = None,
|
| 173 |
+
rotary_base: float = 10000.0,
|
| 174 |
+
use_residual: bool = True,
|
| 175 |
+
):
|
| 176 |
+
super().__init__()
|
| 177 |
+
assert attn_dim % num_heads == 0
|
| 178 |
+
|
| 179 |
+
self.d_model = d_model
|
| 180 |
+
self.num_heads = num_heads
|
| 181 |
+
self.head_dim = attn_dim // num_heads
|
| 182 |
+
self.attn_dim = attn_dim
|
| 183 |
+
self.use_rope = use_rope
|
| 184 |
+
self.use_residual = use_residual
|
| 185 |
+
self.rope_gamma = rope_gamma
|
| 186 |
+
self.attn_scale = attn_scale if attn_scale is not None else math.sqrt(self.attn_dim)
|
| 187 |
+
|
| 188 |
+
self.W_query = LinearWrapper(d_model, attn_dim)
|
| 189 |
+
self.W_key = LinearWrapper(d_context, attn_dim)
|
| 190 |
+
self.W_value = LinearWrapper(d_context, attn_dim)
|
| 191 |
+
self.out_fc = LinearWrapper(attn_dim, d_model)
|
| 192 |
+
|
| 193 |
+
self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
|
| 194 |
+
|
| 195 |
+
if use_rope:
|
| 196 |
+
inv_freq = 1.0 / (rotary_base ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim))
|
| 197 |
+
theta = (inv_freq * rope_gamma).view(1, 1, -1)
|
| 198 |
+
self.register_buffer("theta", theta, persistent=True)
|
| 199 |
+
self.register_buffer("increments", torch.arange(1000).view(1, 1000, 1), persistent=True)
|
| 200 |
+
self.tanh = None
|
| 201 |
+
else:
|
| 202 |
+
self.theta = None
|
| 203 |
+
self.increments = None
|
| 204 |
+
self.tanh = nn.Tanh()
|
| 205 |
+
|
| 206 |
+
def forward(
|
| 207 |
+
self,
|
| 208 |
+
x: torch.Tensor,
|
| 209 |
+
context: torch.Tensor,
|
| 210 |
+
context_keys: Optional[torch.Tensor] = None,
|
| 211 |
+
x_mask: Optional[torch.Tensor] = None,
|
| 212 |
+
context_mask: Optional[torch.Tensor] = None,
|
| 213 |
+
) -> torch.Tensor:
|
| 214 |
+
B, d_model, T = x.shape
|
| 215 |
+
L = context.shape[1]
|
| 216 |
+
|
| 217 |
+
x_t = x.transpose(1, 2)
|
| 218 |
+
q = self.W_query(x_t)
|
| 219 |
+
|
| 220 |
+
k_src = context_keys if context_keys is not None else context
|
| 221 |
+
k = self.W_key(k_src)
|
| 222 |
+
v = self.W_value(context)
|
| 223 |
+
|
| 224 |
+
if not self.use_rope and self.tanh is not None:
|
| 225 |
+
k = self.tanh(k)
|
| 226 |
+
|
| 227 |
+
H = self.num_heads
|
| 228 |
+
D = self.head_dim
|
| 229 |
+
|
| 230 |
+
q = q.view(B, T, H, D).permute(2, 0, 1, 3)
|
| 231 |
+
k = k.view(B, L, H, D).permute(2, 0, 1, 3)
|
| 232 |
+
v = v.view(B, L, H, D).permute(2, 0, 1, 3)
|
| 233 |
+
|
| 234 |
+
if self.use_rope:
|
| 235 |
+
device = x.device
|
| 236 |
+
|
| 237 |
+
if x_mask is not None:
|
| 238 |
+
len_q = x_mask.sum(dim=(-2, -1)).reshape(-1, 1, 1)
|
| 239 |
+
else:
|
| 240 |
+
len_q = torch.tensor([T], device=device, dtype=torch.float32).reshape(1, 1, 1)
|
| 241 |
+
|
| 242 |
+
if context_mask is not None:
|
| 243 |
+
len_k = context_mask.sum(dim=(-2, -1)).reshape(-1, 1, 1)
|
| 244 |
+
else:
|
| 245 |
+
len_k = torch.tensor([L], device=device, dtype=torch.float32).reshape(1, 1, 1)
|
| 246 |
+
|
| 247 |
+
if self.increments is not None and self.increments.shape[1] >= max(T, L):
|
| 248 |
+
pos_q = self.increments[:, :T, :].to(device).float()
|
| 249 |
+
pos_k = self.increments[:, :L, :].to(device).float()
|
| 250 |
+
else:
|
| 251 |
+
pos_q = torch.arange(T, device=device, dtype=torch.float32).reshape(1, -1, 1)
|
| 252 |
+
pos_k = torch.arange(L, device=device, dtype=torch.float32).reshape(1, -1, 1)
|
| 253 |
+
|
| 254 |
+
norm_pos_q = pos_q / len_q
|
| 255 |
+
norm_pos_k = pos_k / len_k
|
| 256 |
+
|
| 257 |
+
theta = self.theta if self.theta is not None else (
|
| 258 |
+
(1.0 / (10000 ** (torch.arange(0, D, 2, device=device).float() / D))) * self.rope_gamma
|
| 259 |
+
).view(1, 1, -1)
|
| 260 |
+
|
| 261 |
+
freqs_q = norm_pos_q * theta
|
| 262 |
+
freqs_k = norm_pos_k * theta
|
| 263 |
+
|
| 264 |
+
cos_q, sin_q = freqs_q.cos(), freqs_q.sin()
|
| 265 |
+
cos_k, sin_k = freqs_k.cos(), freqs_k.sin()
|
| 266 |
+
|
| 267 |
+
cos_q, sin_q = cos_q.unsqueeze(0), sin_q.unsqueeze(0)
|
| 268 |
+
cos_k, sin_k = cos_k.unsqueeze(0), sin_k.unsqueeze(0)
|
| 269 |
+
|
| 270 |
+
q = apply_rotary_pos_emb(q, cos_q, sin_q)
|
| 271 |
+
k = apply_rotary_pos_emb(k, cos_k, sin_k)
|
| 272 |
+
|
| 273 |
+
attn_logits = torch.matmul(q, k.transpose(-1, -2)) / self.attn_scale
|
| 274 |
+
|
| 275 |
+
if context_mask is not None:
|
| 276 |
+
if context_mask.dim() == 2:
|
| 277 |
+
context_mask = context_mask.unsqueeze(1)
|
| 278 |
+
cm = (context_mask == 0)
|
| 279 |
+
attn_logits = attn_logits.masked_fill(cm.unsqueeze(0), float("-inf"))
|
| 280 |
+
|
| 281 |
+
attn = torch.softmax(attn_logits, dim=-1)
|
| 282 |
+
|
| 283 |
+
if x_mask is not None:
|
| 284 |
+
if x_mask.dim() == 2:
|
| 285 |
+
x_mask = x_mask.unsqueeze(1)
|
| 286 |
+
qm = (x_mask == 0).permute(1, 0, 2).unsqueeze(-1)
|
| 287 |
+
attn = attn.masked_fill(qm, 0.0)
|
| 288 |
+
|
| 289 |
+
out = torch.matmul(attn, v)
|
| 290 |
+
out = out.permute(1, 2, 0, 3).contiguous().view(B, T, self.attn_dim)
|
| 291 |
+
out = self.out_fc(out)
|
| 292 |
+
out = self.dropout(out)
|
| 293 |
+
|
| 294 |
+
if x_mask is not None:
|
| 295 |
+
out = out * x_mask.transpose(1, 2)
|
| 296 |
+
|
| 297 |
+
out = out.transpose(1, 2)
|
| 298 |
+
return out
|
| 299 |
+
|
| 300 |
+
|
| 301 |
+
class CrossAttentionBlock(nn.Module):
|
| 302 |
+
def __init__(
|
| 303 |
+
self,
|
| 304 |
+
d_model: int,
|
| 305 |
+
d_context: int,
|
| 306 |
+
num_heads: int = 8,
|
| 307 |
+
attn_dim: int = 256,
|
| 308 |
+
use_rope: bool = True,
|
| 309 |
+
rope_gamma: float = 10.0,
|
| 310 |
+
attn_scale: Optional[float] = None,
|
| 311 |
+
use_residual: bool = True,
|
| 312 |
+
rotary_base: float = 10000.0,
|
| 313 |
+
):
|
| 314 |
+
super().__init__()
|
| 315 |
+
self.use_rope = use_rope
|
| 316 |
+
self.use_residual = use_residual
|
| 317 |
+
attn_module = AttentionModule(
|
| 318 |
+
d_model, d_context, num_heads, attn_dim, use_rope,
|
| 319 |
+
rope_gamma=rope_gamma, attn_scale=attn_scale, rotary_base=rotary_base, use_residual=use_residual,
|
| 320 |
+
)
|
| 321 |
+
# Checkpoint naming: text (RoPE) -> 'attn'; style (no RoPE) -> 'attention'.
|
| 322 |
+
if use_rope:
|
| 323 |
+
self.attn = attn_module
|
| 324 |
+
else:
|
| 325 |
+
self.attention = attn_module
|
| 326 |
+
self.norm = LayerNormWrapper(d_model)
|
| 327 |
+
|
| 328 |
+
def forward(
|
| 329 |
+
self,
|
| 330 |
+
x: torch.Tensor,
|
| 331 |
+
context: torch.Tensor,
|
| 332 |
+
context_keys: Optional[torch.Tensor],
|
| 333 |
+
x_mask: Optional[torch.Tensor],
|
| 334 |
+
context_mask: Optional[torch.Tensor],
|
| 335 |
+
) -> torch.Tensor:
|
| 336 |
+
if x_mask is not None:
|
| 337 |
+
x = x * x_mask
|
| 338 |
+
|
| 339 |
+
residual = x
|
| 340 |
+
|
| 341 |
+
if self.use_rope:
|
| 342 |
+
attn_out = self.attn(x, context, context_keys, x_mask, context_mask)
|
| 343 |
+
else:
|
| 344 |
+
attn_out = self.attention(x, context, context_keys, x_mask, context_mask)
|
| 345 |
+
|
| 346 |
+
if self.use_residual:
|
| 347 |
+
x = residual + attn_out
|
| 348 |
+
else:
|
| 349 |
+
x = attn_out
|
| 350 |
+
|
| 351 |
+
x = self.norm(x)
|
| 352 |
+
if x_mask is not None:
|
| 353 |
+
x = x * x_mask
|
| 354 |
+
return x
|
| 355 |
+
|
| 356 |
+
|
| 357 |
+
class VectorFieldEstimator(nn.Module):
|
| 358 |
+
def __init__(
|
| 359 |
+
self,
|
| 360 |
+
in_channels: int = 144,
|
| 361 |
+
hidden_channels: int = 512,
|
| 362 |
+
out_channels: int = 144,
|
| 363 |
+
text_dim: int = 256,
|
| 364 |
+
style_dim: int = 256,
|
| 365 |
+
num_style_tokens: int = 50,
|
| 366 |
+
num_superblocks: int = 4,
|
| 367 |
+
time_embed_dim: int = 64,
|
| 368 |
+
rope_gamma: float = 10.0,
|
| 369 |
+
main_blocks_cfg: dict = None,
|
| 370 |
+
last_convnext_cfg: dict = None,
|
| 371 |
+
text_n_heads: int = 4,
|
| 372 |
+
time_hdim: int = 256,
|
| 373 |
+
use_residual: bool = True,
|
| 374 |
+
rotary_base: float = 10000.0,
|
| 375 |
+
):
|
| 376 |
+
super().__init__()
|
| 377 |
+
self.in_channels = in_channels
|
| 378 |
+
self.hidden_channels = hidden_channels
|
| 379 |
+
self.out_channels = out_channels
|
| 380 |
+
self.text_dim = text_dim
|
| 381 |
+
self.style_dim = style_dim
|
| 382 |
+
self.rope_gamma = rope_gamma
|
| 383 |
+
|
| 384 |
+
# Shared tiled constant ([1, 50, 256]) consumed by every style-attn W_key.
|
| 385 |
+
self.tile = nn.Parameter(torch.randn(1, num_style_tokens, style_dim) * 0.02)
|
| 386 |
+
|
| 387 |
+
self.proj_in = ProjectionWrapper(in_channels, hidden_channels)
|
| 388 |
+
self.time_encoder = TimeEncoder(time_embed_dim, hdim=time_hdim)
|
| 389 |
+
|
| 390 |
+
self.main_blocks = nn.ModuleList()
|
| 391 |
+
|
| 392 |
+
shared_attn_scale = math.sqrt(256)
|
| 393 |
+
|
| 394 |
+
mb_cfg = main_blocks_cfg or {}
|
| 395 |
+
lc_cfg = last_convnext_cfg or {}
|
| 396 |
+
|
| 397 |
+
c0_cfg = mb_cfg.get("convnext_0", {})
|
| 398 |
+
c1_cfg = mb_cfg.get("convnext_1", {})
|
| 399 |
+
c2_cfg = mb_cfg.get("convnext_2", {})
|
| 400 |
+
|
| 401 |
+
for _ in range(num_superblocks):
|
| 402 |
+
self.main_blocks.append(
|
| 403 |
+
ConvNeXtStack(hidden_channels, kernel_size=c0_cfg.get("ksz", 5), dilations=c0_cfg.get("dilation_lst", [1, 2, 4, 8]))
|
| 404 |
+
)
|
| 405 |
+
self.main_blocks.append(
|
| 406 |
+
TimeCondBlock(time_dim=time_embed_dim, channels=hidden_channels)
|
| 407 |
+
)
|
| 408 |
+
self.main_blocks.append(
|
| 409 |
+
ConvNeXtStack(hidden_channels, kernel_size=c1_cfg.get("ksz", 5), dilations=c1_cfg.get("dilation_lst", [1]))
|
| 410 |
+
)
|
| 411 |
+
self.main_blocks.append(
|
| 412 |
+
CrossAttentionBlock(
|
| 413 |
+
d_model=hidden_channels,
|
| 414 |
+
d_context=text_dim,
|
| 415 |
+
num_heads=text_n_heads,
|
| 416 |
+
attn_dim=256,
|
| 417 |
+
use_rope=True,
|
| 418 |
+
rope_gamma=self.rope_gamma,
|
| 419 |
+
attn_scale=shared_attn_scale,
|
| 420 |
+
use_residual=use_residual,
|
| 421 |
+
rotary_base=rotary_base,
|
| 422 |
+
)
|
| 423 |
+
)
|
| 424 |
+
self.main_blocks.append(
|
| 425 |
+
ConvNeXtStack(hidden_channels, kernel_size=c2_cfg.get("ksz", 5), dilations=c2_cfg.get("dilation_lst", [1]))
|
| 426 |
+
)
|
| 427 |
+
self.main_blocks.append(
|
| 428 |
+
CrossAttentionBlock(
|
| 429 |
+
d_model=hidden_channels,
|
| 430 |
+
d_context=style_dim,
|
| 431 |
+
num_heads=2,
|
| 432 |
+
attn_dim=256,
|
| 433 |
+
use_rope=False,
|
| 434 |
+
attn_scale=shared_attn_scale,
|
| 435 |
+
use_residual=use_residual,
|
| 436 |
+
rotary_base=rotary_base,
|
| 437 |
+
)
|
| 438 |
+
)
|
| 439 |
+
|
| 440 |
+
self.last_convnext = ConvNeXtStack(
|
| 441 |
+
hidden_channels, kernel_size=lc_cfg.get("ksz", 5), dilations=lc_cfg.get("dilation_lst", [1, 1, 1, 1])
|
| 442 |
+
)
|
| 443 |
+
self.proj_out = ProjectionWrapper(hidden_channels, out_channels)
|
| 444 |
+
|
| 445 |
+
def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
|
| 446 |
+
missing_keys, unexpected_keys, error_msgs):
|
| 447 |
+
# Back-compat: older checkpoints stored the tiled style-key under `style_key`.
|
| 448 |
+
legacy_key = prefix + "style_key"
|
| 449 |
+
new_key = prefix + "tile"
|
| 450 |
+
if legacy_key in state_dict and new_key not in state_dict:
|
| 451 |
+
state_dict[new_key] = state_dict.pop(legacy_key)
|
| 452 |
+
return super()._load_from_state_dict(
|
| 453 |
+
state_dict, prefix, local_metadata, strict,
|
| 454 |
+
missing_keys, unexpected_keys, error_msgs,
|
| 455 |
+
)
|
| 456 |
+
|
| 457 |
+
def forward(
|
| 458 |
+
self,
|
| 459 |
+
noisy_latent: torch.Tensor,
|
| 460 |
+
text_emb: torch.Tensor,
|
| 461 |
+
style_ttl: torch.Tensor,
|
| 462 |
+
latent_mask: torch.Tensor,
|
| 463 |
+
text_mask: torch.Tensor,
|
| 464 |
+
current_step: torch.Tensor,
|
| 465 |
+
total_step: Optional[torch.Tensor] = None,
|
| 466 |
+
) -> torch.Tensor:
|
| 467 |
+
B = noisy_latent.shape[0]
|
| 468 |
+
|
| 469 |
+
if total_step is not None:
|
| 470 |
+
t_norm = current_step.reshape(B, 1, 1) / total_step.reshape(B, 1, 1)
|
| 471 |
+
reciprocal = 1.0 / total_step.reshape(B, 1, 1)
|
| 472 |
+
t_norm_flat = t_norm.reshape(B)
|
| 473 |
+
else:
|
| 474 |
+
t_norm_flat = current_step.reshape(B)
|
| 475 |
+
|
| 476 |
+
t_emb = self.time_encoder(t_norm_flat)
|
| 477 |
+
text_blc = text_emb.transpose(1, 2)
|
| 478 |
+
|
| 479 |
+
x = self.proj_in(noisy_latent)
|
| 480 |
+
x = x * latent_mask
|
| 481 |
+
|
| 482 |
+
for i, block in enumerate(self.main_blocks):
|
| 483 |
+
idx_in_super = i % 6
|
| 484 |
+
if idx_in_super == 0:
|
| 485 |
+
x = block(x, mask=latent_mask)
|
| 486 |
+
elif idx_in_super == 1:
|
| 487 |
+
x = block(x, t_emb)
|
| 488 |
+
x = x * latent_mask
|
| 489 |
+
elif idx_in_super == 2:
|
| 490 |
+
x = block(x, mask=latent_mask)
|
| 491 |
+
elif idx_in_super == 3:
|
| 492 |
+
x = block(x, context=text_blc, context_keys=None,
|
| 493 |
+
x_mask=latent_mask, context_mask=text_mask)
|
| 494 |
+
elif idx_in_super == 4:
|
| 495 |
+
x = block(x, mask=latent_mask)
|
| 496 |
+
elif idx_in_super == 5:
|
| 497 |
+
x = block(x, context=style_ttl,
|
| 498 |
+
context_keys=self.tile.expand(B, -1, -1),
|
| 499 |
+
x_mask=latent_mask, context_mask=None)
|
| 500 |
+
|
| 501 |
+
x = self.last_convnext(x, mask=latent_mask)
|
| 502 |
+
diff_out = self.proj_out(x) * latent_mask
|
| 503 |
+
|
| 504 |
+
if total_step is not None:
|
| 505 |
+
denoised = noisy_latent + reciprocal * diff_out
|
| 506 |
+
return denoised * latent_mask
|
| 507 |
+
return diff_out
|