gratias98 commited on
Commit
ecd73f4
·
verified ·
1 Parent(s): 3c74bb7

Update tts.py

Browse files
Files changed (1) hide show
  1. tts.py +169 -106
tts.py CHANGED
@@ -1,117 +1,180 @@
1
- import os, re, tempfile, torch, sys
 
 
 
 
 
 
 
 
 
 
2
  import numpy as np
3
- import psutil
4
  from huggingface_hub import hf_hub_download
5
 
 
6
  if "vits" not in sys.path:
7
- sys.path.append("vits")
8
 
9
- from vits import commons, utils
10
  from vits.models import SynthesizerTrn
11
 
12
- # Load languages
13
  TTS_LANGUAGES = {}
14
- with open("data/tts/all_langs.tsv") as f:
15
- TTS_LANGUAGES = {line.split(" ",1)[0].strip(): line.split(" ",1)[1].strip() for line in f}
16
-
17
- class TextMapper:
18
- def __init__(self, vocab_file):
19
- self.symbols = [x.strip() for x in open(vocab_file, encoding="utf-8")]
20
- self.SPACE_ID = self.symbols.index(" ")
21
- self._symbol_to_id = {s:i for i,s in enumerate(self.symbols)}
22
- self._id_to_symbol = {i:s for i,s in enumerate(self.symbols)}
23
-
24
- def text_to_sequence(self, text, cleaner_names):
25
- return [self._symbol_to_id[s] for s in text.strip()]
26
-
27
- def uromanize(self, text, uroman_pl):
28
- with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2:
29
- with open(tf.name, "w") as f:
30
- f.write(text + "\n")
31
- os.system(f"perl {uroman_pl} -l xxx < {tf.name} > {tf2.name}")
32
- with open(tf2.name) as f:
33
- return re.sub(r"\s+", " ", f.read()).strip()
34
-
35
- def get_text(self, text, hps):
36
- text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
37
- if hps.data.add_blank:
38
- text_norm = commons.intersperse(text_norm, 0)
39
- return torch.LongTensor(text_norm)
40
-
41
- def filter_oov(self, text, lang=None):
42
- text = text.replace("ț", "ţ") if lang == "ron" else text
43
- return "".join(c for c in text if c in self._symbol_to_id)
44
-
45
- def synthesize(text=None, lang=None, speed=1.0):
46
- # Memory check
47
- if psutil.virtual_memory().percent > 85:
48
- raise RuntimeError("System memory usage too high")
49
-
50
- lang_code = lang.split()[0].strip()
51
-
52
- # Download model files
53
- model_dir = f"models/{lang_code}"
54
- files = {
55
- "vocab": hf_hub_download("facebook/mms-tts", "vocab.txt", subfolder=model_dir),
56
- "config": hf_hub_download("facebook/mms-tts", "config.json", subfolder=model_dir),
57
- "model": hf_hub_download("facebook/mms-tts", "G_100000.pth", subfolder=model_dir)
58
- }
59
-
60
- # Setup device
61
- device = torch.device("cuda" if torch.cuda.is_available() else
62
- "mps" if hasattr(torch.backends, "mps") and
63
- torch.backends.mps.is_available() and
64
- torch.backends.mps.is_built() else "cpu")
65
-
66
- # Initialize model
67
- hps = utils.get_hparams_from_file(files["config"])
68
- text_mapper = TextMapper(files["vocab"])
69
- net_g = SynthesizerTrn(
70
- len(text_mapper.symbols),
71
- hps.data.filter_length // 2 + 1,
72
- hps.train.segment_size // hps.data.hop_length,
73
- **hps.model
74
- ).to(device).eval()
75
-
76
- utils.load_checkpoint(files["model"], net_g, None)
77
-
78
- # Process text
79
- if hps.data.training_files.endswith(".uroman"):
80
- text = text_mapper.uromanize(text, os.path.join("uroman", "bin", "uroman.pl"))
81
-
82
- text = text_mapper.filter_oov(text.lower(), lang=lang)
83
- stn_tst = text_mapper.get_text(text, hps)
84
-
85
- # Generate audio
86
- try:
87
- with torch.no_grad():
88
- x_tst = stn_tst.unsqueeze(0).to(device)
89
- x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
90
-
91
- hyp = net_g.infer(
92
- x_tst,
93
- x_tst_lengths,
94
- noise_scale=0.667,
95
- noise_scale_w=0.8,
96
- length_scale=1.0/speed
97
- )[0][0,0].cpu().float().numpy()
98
-
99
- # Cleanup
100
- torch.cuda.empty_cache() if device.type == "cuda" else None
101
-
102
- return (hps.data.sampling_rate, hyp), text
103
-
104
- except RuntimeError as e:
105
- if "out of memory" in str(e):
106
- torch.cuda.empty_cache()
107
- device = torch.device("cpu")
108
- return synthesize(text, lang, speed)
109
- raise e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
110
 
