ashishkblink commited on
Commit
3f78fdf
·
verified ·
1 Parent(s): fb2f21a

Upload f5_tts/model/backbones/mmdit.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. f5_tts/model/backbones/mmdit.py +146 -0
f5_tts/model/backbones/mmdit.py ADDED
@@ -0,0 +1,146 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ ein notation:
3
+ b - batch
4
+ n - sequence
5
+ nt - text sequence
6
+ nw - raw wave length
7
+ d - dimension
8
+ """
9
+
10
+ from __future__ import annotations
11
+
12
+ import torch
13
+ from torch import nn
14
+
15
+ from x_transformers.x_transformers import RotaryEmbedding
16
+
17
+ from f5_tts.model.modules import (
18
+ TimestepEmbedding,
19
+ ConvPositionEmbedding,
20
+ MMDiTBlock,
21
+ AdaLayerNormZero_Final,
22
+ precompute_freqs_cis,
23
+ get_pos_embed_indices,
24
+ )
25
+
26
+
27
+ # text embedding
28
+
29
+
30
+ class TextEmbedding(nn.Module):
31
+ def __init__(self, out_dim, text_num_embeds):
32
+ super().__init__()
33
+ self.text_embed = nn.Embedding(text_num_embeds + 1, out_dim) # will use 0 as filler token
34
+
35
+ self.precompute_max_pos = 1024
36
+ self.register_buffer("freqs_cis", precompute_freqs_cis(out_dim, self.precompute_max_pos), persistent=False)
37
+
38
+ def forward(self, text: int["b nt"], drop_text=False) -> int["b nt d"]: # noqa: F722
39
+ text = text + 1
40
+ if drop_text:
41
+ text = torch.zeros_like(text)
42
+ text = self.text_embed(text)
43
+
44
+ # sinus pos emb
45
+ batch_start = torch.zeros((text.shape[0],), dtype=torch.long)
46
+ batch_text_len = text.shape[1]
47
+ pos_idx = get_pos_embed_indices(batch_start, batch_text_len, max_pos=self.precompute_max_pos)
48
+ text_pos_embed = self.freqs_cis[pos_idx]
49
+
50
+ text = text + text_pos_embed
51
+
52
+ return text
53
+
54
+
55
+ # noised input & masked cond audio embedding
56
+
57
+
58
+ class AudioEmbedding(nn.Module):
59
+ def __init__(self, in_dim, out_dim):
60
+ super().__init__()
61
+ self.linear = nn.Linear(2 * in_dim, out_dim)
62
+ self.conv_pos_embed = ConvPositionEmbedding(out_dim)
63
+
64
+ def forward(self, x: float["b n d"], cond: float["b n d"], drop_audio_cond=False): # noqa: F722
65
+ if drop_audio_cond:
66
+ cond = torch.zeros_like(cond)
67
+ x = torch.cat((x, cond), dim=-1)
68
+ x = self.linear(x)
69
+ x = self.conv_pos_embed(x) + x
70
+ return x
71
+
72
+
73
+ # Transformer backbone using MM-DiT blocks
74
+
75
+
76
+ class MMDiT(nn.Module):
77
+ def __init__(
78
+ self,
79
+ *,
80
+ dim,
81
+ depth=8,
82
+ heads=8,
83
+ dim_head=64,
84
+ dropout=0.1,
85
+ ff_mult=4,
86
+ text_num_embeds=256,
87
+ mel_dim=100,
88
+ ):
89
+ super().__init__()
90
+
91
+ self.time_embed = TimestepEmbedding(dim)
92
+ self.text_embed = TextEmbedding(dim, text_num_embeds)
93
+ self.audio_embed = AudioEmbedding(mel_dim, dim)
94
+
95
+ self.rotary_embed = RotaryEmbedding(dim_head)
96
+
97
+ self.dim = dim
98
+ self.depth = depth
99
+
100
+ self.transformer_blocks = nn.ModuleList(
101
+ [
102
+ MMDiTBlock(
103
+ dim=dim,
104
+ heads=heads,
105
+ dim_head=dim_head,
106
+ dropout=dropout,
107
+ ff_mult=ff_mult,
108
+ context_pre_only=i == depth - 1,
109
+ )
110
+ for i in range(depth)
111
+ ]
112
+ )
113
+ self.norm_out = AdaLayerNormZero_Final(dim) # final modulation
114
+ self.proj_out = nn.Linear(dim, mel_dim)
115
+
116
+ def forward(
117
+ self,
118
+ x: float["b n d"], # nosied input audio # noqa: F722
119
+ cond: float["b n d"], # masked cond audio # noqa: F722
120
+ text: int["b nt"], # text # noqa: F722
121
+ time: float["b"] | float[""], # time step # noqa: F821 F722
122
+ drop_audio_cond, # cfg for cond audio
123
+ drop_text, # cfg for text
124
+ mask: bool["b n"] | None = None, # noqa: F722
125
+ ):
126
+ batch = x.shape[0]
127
+ if time.ndim == 0:
128
+ time = time.repeat(batch)
129
+
130
+ # t: conditioning (time), c: context (text + masked cond audio), x: noised input audio
131
+ t = self.time_embed(time)
132
+ c = self.text_embed(text, drop_text=drop_text)
133
+ x = self.audio_embed(x, cond, drop_audio_cond=drop_audio_cond)
134
+
135
+ seq_len = x.shape[1]
136
+ text_len = text.shape[1]
137
+ rope_audio = self.rotary_embed.forward_from_seq_len(seq_len)
138
+ rope_text = self.rotary_embed.forward_from_seq_len(text_len)
139
+
140
+ for block in self.transformer_blocks:
141
+ c, x = block(x, c, t, mask=mask, rope=rope_audio, c_rope=rope_text)
142
+
143
+ x = self.norm_out(x, t)
144
+ output = self.proj_out(x)
145
+
146
+ return output