notmax123 commited on
Commit
863d06f
·
1 Parent(s): 4818895

Clone tab: search fonts/pt_models with filename aliases; vendor models/

Browse files
.gitignore CHANGED
@@ -13,6 +13,8 @@ voices
13
  renikud.onnx
14
  model.onnx
15
  pt_weights
 
 
16
 
17
  # Virtual environments
18
  venv/
 
13
  renikud.onnx
14
  model.onnx
15
  pt_weights
16
+ fonts/pt_models
17
+ pt_models
18
 
19
  # Virtual environments
20
  venv/
app.py CHANGED
@@ -122,7 +122,7 @@ class TextProcessor:
122
  from phonemizer.separator import Separator
123
  EspeakWrapper.set_library(espeakng_loader.get_library_path())
124
  if hasattr(EspeakWrapper, "set_data_path"):
125
- EspeakWrapper.set_data_path(espeakng_loader.get_data_path())
126
  self._espeak_separator = Separator(phone="", word=" ", syllable="")
127
  self._espeak_ready = True
128
  except Exception as e:
@@ -157,7 +157,7 @@ class TextProcessor:
157
  return re.sub(r"\s+", " ", r.stdout.replace("\n", " ")).strip()
158
  except Exception as e:
159
  print(f"[WARN] espeak-ng subprocess failed for {lang}: {e}")
160
- return text
161
 
162
  def _phonemize_segment(self, content: str, lang: str) -> str:
163
  content = content.strip()
@@ -211,7 +211,7 @@ class UnicodeProcessor:
211
  if isinstance(raw, dict) and "char_to_id" in raw:
212
  self.pad_id = int(raw.get("pad_id", 0))
213
  self._char_to_id = {k: int(v) for k, v in raw["char_to_id"].items()}
214
- else:
215
  self._char_to_id = {chr(int(k)): int(v) for k, v in raw.items()}
216
  print(f"[INFO] Loaded vocab from {indexer_path} ({len(self._char_to_id)} entries)")
217
  else:
@@ -327,7 +327,7 @@ def chunk_text(text: str, max_len: int = 300) -> List[str]:
327
  for sentence in re.split(pattern, paragraph):
328
  if len(current) + len(sentence) + 1 <= max_len:
329
  current += (" " if current else "") + sentence
330
- else:
331
  if current:
332
  chunks.append(current.strip())
333
  current = sentence