111
  TTS_EXAMPLES = [
112
- ["I am going to the store.", "eng (English)", 1.0],
113
- ["안녕하세요.", "kor (Korean)", 1.0],
114
- ["क्या मुझे पीने का पानी मिल सकता है?", "hin (Hindi)", 1.0],
115
- ["Tanış olmağıma çox şadam", "azj-script_latin (Azerbaijani, North)", 1.0],
116
- ["Mu zo murna a cikin ƙasar.", "hau (Hausa)", 1.0]
117
  ]
 
1
+ # Copyright (c) Facebook, Inc. and its affiliates.
2
+ #
3
+ # This source code is licensed under the MIT license found in the
4
+ # LICENSE file in the root directory of this source tree.
5
+
6
+ import os
7
+ import re
8
+ import tempfile
9
+ import torch
10
+ import sys
11
+ import gradio as gr
12
  import numpy as np
13
+
14
  from huggingface_hub import hf_hub_download
15
 
16
+ # Setup TTS env
17
  if "vits" not in sys.path:
18
+ sys.path.append("vits")
19
 
20
+ from vits import commons, utils
21
  from vits.models import SynthesizerTrn
22
 
23
+
24
  TTS_LANGUAGES = {}
25
+ with open(f"data/tts/all_langs.tsv") as f:
26
+ for line in f:
27
+ iso, name = line.split(" ", 1)
28
+ TTS_LANGUAGES[iso.strip()] = name.strip()
29
+
30
+
31
+ class TextMapper(object):
32
+ def __init__(self, vocab_file):
33
+ self.symbols = [
34
+ x.replace("\n", "") for x in open(vocab_file, encoding="utf-8").readlines()
35
+ ]
36
+ self.SPACE_ID = self.symbols.index(" ")
37
+ self._symbol_to_id = {s: i for i, s in enumerate(self.symbols)}
38
+ self._id_to_symbol = {i: s for i, s in enumerate(self.symbols)}
39
+
40
+ def text_to_sequence(self, text, cleaner_names):
41
+ """Converts a string of text to a sequence of IDs corresponding to the symbols in the text.
42
+ Args:
43
+ text: string to convert to a sequence
44
+ cleaner_names: names of the cleaner functions to run the text through
45
+ Returns:
46
+ List of integers corresponding to the symbols in the text
47
+ """
48
+ sequence = []
49
+ clean_text = text.strip()
50
+ for symbol in clean_text:
51
+ symbol_id = self._symbol_to_id[symbol]
52
+ sequence += [symbol_id]
53
+ return sequence
54
+
55
+ def uromanize(self, text, uroman_pl):
56
+ iso = "xxx"
57
+ with tempfile.NamedTemporaryFile() as tf, tempfile.NamedTemporaryFile() as tf2:
58
+ with open(tf.name, "w") as f:
59
+ f.write("\n".join([text]))
60
+ cmd = f"perl " + uroman_pl
61
+ cmd += f" -l {iso} "
62
+ cmd += f" < {tf.name} > {tf2.name}"
63
+ os.system(cmd)
64
+ outtexts = []
65
+ with open(tf2.name) as f:
66
+ for line in f:
67
+ line = re.sub(r"\s+", " ", line).strip()
68
+ outtexts.append(line)
69
+ outtext = outtexts[0]
70
+ return outtext
71
+
72
+ def get_text(self, text, hps):
73
+ text_norm = self.text_to_sequence(text, hps.data.text_cleaners)
74
+ if hps.data.add_blank:
75
+ text_norm = commons.intersperse(text_norm, 0)
76
+ text_norm = torch.LongTensor(text_norm)
77
+ return text_norm
78
+
79
+ def filter_oov(self, text, lang=None):
80
+ text = self.preprocess_char(text, lang=lang)
81
+ val_chars = self._symbol_to_id
82
+ txt_filt = "".join(list(filter(lambda x: x in val_chars, text)))
83
+ return txt_filt
84
+
85
+ def preprocess_char(self, text, lang=None):
86
+ """
87
+ Special treatement of characters in certain languages
88
+ """
89
+ if lang == "ron":
90
+ text = text.replace("ț", "ţ")
91
+ print(f"{lang} (ț -> ţ): {text}")
92
+ return text
93
+
94
+
95
+ def synthesize(text=None, lang=None, speed=None):
96
+ if speed is None:
97
+ speed = 1.0
98
+
99
+ lang_code = lang.split()[0].strip()
100
+
101
+ vocab_file = hf_hub_download(
102
+ repo_id="facebook/mms-tts",
103
+ filename="vocab.txt",
104
+ subfolder=f"models/{lang_code}",
105
+ )
106
+ config_file = hf_hub_download(
107
+ repo_id="facebook/mms-tts",
108
+ filename="config.json",
109
+ subfolder=f"models/{lang_code}",
110
+ )
111
+ g_pth = hf_hub_download(
112
+ repo_id="facebook/mms-tts",
113
+ filename="G_100000.pth",
114
+ subfolder=f"models/{lang_code}",
115
+ )
116
+
117
+ if torch.cuda.is_available():
118
+ device = torch.device("cuda")
119
+ elif (
120
+ hasattr(torch.backends, "mps")
121
+ and torch.backends.mps.is_available()
122
+ and torch.backends.mps.is_built()
123
+ ):
124
+ device = torch.device("mps")
125
+ else:
126
+ device = torch.device("cpu")
127
+
128
+ print(f"Run inference with {device}")
129
+
130
+ assert os.path.isfile(config_file), f"{config_file} doesn't exist"
131
+ hps = utils.get_hparams_from_file(config_file)
132
+ text_mapper = TextMapper(vocab_file)
133
+ net_g = SynthesizerTrn(
134
+ len(text_mapper.symbols),
135
+ hps.data.filter_length // 2 + 1,
136
+ hps.train.segment_size // hps.data.hop_length,
137
+ **hps.model,
138
+ )
139
+ net_g.to(device)
140
+ _ = net_g.eval()
141
+
142
+ _ = utils.load_checkpoint(g_pth, net_g, None)
143
+
144
+ is_uroman = hps.data.training_files.split(".")[-1] == "uroman"
145
+
146
+ if is_uroman:
147
+ uroman_dir = "uroman"
148
+ assert os.path.exists(uroman_dir)
149
+ uroman_pl = os.path.join(uroman_dir, "bin", "uroman.pl")
150
+ text = text_mapper.uromanize(text, uroman_pl)
151
+
152
+ text = text.lower()
153
+ text = text_mapper.filter_oov(text, lang=lang)
154
+ stn_tst = text_mapper.get_text(text, hps)
155
+ with torch.no_grad():
156
+ x_tst = stn_tst.unsqueeze(0).to(device)
157
+ x_tst_lengths = torch.LongTensor([stn_tst.size(0)]).to(device)
158
+ hyp = (
159
+ net_g.infer(
160
+ x_tst,
161
+ x_tst_lengths,
162
+ noise_scale=0.667,
163
+ noise_scale_w=0.8,
164
+ length_scale=1.0 / speed,
165
+ )[0][0, 0]
166
+ .cpu()
167
+ .float()
168
+ .numpy()
169
+ )
170
+
171
+ return (hps.data.sampling_rate, hyp), text
172
+
173
 
174
  TTS_EXAMPLES = [
175
+ ["I am going to the store.", "eng (English)", 1.0],
176
+ ["안녕하세요.", "kor (Korean)", 1.0],
177
+ ["क्या मुझे पीने का पानी मिल सकता है?", "hin (Hindi)", 1.0],
178
+ ["Tanış olmağıma çox şadam", "azj-script_latin (Azerbaijani, North)", 1.0],
179
+ ["Mu zo murna a cikin ƙasar.", "hau (Hausa)", 1.0],
180
  ]