ashishkblink commited on
Commit
6972ead
·
verified ·
1 Parent(s): 441f260

Upload f5_tts/model/utils.py with huggingface_hub

Browse files
Files changed (1) hide show
  1. f5_tts/model/utils.py +191 -0
f5_tts/model/utils.py ADDED
@@ -0,0 +1,191 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ from __future__ import annotations
2
+
3
+ import os
4
+ import random
5
+ from collections import defaultdict
6
+ from importlib.resources import files
7
+
8
+ import torch
9
+ from torch.nn.utils.rnn import pad_sequence
10
+
11
+ import jieba
12
+ from pypinyin import lazy_pinyin, Style
13
+
14
+
15
+ # seed everything
16
+
17
+
18
+ def seed_everything(seed=0):
19
+ random.seed(seed)
20
+ os.environ["PYTHONHASHSEED"] = str(seed)
21
+ torch.manual_seed(seed)
22
+ torch.cuda.manual_seed(seed)
23
+ torch.cuda.manual_seed_all(seed)
24
+ torch.backends.cudnn.deterministic = True
25
+ torch.backends.cudnn.benchmark = False
26
+
27
+
28
+ # helpers
29
+
30
+
31
+ def exists(v):
32
+ return v is not None
33
+
34
+
35
+ def default(v, d):
36
+ return v if exists(v) else d
37
+
38
+
39
+ # tensor helpers
40
+
41
+
42
+ def lens_to_mask(t: int["b"], length: int | None = None) -> bool["b n"]: # noqa: F722 F821
43
+ if not exists(length):
44
+ length = t.amax()
45
+
46
+ seq = torch.arange(length, device=t.device)
47
+ return seq[None, :] < t[:, None]
48
+
49
+
50
+ def mask_from_start_end_indices(seq_len: int["b"], start: int["b"], end: int["b"]): # noqa: F722 F821
51
+ max_seq_len = seq_len.max().item()
52
+ seq = torch.arange(max_seq_len, device=start.device).long()
53
+ start_mask = seq[None, :] >= start[:, None]
54
+ end_mask = seq[None, :] < end[:, None]
55
+ return start_mask & end_mask
56
+
57
+
58
+ def mask_from_frac_lengths(seq_len: int["b"], frac_lengths: float["b"]): # noqa: F722 F821
59
+ lengths = (frac_lengths * seq_len).long()
60
+ max_start = seq_len - lengths
61
+
62
+ rand = torch.rand_like(frac_lengths)
63
+ start = (max_start * rand).long().clamp(min=0)
64
+ end = start + lengths
65
+
66
+ return mask_from_start_end_indices(seq_len, start, end)
67
+
68
+
69
+ def maybe_masked_mean(t: float["b n d"], mask: bool["b n"] = None) -> float["b d"]: # noqa: F722
70
+ if not exists(mask):
71
+ return t.mean(dim=1)
72
+
73
+ t = torch.where(mask[:, :, None], t, torch.tensor(0.0, device=t.device))
74
+ num = t.sum(dim=1)
75
+ den = mask.float().sum(dim=1)
76
+
77
+ return num / den.clamp(min=1.0)
78
+
79
+
80
+ # simple utf-8 tokenizer, since paper went character based
81
+ def list_str_to_tensor(text: list[str], padding_value=-1) -> int["b nt"]: # noqa: F722
82
+ list_tensors = [torch.tensor([*bytes(t, "UTF-8")]) for t in text] # ByT5 style
83
+ text = pad_sequence(list_tensors, padding_value=padding_value, batch_first=True)
84
+ return text
85
+
86
+
87
+ # char tokenizer, based on custom dataset's extracted .txt file
88
+ def list_str_to_idx(
89
+ text: list[str] | list[list[str]],
90
+ vocab_char_map: dict[str, int], # {char: idx}
91
+ padding_value=-1,
92
+ ) -> int["b nt"]: # noqa: F722
93
+ list_idx_tensors = [torch.tensor([vocab_char_map.get(c, 0) for c in t]) for t in text] # pinyin or char style
94
+ text = pad_sequence(list_idx_tensors, padding_value=padding_value, batch_first=True)
95
+ return text
96
+
97
+
98
+ # Get tokenizer
99
+
100
+
101
+ def get_tokenizer(dataset_name, tokenizer: str = "pinyin"):
102
+ """
103
+ tokenizer - "pinyin" do g2p for only chinese characters, need .txt vocab_file
104
+ - "char" for char-wise tokenizer, need .txt vocab_file
105
+ - "byte" for utf-8 tokenizer
106
+ - "custom" if you're directly passing in a path to the vocab.txt you want to use
107
+ vocab_size - if use "pinyin", all available pinyin types, common alphabets (also those with accent) and symbols
108
+ - if use "char", derived from unfiltered character & symbol counts of custom dataset
109
+ - if use "byte", set to 256 (unicode byte range)
110
+ """
111
+ if tokenizer in ["pinyin", "char"]:
112
+ tokenizer_path = os.path.join(files("f5_tts").joinpath("../../data"), f"{dataset_name}_{tokenizer}/vocab.txt")
113
+ with open(tokenizer_path, "r", encoding="utf-8") as f:
114
+ vocab_char_map = {}
115
+ for i, char in enumerate(f):
116
+ vocab_char_map[char[:-1]] = i
117
+ vocab_size = len(vocab_char_map)
118
+ assert vocab_char_map[" "] == 0, "make sure space is of idx 0 in vocab.txt, cuz 0 is used for unknown char"
119
+
120
+ elif tokenizer == "byte":
121
+ vocab_char_map = None
122
+ vocab_size = 256
123
+
124
+ elif tokenizer == "custom":
125
+ with open(dataset_name, "r", encoding="utf-8") as f:
126
+ vocab_char_map = {}
127
+ for i, char in enumerate(f):
128
+ vocab_char_map[char[:-1]] = i
129
+ vocab_size = len(vocab_char_map)
130
+
131
+ return vocab_char_map, vocab_size
132
+
133
+
134
+ # convert char to pinyin
135
+
136
+ jieba.initialize()
137
+ print("Word segmentation module jieba initialized.\n")
138
+
139
+
140
+ def convert_char_to_pinyin(text_list, polyphone=True):
141
+ final_text_list = []
142
+ custom_trans = str.maketrans(
143
+ {";": ",", "“": '"', "”": '"', "‘": "'", "’": "'"}
144
+ ) # add custom trans here, to address oov
145
+
146
+ def is_chinese(c):
147
+ return (
148
+ "\u3100" <= c <= "\u9fff" # common chinese characters
149
+ )
150
+
151
+ for text in text_list:
152
+ char_list = []
153
+ text = text.translate(custom_trans)
154
+ for seg in jieba.cut(text):
155
+ seg_byte_len = len(bytes(seg, "UTF-8"))
156
+ if seg_byte_len == len(seg): # if pure alphabets and symbols
157
+ if char_list and seg_byte_len > 1 and char_list[-1] not in " :'\"":
158
+ char_list.append(" ")
159
+ char_list.extend(seg)
160
+ elif polyphone and seg_byte_len == 3 * len(seg): # if pure east asian characters
161
+ seg_ = lazy_pinyin(seg, style=Style.TONE3, tone_sandhi=True)
162
+ for i, c in enumerate(seg):
163
+ if is_chinese(c):
164
+ char_list.append(" ")
165
+ char_list.append(seg_[i])
166
+ else: # if mixed characters, alphabets and symbols
167
+ for c in seg:
168
+ if ord(c) < 256:
169
+ char_list.extend(c)
170
+ elif is_chinese(c):
171
+ char_list.append(" ")
172
+ char_list.extend(lazy_pinyin(c, style=Style.TONE3, tone_sandhi=True))
173
+ else:
174
+ char_list.append(c)
175
+ final_text_list.append(char_list)
176
+
177
+ return final_text_list
178
+
179
+
180
+ # filter func for dirty data with many repetitions
181
+
182
+
183
+ def repetition_found(text, length=2, tolerance=10):
184
+ pattern_count = defaultdict(int)
185
+ for i in range(len(text) - length + 1):
186
+ pattern = text[i : i + length]
187
+ pattern_count[pattern] += 1
188
+ for pattern, count in pattern_count.items():
189
+ if count > tolerance:
190
+ return True
191
+ return False