@@ -550,7 +550,30 @@ def synthesize_text(text: str, voice: str, lang: str, steps: int, speed: float,
550
  # Voice-clone tab (runs export_new_voice.py)
551
  # ============================================================
552
  EXPORT_SCRIPT = os.path.join(os.path.dirname(__file__), "export_new_voice.py")
553
- PT_WEIGHTS_DIR = "pt_weights"
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
554
 
555
 
556
  def _refresh_voices() -> None:
@@ -568,19 +591,20 @@ def clone_voice(ref_wav: Optional[str], voice_name: str):
568
  safe = re.sub(r"[^\w\-]+", "_", voice_name.strip())
569
  out_path = os.path.join(VOICES_DIR, f"{safe}.json")
570
 
571
- needed = {
572
- "ae_ckpt": os.path.join(PT_WEIGHTS_DIR, "blue_codec.safetensors"),
573
- "ttl_ckpt": os.path.join(PT_WEIGHTS_DIR, "vf_estimator.safetensors"),
574
- "dp_ckpt": os.path.join(PT_WEIGHTS_DIR, "duration_predictor.safetensors"),
575
- "stats": os.path.join(PT_WEIGHTS_DIR, "stats_multilingual.pt"),
576
- }
577
- missing = [v for v in needed.values() if not os.path.exists(v)]
578
  if missing:
 
 
 
 
579
  return (
580
- "Voice cloning needs PyTorch checkpoints. Please fetch them first:\n"
581
- " hf download notmax123/blue blue_codec.safetensors duration_predictor.safetensors "
582
- "vf_estimator.safetensors stats_multilingual.pt --local-dir pt_weights\n\n"
583
- f"Missing: {', '.join(missing)}"
 
 
584
  ), gr.update()
585
 
586
  cmd = [
@@ -666,40 +690,40 @@ with gr.Blocks(title="BlueTTS — Multilingual TTS") as demo:
666
 
667
  with gr.Tabs():
668
  with gr.TabItem("Synthesize"):
669
- with gr.Column(elem_classes="card"):
670
- text_input = gr.Textbox(
671
  label="Text", placeholder="Type or paste text here…",
672
  lines=4, elem_classes="big-input",
673
- value="Great ideas become real when a small team keeps building every single day.",
674
- )
675
- with gr.Column(elem_classes="controls-row"):
676
- with gr.Row(elem_classes="ctrl-row1"):
677
- lang_input = gr.Dropdown(
678
  choices=[("English 🇺🇸", "en"), ("Hebrew 🇮🇱", "he"),
679
  ("Spanish 🇪🇸", "es"), ("German 🇩🇪", "de"),
680
  ("Italian 🇮🇹", "it")],
681
  value="en", label="Language", elem_classes="ctrl-lang",
682
  )
683
- voice_input = gr.Dropdown(
684
  choices=list(VOICES.keys()),
685
  value=next(iter(VOICES.keys()), None),
686
  label="Voice", elem_classes="ctrl-voice",
687
- )
688
- with gr.Row(elem_classes="ctrl-row2"):
689
  steps_input = gr.Slider(2, 32, 8, step=1, label="Quality (steps)", elem_classes="ctrl-steps")
690
  speed_input = gr.Slider(0.5, 2.0, 1.0, step=0.05, label="Speed", elem_classes="ctrl-speed")
691
  cfg_input = gr.Slider(1.0, 7.0, 3.0, step=0.1, label="CFG Scale", elem_classes="ctrl-cfg")
692
- btn = gr.Button("⚡ Generate Speech", elem_classes="gen-btn")
693
- audio_out = gr.Audio(label="Output", type="numpy", autoplay=True)
694
- stats_out = gr.HTML()
695
 
696
  gr.Examples(examples=EXAMPLES, inputs=[text_input, lang_input], label="Examples")
697
 
698
- btn.click(
699
- synthesize_text,
700
  inputs=[text_input, voice_input, lang_input, steps_input, speed_input, cfg_input],
701
- outputs=[audio_out, stats_out],
702
- )
703
 
704
  with gr.TabItem("Clone Voice"):
705
  with gr.Column(elem_classes="card"):
 
122
  from phonemizer.separator import Separator
123
  EspeakWrapper.set_library(espeakng_loader.get_library_path())
124
  if hasattr(EspeakWrapper, "set_data_path"):
125
+ EspeakWrapper.set_data_path(espeakng_loader.get_data_path())
126
  self._espeak_separator = Separator(phone="", word=" ", syllable="")
127
  self._espeak_ready = True
128
  except Exception as e:
 
157
  return re.sub(r"\s+", " ", r.stdout.replace("\n", " ")).strip()
158
  except Exception as e:
159
  print(f"[WARN] espeak-ng subprocess failed for {lang}: {e}")
160
+ return text
161
 
162
  def _phonemize_segment(self, content: str, lang: str) -> str:
163
  content = content.strip()
 
211
  if isinstance(raw, dict) and "char_to_id" in raw:
212
  self.pad_id = int(raw.get("pad_id", 0))
213
  self._char_to_id = {k: int(v) for k, v in raw["char_to_id"].items()}
214
+ else:
215
  self._char_to_id = {chr(int(k)): int(v) for k, v in raw.items()}
216
  print(f"[INFO] Loaded vocab from {indexer_path} ({len(self._char_to_id)} entries)")
217
  else:
 
327
  for sentence in re.split(pattern, paragraph):
328
  if len(current) + len(sentence) + 1 <= max_len:
329
  current += (" " if current else "") + sentence
330
+ else:
331
  if current:
332
  chunks.append(current.strip())
333
  current = sentence
 
550
  # Voice-clone tab (runs export_new_voice.py)
551
  # ============================================================
552
  EXPORT_SCRIPT = os.path.join(os.path.dirname(__file__), "export_new_voice.py")
553
+
554
+ # Accept checkpoints from a handful of common locations (with the filename
555
+ # variants we've seen in the wild) so the clone tab works out of the box.
556
+ PT_WEIGHTS_SEARCH = [
557
+ "pt_weights",
558
+ os.path.join("fonts", "pt_models"),
559
+ "pt_models",
560
+ ]
561
+ PT_WEIGHT_ALIASES: dict[str, list[str]] = {
562
+ "ae_ckpt": ["blue_codec.safetensors", "blue_codec.pt"],
563
+ "ttl_ckpt": ["vf_estimator.safetensors", "vf_estimator.pt", "vf_estimetor.pt"],
564
+ "dp_ckpt": ["duration_predictor.safetensors", "duration_predictor.pt",
565
+ "duration_predictor_final.pt"],
566
+ "stats": ["stats_multilingual.pt", "stats.pt"],
567
+ }
568
+
569
+
570
+ def _find_pt_weight(aliases: list[str]) -> Optional[str]:
571
+ for d in PT_WEIGHTS_SEARCH:
572
+ for name in aliases:
573
+ p = os.path.join(d, name)
574
+ if os.path.exists(p):
575
+ return p
576
+ return None
577
 
578
 
579
  def _refresh_voices() -> None:
 
591
  safe = re.sub(r"[^\w\-]+", "_", voice_name.strip())
592
  out_path = os.path.join(VOICES_DIR, f"{safe}.json")
593
 
594
+ needed: dict[str, Optional[str]] = {k: _find_pt_weight(v) for k, v in PT_WEIGHT_ALIASES.items()}
595
+ missing = [k for k, v in needed.items() if v is None]
 
 
 
 
 
596
  if missing:
597
+ searched = ", ".join(PT_WEIGHTS_SEARCH)
598
+ wanted = "\n".join(
599
+ f" {k}: any of {PT_WEIGHT_ALIASES[k]}" for k in missing
600
+ )
601
  return (
602
+ "Voice cloning needs PyTorch checkpoints. I looked in: "
603
+ f"{searched}\nMissing:\n{wanted}\n\n"
604
+ "Fetch them with:\n"
605
+ " hf download notmax123/blue blue_codec.safetensors "
606
+ "duration_predictor.safetensors vf_estimator.safetensors "
607
+ "stats_multilingual.pt --local-dir pt_weights"
608
  ), gr.update()
609
 
610
  cmd = [
 
690
 
691
  with gr.Tabs():
692
  with gr.TabItem("Synthesize"):
693
+ with gr.Column(elem_classes="card"):
694
+ text_input = gr.Textbox(
695
  label="Text", placeholder="Type or paste text here…",
696
  lines=4, elem_classes="big-input",
697
+ value="Great ideas become real when a small team keeps building every single day.",
698
+ )
699
+ with gr.Column(elem_classes="controls-row"):
700
+ with gr.Row(elem_classes="ctrl-row1"):
701
+ lang_input = gr.Dropdown(
702
  choices=[("English 🇺🇸", "en"), ("Hebrew 🇮🇱", "he"),
703
  ("Spanish 🇪🇸", "es"), ("German 🇩🇪", "de"),
704
  ("Italian 🇮🇹", "it")],
705
  value="en", label="Language", elem_classes="ctrl-lang",
706
  )
707
+ voice_input = gr.Dropdown(
708
  choices=list(VOICES.keys()),
709
  value=next(iter(VOICES.keys()), None),
710
  label="Voice", elem_classes="ctrl-voice",
711
+ )
712
+ with gr.Row(elem_classes="ctrl-row2"):
713
  steps_input = gr.Slider(2, 32, 8, step=1, label="Quality (steps)", elem_classes="ctrl-steps")
714
  speed_input = gr.Slider(0.5, 2.0, 1.0, step=0.05, label="Speed", elem_classes="ctrl-speed")
715
  cfg_input = gr.Slider(1.0, 7.0, 3.0, step=0.1, label="CFG Scale", elem_classes="ctrl-cfg")
716
+ btn = gr.Button("⚡ Generate Speech", elem_classes="gen-btn")
717
+ audio_out = gr.Audio(label="Output", type="numpy", autoplay=True)
718
+ stats_out = gr.HTML()
719
 
720
  gr.Examples(examples=EXAMPLES, inputs=[text_input, lang_input], label="Examples")
721
 
722
+ btn.click(
723
+ synthesize_text,
724
  inputs=[text_input, voice_input, lang_input, steps_input, speed_input, cfg_input],
725
+ outputs=[audio_out, stats_out],
726
+ )
727
 
728
  with gr.TabItem("Clone Voice"):
729
  with gr.Column(elem_classes="card"):
models/__init__.py ADDED
File without changes
models/reference_encoder.py ADDED
@@ -0,0 +1,86 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+
4
+ from .text_encoder import ConvNeXtWrapper
5
+
6
+
7
+ class ReferenceEncoder(nn.Module):
8
+ def __init__(
9
+ self,
10
+ in_channels: int = 144,
11
+ d_model: int = 256,
12
+ hidden_dim: int = 1024,
13
+ num_blocks: int = 6,
14
+ num_tokens: int = 50,
15
+ num_heads: int = 2,
16
+ kernel_size: int = 5,
17
+ dilation_lst: list = None,
18
+ prototype_dim: int = 256,
19
+ n_units: int = 256,
20
+ style_value_dim: int = 256,
21
+ ):
22
+ super().__init__()
23
+ self.d_model = d_model
24
+ self.num_tokens = num_tokens
25
+
26
+ if hidden_dim % d_model != 0:
27
+ raise ValueError(f"hidden_dim ({hidden_dim}) must be divisible by d_model ({d_model})")
28
+ mlp_ratio = hidden_dim // d_model
29
+
30
+ self.input_proj = nn.Conv1d(in_channels, d_model, kernel_size=1)
31
+ self.convnext = ConvNeXtWrapper(
32
+ d_model,
33
+ n_layers=num_blocks,
34
+ expansion_factor=mlp_ratio,
35
+ kernel_size=kernel_size,
36
+ dilation_lst=dilation_lst,
37
+ )
38
+
39
+ self.ref_keys = nn.Parameter(torch.randn(num_tokens, prototype_dim) * 0.02)
40
+ self.q_proj = nn.Linear(prototype_dim, n_units) if prototype_dim != n_units else nn.Identity()
41
+ self.out_proj = nn.Linear(n_units, style_value_dim) if n_units != style_value_dim else nn.Identity()
42
+
43
+ self.attn1 = nn.MultiheadAttention(
44
+ embed_dim=n_units, num_heads=num_heads, kdim=d_model, vdim=d_model, batch_first=True
45
+ )
46
+ self.attn2 = nn.MultiheadAttention(
47
+ embed_dim=n_units, num_heads=num_heads, kdim=d_model, vdim=d_model, batch_first=True
48
+ )
49
+
50
+ def forward(self, z_ref: torch.Tensor, mask: torch.Tensor = None):
51
+ B = z_ref.shape[0]
52
+ x = self.input_proj(z_ref)
53
+ x = self.convnext(x, mask=mask)
54
+ kv = x.transpose(1, 2)
55
+
56
+ key_padding_mask = None
57
+ if mask is not None:
58
+ key_padding_mask = (mask.squeeze(1) == 0)
59
+
60
+ q0 = self.ref_keys.unsqueeze(0).expand(B, -1, -1)
61
+ q0 = self.q_proj(q0)
62
+
63
+ q1, _ = self.attn1(query=q0, key=kv, value=kv, key_padding_mask=key_padding_mask, need_weights=False)
64
+ q2 = q0 + q1
65
+ out, _ = self.attn2(query=q2, key=kv, value=kv, key_padding_mask=key_padding_mask, need_weights=False)
66
+ return self.out_proj(out)
67
+
68
+ @staticmethod
69
+ def remap_legacy_state_dict(state_dict: dict) -> dict:
70
+ """Remap pre-refactor checkpoints (per-layer pre-norm + FFN) onto current layout."""
71
+ remapped = {}
72
+ legacy_prefix_map = {
73
+ "attn_layers.0.attn.": "attn1.",
74
+ "attn_layers.1.attn.": "attn2.",
75
+ }
76
+ drop_substrings = (".norm_q.", ".norm_kv.", ".ffn.", "pos_emb.")
77
+ for k, v in state_dict.items():
78
+ if any(s in k for s in drop_substrings):
79
+ continue
80
+ new_key = k
81
+ for old, new in legacy_prefix_map.items():
82
+ if new_key.startswith(old):
83
+ new_key = new + new_key[len(old):]
84
+ break
85
+ remapped[new_key] = v
86
+ return remapped
models/text_encoder.py ADDED
@@ -0,0 +1,358 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+
5
+
6
+ class LayerNorm(nn.Module):
7
+ def __init__(self, channels: int, eps: float = 1e-6):
8
+ super().__init__()
9
+ self.norm = nn.LayerNorm(channels, eps=eps)
10
+
11
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
12
+ x = x.transpose(1, 2)
13
+ x = self.norm(x)
14
+ x = x.transpose(1, 2)
15
+ return x
16
+
17
+
18
+ class ConvNeXtBlock(nn.Module):
19
+ def __init__(self, dim: int, expansion_factor: int = 4, kernel_size: int = 5, dilation: int = 1, layer_scale_init_value: float = 1e-6):
20
+ super().__init__()
21
+ hidden_dim = dim * expansion_factor
22
+ if (kernel_size % 2) != 1:
23
+ raise ValueError(f"ConvNeXtBlock expects odd kernel_size, got {kernel_size}")
24
+ self.pad = ((kernel_size - 1) // 2) * dilation
25
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=0, groups=dim, dilation=dilation)
26
+ self.norm = LayerNorm(dim, eps=1e-6)
27
+ self.pwconv1 = nn.Conv1d(dim, hidden_dim, kernel_size=1)
28
+ self.act = nn.GELU()
29
+ self.pwconv2 = nn.Conv1d(hidden_dim, dim, kernel_size=1)
30
+ self.gamma = nn.Parameter(layer_scale_init_value * torch.ones((1, dim, 1)), requires_grad=True)
31
+
32
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
33
+ if mask is not None:
34
+ x = x * mask
35
+ residual = x
36
+
37
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
38
+ x = self.dwconv(x)
39
+ if mask is not None:
40
+ x = x * mask
41
+
42
+ x = self.norm(x)
43
+ x = self.pwconv1(x)
44
+ x = self.act(x)
45
+ x = self.pwconv2(x)
46
+ x = self.gamma * x
47
+
48
+ x = residual + x
49
+ if mask is not None:
50
+ x = x * mask
51
+ return x
52
+
53
+
54
+ class ConvNeXtWrapper(nn.Module):
55
+ def __init__(self, d_model, n_layers, expansion_factor, kernel_size=5, dilation_lst=None):
56
+ super().__init__()
57
+ if dilation_lst is None:
58
+ dilation_lst = [1] * n_layers
59
+ self.convnext = nn.ModuleList([
60
+ ConvNeXtBlock(d_model, expansion_factor=expansion_factor, kernel_size=kernel_size, dilation=dilation_lst[i])
61
+ for i in range(n_layers)
62
+ ])
63
+
64
+ def forward(self, x, mask=None):
65
+ for block in self.convnext:
66
+ x = block(x, mask=mask)
67
+ return x
68
+
69
+
70
+ class RelativeMultiHeadAttention(nn.Module):
71
+ def __init__(self, channels: int, n_heads: int, window_size: int = 4, p_dropout: float = 0.0):
72
+ super().__init__()
73
+ assert channels % n_heads == 0
74
+ self.channels = channels
75
+ self.n_heads = n_heads
76
+ self.head_dim = channels // n_heads
77
+ self.scale = self.head_dim ** -0.5
78
+ self.window_size = window_size
79
+
80
+ self.conv_q = nn.Conv1d(channels, channels, 1)
81
+ self.conv_k = nn.Conv1d(channels, channels, 1)
82
+ self.conv_v = nn.Conv1d(channels, channels, 1)
83
+ self.conv_o = nn.Conv1d(channels, channels, 1)
84
+
85
+ self.emb_rel_k = nn.Parameter(torch.randn(1, 2 * window_size + 1, self.head_dim) * 0.02)
86
+ self.emb_rel_v = nn.Parameter(torch.randn(1, 2 * window_size + 1, self.head_dim) * 0.02)
87
+
88
+ self.drop = nn.Dropout(p_dropout)
89
+
90
+ def forward(self, x: torch.Tensor, attn_mask: torch.Tensor | None = None) -> torch.Tensor:
91
+ B, C, L = x.shape
92
+
93
+ q = self.conv_q(x).view(B, self.n_heads, self.head_dim, L).transpose(2, 3)
94
+ q = q * self.scale
95
+ k = self.conv_k(x).view(B, self.n_heads, self.head_dim, L).transpose(2, 3)
96
+ v = self.conv_v(x).view(B, self.n_heads, self.head_dim, L).transpose(2, 3)
97
+
98
+ scores = torch.matmul(q, k.transpose(-2, -1))
99
+
100
+ t = torch.arange(L, device=x.device)
101
+ diff = t[None, :] - t[:, None]
102
+ window_mask = (diff.abs() <= self.window_size)
103
+ diff_clamped = torch.clamp(diff, -self.window_size, self.window_size)
104
+ indices = diff_clamped + self.window_size
105
+
106
+ rel_k = self.emb_rel_k[0][indices]
107
+ rel_scores = torch.einsum("bhld,ljd->bhlj", q, rel_k)
108
+ rel_scores = rel_scores * window_mask[None, None, :, :]
109
+
110
+ scores = scores + rel_scores
111
+
112
+ if attn_mask is not None:
113
+ scores = scores.masked_fill(attn_mask == 0, -1e4)
114
+
115
+ attn = torch.softmax(scores, dim=-1)
116
+ attn = self.drop(attn)
117
+
118
+ out = torch.matmul(attn, v)
119
+
120
+ rel_v = self.emb_rel_v[0][indices]
121
+ rel_v = rel_v * window_mask[:, :, None]
122
+ out_rel = torch.einsum("bhlj,ljd->bhld", attn, rel_v)
123
+
124
+ out = out + out_rel
125
+ out = out.transpose(2, 3).contiguous().view(B, C, L)
126
+ out = self.conv_o(out)
127
+ return out
128
+
129
+
130
+ class FeedForward(nn.Module):
131
+ def __init__(self, channels: int, filter_channels: int, kernel_size: int = 1, p_dropout: float = 0.0):
132
+ super().__init__()
133
+ self.conv_1 = nn.Conv1d(channels, filter_channels, kernel_size)
134
+ self.relu = nn.ReLU()
135
+ self.drop = nn.Dropout(p_dropout)
136
+ self.conv_2 = nn.Conv1d(filter_channels, channels, kernel_size)
137
+
138
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
139
+ if mask is not None:
140
+ x = x * mask
141
+ x = self.conv_1(x)
142
+ x = self.relu(x)
143
+ x = self.drop(x)
144
+ if mask is not None:
145
+ x = x * mask
146
+ x = self.conv_2(x)
147
+ if mask is not None:
148
+ x = x * mask
149
+ return x
150
+
151
+
152
+ class AttnEncoder(nn.Module):
153
+ def __init__(self, channels: int, n_heads: int, filter_channels: int, n_layers: int, p_dropout: float = 0.0):
154
+ super().__init__()
155
+ self.attn_layers = nn.ModuleList(
156
+ [RelativeMultiHeadAttention(channels, n_heads, window_size=4, p_dropout=p_dropout) for _ in range(n_layers)]
157
+ )
158
+ self.norm_layers_1 = nn.ModuleList([LayerNorm(channels) for _ in range(n_layers)])
159
+ self.ffn_layers = nn.ModuleList(
160
+ [FeedForward(channels, filter_channels, p_dropout=p_dropout) for _ in range(n_layers)]
161
+ )
162
+ self.norm_layers_2 = nn.ModuleList([LayerNorm(channels) for _ in range(n_layers)])
163
+
164
+ def forward(self, x: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
165
+ if mask is not None:
166
+ x = x * mask
167
+
168
+ attn_mask = None
169
+ if mask is not None:
170
+ attn_mask = mask.unsqueeze(-1) * mask.unsqueeze(-2)
171
+
172
+ for i in range(len(self.attn_layers)):
173
+ residual = x
174
+ x = self.attn_layers[i](x, attn_mask=attn_mask)
175
+ x = residual + x
176
+ x = self.norm_layers_1[i](x)
177
+
178
+ residual_ffn = x
179
+ x_ffn = self.ffn_layers[i](x, mask=mask)
180
+ x = residual_ffn + x_ffn
181
+ x = self.norm_layers_2[i](x)
182
+
183
+ if mask is not None:
184
+ x = x * mask
185
+ return x
186
+
187
+
188
+ class LinearWrapped(nn.Module):
189
+ def __init__(self, in_dim, out_dim=None):
190
+ super().__init__()
191
+ if out_dim is None:
192
+ out_dim = in_dim
193
+ self.linear = nn.Linear(in_dim, out_dim)
194
+
195
+ def forward(self, x):
196
+ return self.linear(x)
197
+
198
+
199
+ class StyleNorm(nn.Module):
200
+ def __init__(self, dim, eps: float = 1e-6):
201
+ super().__init__()
202
+ self.norm = nn.LayerNorm(dim, eps=eps)
203
+
204
+ def forward(self, x):
205
+ x = self.norm(x)
206
+ x = x.transpose(1, 2)
207
+ return x
208
+
209
+
210
+ class TextEmbedderWrapper(nn.Module):
211
+ def __init__(self, vocab_size, d_model):
212
+ super().__init__()
213
+ self.char_embedder = nn.Embedding(vocab_size, d_model)
214
+
215
+ def forward(self, x):
216
+ return self.char_embedder(x)
217
+
218
+
219
+ class StyleAttentionLayer(nn.Module):
220
+ def __init__(self, text_dim: int, style_dim: int, n_units: int, num_heads: int = 2, num_style_tokens: int = 50):
221
+ super().__init__()
222
+ assert n_units % num_heads == 0
223
+ self.num_heads = num_heads
224
+ self.dim = n_units
225
+ self.head_dim = n_units // num_heads
226
+ self.scale = n_units ** -0.5
227
+
228
+ self.W_query = LinearWrapped(text_dim, n_units)
229
+ self.W_value = LinearWrapped(style_dim, n_units)
230
+ self.out_fc = LinearWrapped(n_units, text_dim)
231
+
232
+ # ONNX folds `tanh(W_key(style_key))` into a baked constant; mirror with a learnable parameter.
233
+ self.key_const = nn.Parameter(torch.randn(num_heads, 1, self.head_dim, num_style_tokens) * 0.02)
234
+
235
+ def forward(self, x: torch.Tensor, values: torch.Tensor, mask_t: torch.Tensor | None = None) -> torch.Tensor:
236
+ B, T, C = x.shape
237
+
238
+ q = self.W_query(x)
239
+ qs = q.chunk(self.num_heads, dim=-1)
240
+ q = torch.stack(qs, dim=0)
241
+
242
+ k = self.key_const
243
+
244
+ if values.dim() == 2:
245
+ values = values.unsqueeze(0)
246
+ if values.shape[0] != B:
247
+ values = values.expand(B, -1, -1)
248
+
249
+ v = self.W_value(values)
250
+ vs = v.chunk(self.num_heads, dim=-1)
251
+ v = torch.stack(vs, dim=0)
252
+
253
+ scores = torch.matmul(q, k) * self.scale
254
+ attn = torch.softmax(scores, dim=-1)
255
+
256
+ if mask_t is not None:
257
+ attn_mask = (mask_t.unsqueeze(0) == 0)
258
+ attn = attn.masked_fill(attn_mask, 0.0)
259
+
260
+ out = torch.matmul(attn, v)
261
+
262
+ outs = out.chunk(self.num_heads, dim=0)
263
+ out = torch.cat(outs, dim=-1).squeeze(0)
264
+
265
+ out = self.out_fc(out)
266
+
267
+ if mask_t is not None:
268
+ out = out * mask_t
269
+ return out
270
+
271
+
272
+ class StyleAttention(nn.Module):
273
+ def __init__(self, text_dim: int, style_dim: int, n_units: int, num_heads: int = 2, num_style_tokens: int = 50):
274
+ super().__init__()
275
+ # attention1 / attention2 are separate: each owns its baked key constant.
276
+ self.attention1 = StyleAttentionLayer(text_dim, style_dim, n_units, num_heads, num_style_tokens)
277
+ self.attention2 = StyleAttentionLayer(text_dim, style_dim, n_units, num_heads, num_style_tokens)
278
+ self.norm = StyleNorm(text_dim)
279
+
280
+ def forward(self, x: torch.Tensor, style_values: torch.Tensor, mask: torch.Tensor | None = None) -> torch.Tensor:
281
+ x = x.transpose(1, 2)
282
+
283
+ mask_t = None
284
+ if mask is not None:
285
+ mask_t = mask.transpose(1, 2)
286
+
287
+ out1 = self.attention1(x, style_values, mask_t=mask_t)
288
+ x1 = x + out1
289
+
290
+ out2 = self.attention2(x1, style_values, mask_t=mask_t)
291
+ x2 = x + out2
292
+
293
+ x = self.norm(x2)
294
+ if mask is not None:
295
+ x = x * mask
296
+ return x
297
+
298
+
299
+ class TextEncoder(nn.Module):
300
+ def __init__(
301
+ self,
302
+ vocab_size: int = 256,
303
+ d_model: int = 256,
304
+ n_conv_layers: int = 6,
305
+ n_attn_layers: int = 4,
306
+ expansion_factor: int = 4,
307
+ p_dropout: float = 0.1,
308
+ kernel_size: int = 5,
309
+ dilation_lst: list = None,
310
+ attn_n_heads: int = 4,
311
+ attn_filter_channels: int = 1024,
312
+ spte_n_heads: int = 2,
313
+ spte_text_dim: int = 256,
314
+ spte_style_dim: int = 256,
315
+ spte_n_units: int = 256,
316
+ spte_n_style: int = 50,
317
+ ):
318
+ super().__init__()
319
+ self.d_model = d_model
320
+ self.text_embedder = TextEmbedderWrapper(vocab_size, d_model)
321
+ self.convnext = ConvNeXtWrapper(
322
+ d_model, n_conv_layers, expansion_factor, kernel_size=kernel_size, dilation_lst=dilation_lst
323
+ )
324
+ self.attn_encoder = AttnEncoder(
325
+ d_model,
326
+ n_heads=attn_n_heads,
327
+ filter_channels=attn_filter_channels,
328
+ n_layers=n_attn_layers,
329
+ p_dropout=p_dropout,
330
+ )
331
+ self.speech_prompted_text_encoder = StyleAttention(
332
+ text_dim=spte_text_dim,
333
+ style_dim=spte_style_dim,
334
+ n_units=spte_n_units,
335
+ num_heads=spte_n_heads,
336
+ num_style_tokens=spte_n_style,
337
+ )
338
+ self.proj_out = nn.Identity()
339
+
340
+ def forward(self, text_ids: torch.Tensor, style_ttl: torch.Tensor, text_mask: torch.Tensor | None = None) -> torch.Tensor:
341
+ x = self.text_embedder(text_ids)
342
+ x = x.transpose(1, 2)
343
+
344
+ if text_mask is not None:
345
+ x = x * text_mask
346
+
347
+ x = self.convnext(x, mask=text_mask)
348
+ convnext_output = x
349
+
350
+ x = self.attn_encoder(x, mask=text_mask)
351
+ x = x + convnext_output
352
+
353
+ x = self.proj_out(x)
354
+ if text_mask is not None:
355
+ x = x * text_mask
356
+
357
+ x = self.speech_prompted_text_encoder(x, style_values=style_ttl, mask=text_mask)
358
+ return x
models/utils.py ADDED
@@ -0,0 +1,197 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import torch
2
+ import torch.nn as nn
3
+ import torch.nn.functional as F
4
+ import torchaudio.transforms as T
5
+
6
+
7
+ def compress_latents(z: torch.Tensor, factor: int = 6) -> torch.Tensor:
8
+ B, C, T = z.shape
9
+ if T % factor != 0:
10
+ pad = factor - (T % factor)
11
+ z = torch.nn.functional.pad(z, (0, pad))
12
+ T = T + pad
13
+ return z.view(B, C, T // factor, factor).permute(0, 1, 3, 2).flatten(1, 2)
14
+
15
+
16
+ def decompress_latents(z: torch.Tensor, factor: int = 6, target_channels: int = 24) -> torch.Tensor:
17
+ B, _, T_low = z.shape
18
+ return z.view(B, target_channels, factor, T_low).permute(0, 1, 3, 2).flatten(2, 3)
19
+
20
+
21
+ def _resolve_vocab_size(char_dict_path, default=256):
22
+ import json as _json
23
+ import os as _os
24
+ if char_dict_path and _os.path.exists(char_dict_path):
25
+ try:
26
+ with open(char_dict_path, "r") as f:
27
+ cd = _json.load(f)
28
+ if isinstance(cd, dict) and "vocab_size" in cd:
29
+ return int(cd["vocab_size"])
30
+ if isinstance(cd, dict) and "char_to_id" in cd and isinstance(cd["char_to_id"], dict):
31
+ return max(cd["char_to_id"].values()) + 1
32
+ if isinstance(cd, dict):
33
+ return max(cd.values()) + 1 if cd else default
34
+ return len(cd)
35
+ except Exception:
36
+ pass
37
+ return default
38
+
39
+
40
+ def load_ttl_config(config_path="configs/tts.json"):
41
+ import json
42
+ with open(config_path, "r") as f:
43
+ full_config = json.load(f)
44
+
45
+ ttl = full_config["ttl"]
46
+ ae = full_config.get("ae", {})
47
+ dp = full_config.get("dp", {})
48
+
49
+ te = ttl["text_encoder"]
50
+ se = ttl["style_encoder"]
51
+ vf = ttl["vector_field"]
52
+ um = ttl["uncond_masker"]
53
+
54
+ char_dict_path = te.get("char_dict_path", te.get("text_embedder", {}).get("char_dict_path"))
55
+ vocab_size = _resolve_vocab_size(char_dict_path, default=256)
56
+
57
+ dp_char_dict_path = (
58
+ dp.get("sentence_encoder", {}).get("char_dict_path")
59
+ or dp.get("sentence_encoder", {}).get("text_embedder", {}).get("char_dict_path")
60
+ )
61
+ dp_vocab_size = _resolve_vocab_size(dp_char_dict_path, default=vocab_size)
62
+
63
+ ae_dec = ae.get("decoder", {})
64
+ ae_dec_cfg = {
65
+ "idim": ae_dec.get("idim", 24),
66
+ "hdim": ae_dec.get("hdim", 512),
67
+ "intermediate_dim": ae_dec.get("intermediate_dim", 2048),
68
+ "ksz": ae_dec.get("ksz", 7),
69
+ "dilation_lst": ae_dec.get("dilation_lst", [1, 2, 4, 1, 2, 4, 1, 1, 1, 1]),
70
+ "chunk_compress_factor": ae.get("chunk_compress_factor", 1),
71
+ "head": {
72
+ "idim": ae_dec.get("head", {}).get("idim", ae_dec.get("hdim", 512)),
73
+ "hdim": ae_dec.get("head", {}).get("hdim", 2048),
74
+ "odim": ae_dec.get("head", {}).get("odim", 512),
75
+ "ksz": ae_dec.get("head", {}).get("ksz", 3),
76
+ },
77
+ }
78
+
79
+ ae_enc = ae.get("encoder", {})
80
+ ae_enc_spec = ae_enc.get("spec_processor", {})
81
+ ae_enc_cfg = {
82
+ "ksz": ae_enc.get("ksz", 7),
83
+ "hdim": ae_enc.get("hdim", 512),
84
+ "intermediate_dim": ae_enc.get("intermediate_dim", 2048),
85
+ "dilation_lst": ae_enc.get("dilation_lst", [1] * 10),
86
+ "odim": ae_enc.get("odim", 24),
87
+ "idim": ae_enc.get("idim", 1253),
88
+ }
89
+
90
+ dp_se = dp.get("style_encoder", {}).get("style_token_layer", {})
91
+
92
+ return {
93
+ "full_config": full_config,
94
+ "ttl": ttl,
95
+ "ae": ae,
96
+ "dp": dp,
97
+
98
+ "vocab_size": vocab_size,
99
+ "char_dict_path": char_dict_path,
100
+ "dp_vocab_size": dp_vocab_size,
101
+
102
+ "latent_dim": ttl["latent_dim"],
103
+ "chunk_compress_factor": ttl["chunk_compress_factor"],
104
+ "compressed_channels": ttl["latent_dim"] * ttl["chunk_compress_factor"],
105
+ "normalizer_scale": ttl["normalizer"]["scale"],
106
+ "sigma_min": ttl["flow_matching"]["sig_min"],
107
+ "Ke": ttl["batch_expander"]["n_batch_expand"],
108
+
109
+ "te_d_model": te["text_embedder"]["char_emb_dim"],
110
+ "te_convnext_layers": te["convnext"]["num_layers"],
111
+ "te_expansion_factor": te["convnext"]["intermediate_dim"] // te["text_embedder"]["char_emb_dim"],
112
+ "te_attn_n_layers": te["attn_encoder"]["n_layers"],
113
+ "te_attn_p_dropout": te["attn_encoder"]["p_dropout"],
114
+
115
+ "se_d_model": se["proj_in"]["odim"],
116
+ "se_hidden_dim": se["convnext"]["intermediate_dim"],
117
+ "se_num_blocks": se["convnext"]["num_layers"],
118
+ "se_n_style": se["style_token_layer"]["n_style"],
119
+ "se_n_heads": se["style_token_layer"]["n_heads"],
120
+
121
+ "prob_both_uncond": um["prob_both_uncond"],
122
+ "prob_text_uncond": um["prob_text_uncond"],
123
+ "uncond_init_std": um["std"],
124
+ "um_text_dim": um["text_dim"],
125
+ "um_n_style": um["n_style"],
126
+ "um_style_key_dim": um["style_key_dim"],
127
+ "um_style_value_dim": um["style_value_dim"],
128
+
129
+ "vf_hidden": vf["proj_in"]["odim"],
130
+ "vf_time_dim": vf["time_encoder"]["time_dim"],
131
+ "vf_n_blocks": vf["main_blocks"]["n_blocks"],
132
+ "vf_text_dim": vf["main_blocks"]["text_cond_layer"]["text_dim"],
133
+ "vf_text_n_heads": vf["main_blocks"]["text_cond_layer"]["n_heads"],
134
+ "vf_style_dim": vf["main_blocks"]["style_cond_layer"]["style_dim"],
135
+ "vf_rotary_scale": vf["main_blocks"]["text_cond_layer"]["rotary_scale"],
136
+
137
+ "ae_dec_cfg": ae_dec_cfg,
138
+ "ae_enc_cfg": ae_enc_cfg,
139
+ "ae_sample_rate": ae.get("sample_rate", 44100),
140
+ "ae_n_fft": ae_enc_spec.get("n_fft", 2048),
141
+ "ae_hop_length": ae_enc_spec.get("hop_length", 512),
142
+ "ae_n_mels": ae_enc_spec.get("n_mels", 1253),
143
+
144
+ "dp_style_tokens": dp_se.get("n_style", 8),
145
+ "dp_style_dim": dp_se.get("style_value_dim", 16),
146
+ }
147
+
148
+
149
+ class MelSpectrogram(nn.Module):
150
+ def __init__(self, sample_rate=44100, n_fft=2048, win_length=2048,
151
+ hop_length=512, n_mels=1253, f_min=0, f_max=None):
152
+ super().__init__()
153
+ self.mel = T.MelSpectrogram(
154
+ sample_rate=sample_rate, n_fft=n_fft, win_length=win_length,
155
+ hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max,
156
+ center=True, power=1.0,
157
+ )
158
+
159
+ def forward(self, audio):
160
+ mel = torch.log(torch.clamp(self.mel(audio), min=1e-5))
161
+ return mel.squeeze(1) if mel.dim() == 4 and mel.shape[1] == 1 else mel
162
+
163
+
164
+ class MelSpectrogramNoLog(nn.Module):
165
+ def __init__(self, sample_rate=44100, n_fft=2048, win_length=2048,
166
+ hop_length=512, n_mels=1253, f_min=0, f_max=12000, power=1.0):
167
+ super().__init__()
168
+ self.mel = T.MelSpectrogram(
169
+ sample_rate=sample_rate, n_fft=n_fft, win_length=win_length,
170
+ hop_length=hop_length, n_mels=n_mels, f_min=f_min, f_max=f_max,
171
+ center=True, power=power,
172
+ )
173
+
174
+ def forward(self, audio):
175
+ mel = self.mel(audio)
176
+ return mel.squeeze(1) if mel.dim() == 4 and mel.shape[1] == 1 else mel
177
+
178
+
179
+ class LinearMelSpectrogram(nn.Module):
180
+ def __init__(self, sample_rate=44100, n_fft=2048, win_length=2048,
181
+ hop_length=512, n_mels=1253, f_min=0, f_max=None):
182
+ super().__init__()
183
+ self.spectrogram = T.Spectrogram(
184
+ n_fft=n_fft, win_length=win_length, hop_length=hop_length,
185
+ center=True, power=1.0,
186
+ )
187
+ self.mel_scale = T.MelScale(
188
+ n_mels=n_mels, sample_rate=sample_rate,
189
+ n_stft=n_fft // 2 + 1, f_min=f_min, f_max=f_max,
190
+ )
191
+
192
+ def forward(self, audio):
193
+ spec = self.spectrogram(audio)
194
+ mel = self.mel_scale(spec)
195
+ log_spec = torch.log(torch.clamp(spec, min=1e-5))
196
+ log_mel = torch.log(torch.clamp(mel, min=1e-5))
197
+ return torch.cat([log_spec, log_mel], dim=1)
models/vf_estimator.py ADDED
@@ -0,0 +1,507 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import math
2
+ from typing import Optional
3
+
4
+ import torch
5
+ import torch.nn as nn
6
+ import torch.nn.functional as F
7
+
8
+
9
+ class LinearWrapper(nn.Module):
10
+ def __init__(self, in_features: int, out_features: int):
11
+ super().__init__()
12
+ self.linear = nn.Linear(in_features, out_features)
13
+
14
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
15
+ return self.linear(x)
16
+
17
+
18
+ class LayerNormWrapper(nn.Module):
19
+ def __init__(self, dim: int, eps: float = 1e-6):
20
+ super().__init__()
21
+ self.norm = nn.LayerNorm(dim, eps=eps)
22
+
23
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
24
+ x = x.transpose(1, 2)
25
+ x = self.norm(x)
26
+ x = x.transpose(1, 2)
27
+ return x
28
+
29
+
30
+ class ProjectionWrapper(nn.Module):
31
+ def __init__(self, in_channels: int, out_channels: int):
32
+ super().__init__()
33
+ self.net = nn.Conv1d(in_channels, out_channels, kernel_size=1, bias=False)
34
+
35
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
36
+ return self.net(x)
37
+
38
+
39
+ class Mish(nn.Module):
40
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
41
+ return x * torch.tanh(F.softplus(x))
42
+
43
+
44
+ class SinusoidalPosEmb(nn.Module):
45
+ def __init__(self, dim: int, scale: float = 1000.0):
46
+ super().__init__()
47
+ self.dim = dim
48
+ self.scale = scale
49
+
50
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
51
+ x = x * self.scale
52
+ device = x.device
53
+ half_dim = self.dim // 2
54
+ emb = math.log(10000) / (half_dim - 1)
55
+ emb = torch.exp(torch.arange(half_dim, device=device) * -emb)
56
+ emb = x[:, None] * emb[None, :]
57
+ emb = torch.cat((emb.sin(), emb.cos()), dim=-1)
58
+ return emb
59
+
60
+
61
+ class TimeEncoder(nn.Module):
62
+ def __init__(self, embed_dim: int, hdim: int = 256):
63
+ super().__init__()
64
+ self.sinusoidal = SinusoidalPosEmb(embed_dim, scale=1000.0)
65
+ self.mlp = nn.Sequential(
66
+ LinearWrapper(embed_dim, hdim),
67
+ Mish(),
68
+ LinearWrapper(hdim, embed_dim),
69
+ )
70
+
71
+ def forward(self, x: torch.Tensor) -> torch.Tensor:
72
+ x = self.sinusoidal(x)
73
+ x = self.mlp(x)
74
+ return x
75
+
76
+
77
+ class TimeCondBlock(nn.Module):
78
+ def __init__(self, time_dim: int, channels: int):
79
+ super().__init__()
80
+ self.linear = LinearWrapper(time_dim, channels)
81
+ # Zero-init so the block starts as identity.
82
+ nn.init.zeros_(self.linear.linear.weight)
83
+ nn.init.zeros_(self.linear.linear.bias)
84
+
85
+ def forward(self, x: torch.Tensor, time_emb: torch.Tensor) -> torch.Tensor:
86
+ cond = self.linear(time_emb)
87
+ cond = cond.unsqueeze(-1)
88
+ return x + cond
89
+
90
+
91
+ class ConvNeXtBlock1D(nn.Module):
92
+ def __init__(self, dim: int, kernel_size: int = 5, expansion: int = 2, dropout: float = 0.0, dilation: int = 1):
93
+ super().__init__()
94
+ self.pad = ((kernel_size - 1) // 2) * dilation
95
+ self.dwconv = nn.Conv1d(dim, dim, kernel_size=kernel_size, padding=0, groups=dim, dilation=dilation)
96
+ self.norm = LayerNormWrapper(dim)
97
+ self.pwconv1 = nn.Conv1d(dim, dim * expansion, kernel_size=1)
98
+ self.act = nn.GELU()
99
+ self.pwconv2 = nn.Conv1d(dim * expansion, dim, kernel_size=1)
100
+ self.gamma = nn.Parameter(torch.ones(1, dim, 1) * 1e-6)
101
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
102
+
103
+ def forward(self, x: torch.Tensor, mask: Optional[torch.Tensor] = None) -> torch.Tensor:
104
+ if mask is not None:
105
+ x = x * mask
106
+ residual = x
107
+
108
+ x = F.pad(x, (self.pad, self.pad), mode="replicate")
109
+ x = self.dwconv(x)
110
+ if mask is not None:
111
+ x = x * mask
112
+
113
+ x = self.norm(x)
114
+ x = self.pwconv1(x)
115
+ x = self.act(x)
116
+ x = self.pwconv2(x)
117
+ x = self.gamma * x
118
+ x = self.dropout(x)
119
+
120
+ x = x + residual
121
+ if mask is not None:
122
+ x = x * mask
123
+ return x
124
+
125
+
126
+ class ConvNeXtStack(nn.Module):
127
+ def __init__(self, channels, kernel_size, dilations):
128
+ super().__init__()
129
+ self.convnext = nn.ModuleList([
130
+ ConvNeXtBlock1D(channels, kernel_size=kernel_size, dilation=d, expansion=2)
131
+ for d in dilations
132
+ ])
133
+
134
+ def forward(self, x, mask=None):
135
+ for blk in self.convnext:
136
+ x = blk(x, mask)
137
+ return x
138
+
139
+
140
+ def apply_rotary_pos_emb(x: torch.Tensor, cos: torch.Tensor, sin: torch.Tensor) -> torch.Tensor:
141
+ B, H, T, D = x.shape
142
+ assert D % 2 == 0, "head_dim must be even for RoPE"
143
+
144
+ x1 = x[..., : D // 2]
145
+ x2 = x[..., D // 2 :]
146
+
147
+ if cos.dim() == 2:
148
+ cos = cos[None, None, :, :]
149
+ sin = sin[None, None, :, :]
150
+ elif cos.dim() == 3:
151
+ cos = cos.unsqueeze(1)
152
+ sin = sin.unsqueeze(1)
153
+
154
+ x1_rot = x1 * cos - x2 * sin
155
+ x2_rot = x1 * sin + x2 * cos
156
+
157
+ return torch.cat([x1_rot, x2_rot], dim=-1)
158
+
159
+
160
+ class AttentionModule(nn.Module):
161
+ """Text path uses LARoPE; style path uses tanh on keys (no RoPE)."""
162
+
163
+ def __init__(
164
+ self,
165
+ d_model: int,
166
+ d_context: int,
167
+ num_heads: int,
168
+ attn_dim: int,
169
+ use_rope: bool,
170
+ dropout: float = 0.0,
171
+ rope_gamma: float = 10.0,
172
+ attn_scale: Optional[float] = None,
173
+ rotary_base: float = 10000.0,
174
+ use_residual: bool = True,
175
+ ):
176
+ super().__init__()
177
+ assert attn_dim % num_heads == 0
178
+
179
+ self.d_model = d_model
180
+ self.num_heads = num_heads
181
+ self.head_dim = attn_dim // num_heads
182
+ self.attn_dim = attn_dim
183
+ self.use_rope = use_rope
184
+ self.use_residual = use_residual
185
+ self.rope_gamma = rope_gamma
186
+ self.attn_scale = attn_scale if attn_scale is not None else math.sqrt(self.attn_dim)
187
+
188
+ self.W_query = LinearWrapper(d_model, attn_dim)
189
+ self.W_key = LinearWrapper(d_context, attn_dim)
190
+ self.W_value = LinearWrapper(d_context, attn_dim)
191
+ self.out_fc = LinearWrapper(attn_dim, d_model)
192
+
193
+ self.dropout = nn.Dropout(dropout) if dropout > 0.0 else nn.Identity()
194
+
195
+ if use_rope:
196
+ inv_freq = 1.0 / (rotary_base ** (torch.arange(0, self.head_dim, 2, dtype=torch.float32) / self.head_dim))
197
+ theta = (inv_freq * rope_gamma).view(1, 1, -1)
198
+ self.register_buffer("theta", theta, persistent=True)
199
+ self.register_buffer("increments", torch.arange(1000).view(1, 1000, 1), persistent=True)
200
+ self.tanh = None
201
+ else:
202
+ self.theta = None
203
+ self.increments = None
204
+ self.tanh = nn.Tanh()
205
+
206
+ def forward(
207
+ self,
208
+ x: torch.Tensor,
209
+ context: torch.Tensor,
210
+ context_keys: Optional[torch.Tensor] = None,
211
+ x_mask: Optional[torch.Tensor] = None,
212
+ context_mask: Optional[torch.Tensor] = None,
213
+ ) -> torch.Tensor:
214
+ B, d_model, T = x.shape
215
+ L = context.shape[1]
216
+
217
+ x_t = x.transpose(1, 2)
218
+ q = self.W_query(x_t)
219
+
220
+ k_src = context_keys if context_keys is not None else context
221
+ k = self.W_key(k_src)
222
+ v = self.W_value(context)
223
+
224
+ if not self.use_rope and self.tanh is not None:
225
+ k = self.tanh(k)
226
+
227
+ H = self.num_heads
228
+ D = self.head_dim
229
+
230
+ q = q.view(B, T, H, D).permute(2, 0, 1, 3)
231
+ k = k.view(B, L, H, D).permute(2, 0, 1, 3)
232
+ v = v.view(B, L, H, D).permute(2, 0, 1, 3)
233
+
234
+ if self.use_rope:
235
+ device = x.device
236
+
237
+ if x_mask is not None:
238
+ len_q = x_mask.sum(dim=(-2, -1)).reshape(-1, 1, 1)
239
+ else:
240
+ len_q = torch.tensor([T], device=device, dtype=torch.float32).reshape(1, 1, 1)
241
+
242
+ if context_mask is not None:
243
+ len_k = context_mask.sum(dim=(-2, -1)).reshape(-1, 1, 1)
244
+ else:
245
+ len_k = torch.tensor([L], device=device, dtype=torch.float32).reshape(1, 1, 1)
246
+
247
+ if self.increments is not None and self.increments.shape[1] >= max(T, L):
248
+ pos_q = self.increments[:, :T, :].to(device).float()
249
+ pos_k = self.increments[:, :L, :].to(device).float()
250
+ else:
251
+ pos_q = torch.arange(T, device=device, dtype=torch.float32).reshape(1, -1, 1)
252
+ pos_k = torch.arange(L, device=device, dtype=torch.float32).reshape(1, -1, 1)
253
+
254
+ norm_pos_q = pos_q / len_q
255
+ norm_pos_k = pos_k / len_k
256
+
257
+ theta = self.theta if self.theta is not None else (
258
+ (1.0 / (10000 ** (torch.arange(0, D, 2, device=device).float() / D))) * self.rope_gamma
259
+ ).view(1, 1, -1)
260
+
261
+ freqs_q = norm_pos_q * theta
262
+ freqs_k = norm_pos_k * theta
263
+
264
+ cos_q, sin_q = freqs_q.cos(), freqs_q.sin()
265
+ cos_k, sin_k = freqs_k.cos(), freqs_k.sin()
266
+
267
+ cos_q, sin_q = cos_q.unsqueeze(0), sin_q.unsqueeze(0)
268
+ cos_k, sin_k = cos_k.unsqueeze(0), sin_k.unsqueeze(0)
269
+
270
+ q = apply_rotary_pos_emb(q, cos_q, sin_q)
271
+ k = apply_rotary_pos_emb(k, cos_k, sin_k)
272
+
273
+ attn_logits = torch.matmul(q, k.transpose(-1, -2)) / self.attn_scale
274
+
275
+ if context_mask is not None:
276
+ if context_mask.dim() == 2:
277
+ context_mask = context_mask.unsqueeze(1)
278
+ cm = (context_mask == 0)
279
+ attn_logits = attn_logits.masked_fill(cm.unsqueeze(0), float("-inf"))
280
+
281
+ attn = torch.softmax(attn_logits, dim=-1)
282
+
283
+ if x_mask is not None:
284
+ if x_mask.dim() == 2:
285
+ x_mask = x_mask.unsqueeze(1)
286
+ qm = (x_mask == 0).permute(1, 0, 2).unsqueeze(-1)
287
+ attn = attn.masked_fill(qm, 0.0)
288
+
289
+ out = torch.matmul(attn, v)
290
+ out = out.permute(1, 2, 0, 3).contiguous().view(B, T, self.attn_dim)
291
+ out = self.out_fc(out)
292
+ out = self.dropout(out)
293
+
294
+ if x_mask is not None:
295
+ out = out * x_mask.transpose(1, 2)
296
+
297
+ out = out.transpose(1, 2)
298
+ return out
299
+
300
+
301
+ class CrossAttentionBlock(nn.Module):
302
+ def __init__(
303
+ self,
304
+ d_model: int,
305
+ d_context: int,
306
+ num_heads: int = 8,
307
+ attn_dim: int = 256,
308
+ use_rope: bool = True,
309
+ rope_gamma: float = 10.0,
310
+ attn_scale: Optional[float] = None,
311
+ use_residual: bool = True,
312
+ rotary_base: float = 10000.0,
313
+ ):
314
+ super().__init__()
315
+ self.use_rope = use_rope
316
+ self.use_residual = use_residual
317
+ attn_module = AttentionModule(
318
+ d_model, d_context, num_heads, attn_dim, use_rope,
319
+ rope_gamma=rope_gamma, attn_scale=attn_scale, rotary_base=rotary_base, use_residual=use_residual,
320
+ )
321
+ # Checkpoint naming: text (RoPE) -> 'attn'; style (no RoPE) -> 'attention'.
322
+ if use_rope:
323
+ self.attn = attn_module
324
+ else:
325
+ self.attention = attn_module
326
+ self.norm = LayerNormWrapper(d_model)
327
+
328
+ def forward(
329
+ self,
330
+ x: torch.Tensor,
331
+ context: torch.Tensor,
332
+ context_keys: Optional[torch.Tensor],
333
+ x_mask: Optional[torch.Tensor],
334
+ context_mask: Optional[torch.Tensor],
335
+ ) -> torch.Tensor:
336
+ if x_mask is not None:
337
+ x = x * x_mask
338
+
339
+ residual = x
340
+
341
+ if self.use_rope:
342
+ attn_out = self.attn(x, context, context_keys, x_mask, context_mask)
343
+ else:
344
+ attn_out = self.attention(x, context, context_keys, x_mask, context_mask)
345
+
346
+ if self.use_residual:
347
+ x = residual + attn_out
348
+ else:
349
+ x = attn_out
350
+
351
+ x = self.norm(x)
352
+ if x_mask is not None:
353
+ x = x * x_mask
354
+ return x
355
+
356
+
357
+ class VectorFieldEstimator(nn.Module):
358
+ def __init__(
359
+ self,
360
+ in_channels: int = 144,
361
+ hidden_channels: int = 512,
362
+ out_channels: int = 144,
363
+ text_dim: int = 256,
364
+ style_dim: int = 256,
365
+ num_style_tokens: int = 50,
366
+ num_superblocks: int = 4,
367
+ time_embed_dim: int = 64,
368
+ rope_gamma: float = 10.0,
369
+ main_blocks_cfg: dict = None,
370
+ last_convnext_cfg: dict = None,
371
+ text_n_heads: int = 4,
372
+ time_hdim: int = 256,
373
+ use_residual: bool = True,
374
+ rotary_base: float = 10000.0,
375
+ ):
376
+ super().__init__()
377
+ self.in_channels = in_channels
378
+ self.hidden_channels = hidden_channels
379
+ self.out_channels = out_channels
380
+ self.text_dim = text_dim
381
+ self.style_dim = style_dim
382
+ self.rope_gamma = rope_gamma
383
+
384
+ # Shared tiled constant ([1, 50, 256]) consumed by every style-attn W_key.
385
+ self.tile = nn.Parameter(torch.randn(1, num_style_tokens, style_dim) * 0.02)
386
+
387
+ self.proj_in = ProjectionWrapper(in_channels, hidden_channels)
388
+ self.time_encoder = TimeEncoder(time_embed_dim, hdim=time_hdim)
389
+
390
+ self.main_blocks = nn.ModuleList()
391
+
392
+ shared_attn_scale = math.sqrt(256)
393
+
394
+ mb_cfg = main_blocks_cfg or {}
395
+ lc_cfg = last_convnext_cfg or {}
396
+
397
+ c0_cfg = mb_cfg.get("convnext_0", {})
398
+ c1_cfg = mb_cfg.get("convnext_1", {})
399
+ c2_cfg = mb_cfg.get("convnext_2", {})
400
+
401
+ for _ in range(num_superblocks):
402
+ self.main_blocks.append(
403
+ ConvNeXtStack(hidden_channels, kernel_size=c0_cfg.get("ksz", 5), dilations=c0_cfg.get("dilation_lst", [1, 2, 4, 8]))
404
+ )
405
+ self.main_blocks.append(
406
+ TimeCondBlock(time_dim=time_embed_dim, channels=hidden_channels)
407
+ )
408
+ self.main_blocks.append(
409
+ ConvNeXtStack(hidden_channels, kernel_size=c1_cfg.get("ksz", 5), dilations=c1_cfg.get("dilation_lst", [1]))
410
+ )
411
+ self.main_blocks.append(
412
+ CrossAttentionBlock(
413
+ d_model=hidden_channels,
414
+ d_context=text_dim,
415
+ num_heads=text_n_heads,
416
+ attn_dim=256,
417
+ use_rope=True,
418
+ rope_gamma=self.rope_gamma,
419
+ attn_scale=shared_attn_scale,
420
+ use_residual=use_residual,
421
+ rotary_base=rotary_base,
422
+ )
423
+ )
424
+ self.main_blocks.append(
425
+ ConvNeXtStack(hidden_channels, kernel_size=c2_cfg.get("ksz", 5), dilations=c2_cfg.get("dilation_lst", [1]))
426
+ )
427
+ self.main_blocks.append(
428
+ CrossAttentionBlock(
429
+ d_model=hidden_channels,
430
+ d_context=style_dim,
431
+ num_heads=2,
432
+ attn_dim=256,
433
+ use_rope=False,
434
+ attn_scale=shared_attn_scale,
435
+ use_residual=use_residual,
436
+ rotary_base=rotary_base,
437
+ )
438
+ )
439
+
440
+ self.last_convnext = ConvNeXtStack(
441
+ hidden_channels, kernel_size=lc_cfg.get("ksz", 5), dilations=lc_cfg.get("dilation_lst", [1, 1, 1, 1])
442
+ )
443
+ self.proj_out = ProjectionWrapper(hidden_channels, out_channels)
444
+
445
+ def _load_from_state_dict(self, state_dict, prefix, local_metadata, strict,
446
+ missing_keys, unexpected_keys, error_msgs):
447
+ # Back-compat: older checkpoints stored the tiled style-key under `style_key`.
448
+ legacy_key = prefix + "style_key"
449
+ new_key = prefix + "tile"
450
+ if legacy_key in state_dict and new_key not in state_dict:
451
+ state_dict[new_key] = state_dict.pop(legacy_key)
452
+ return super()._load_from_state_dict(
453
+ state_dict, prefix, local_metadata, strict,
454
+ missing_keys, unexpected_keys, error_msgs,
455
+ )
456
+
457
+ def forward(
458
+ self,
459
+ noisy_latent: torch.Tensor,
460
+ text_emb: torch.Tensor,
461
+ style_ttl: torch.Tensor,
462
+ latent_mask: torch.Tensor,
463
+ text_mask: torch.Tensor,
464
+ current_step: torch.Tensor,
465
+ total_step: Optional[torch.Tensor] = None,
466
+ ) -> torch.Tensor:
467
+ B = noisy_latent.shape[0]
468
+
469
+ if total_step is not None:
470
+ t_norm = current_step.reshape(B, 1, 1) / total_step.reshape(B, 1, 1)
471
+ reciprocal = 1.0 / total_step.reshape(B, 1, 1)
472
+ t_norm_flat = t_norm.reshape(B)
473
+ else:
474
+ t_norm_flat = current_step.reshape(B)
475
+
476
+ t_emb = self.time_encoder(t_norm_flat)
477
+ text_blc = text_emb.transpose(1, 2)
478
+
479
+ x = self.proj_in(noisy_latent)
480
+ x = x * latent_mask
481
+
482
+ for i, block in enumerate(self.main_blocks):
483
+ idx_in_super = i % 6
484
+ if idx_in_super == 0:
485
+ x = block(x, mask=latent_mask)
486
+ elif idx_in_super == 1:
487
+ x = block(x, t_emb)
488
+ x = x * latent_mask
489
+ elif idx_in_super == 2:
490
+ x = block(x, mask=latent_mask)
491
+ elif idx_in_super == 3:
492
+ x = block(x, context=text_blc, context_keys=None,
493
+ x_mask=latent_mask, context_mask=text_mask)
494
+ elif idx_in_super == 4:
495
+ x = block(x, mask=latent_mask)
496
+ elif idx_in_super == 5:
497
+ x = block(x, context=style_ttl,
498
+ context_keys=self.tile.expand(B, -1, -1),
499
+ x_mask=latent_mask, context_mask=None)
500
+
501
+ x = self.last_convnext(x, mask=latent_mask)
502
+ diff_out = self.proj_out(x) * latent_mask
503
+
504
+ if total_step is not None:
505
+ denoised = noisy_latent + reciprocal * diff_out
506
+ return denoised * latent_mask
507
+ return diff_out