utkarsh2299 commited on
Commit
d4fe7cc
·
verified ·
1 Parent(s): e330495

Upload inference_w_sil_alpha.py

Browse files
Files changed (1) hide show
  1. inference_w_sil_alpha.py +209 -0
inference_w_sil_alpha.py ADDED
@@ -0,0 +1,209 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import sys
2
+ import os
3
+ #replace the path with your hifigan path to import Generator from models.py
4
+ sys.path.append("hifigan")
5
+ import argparse
6
+ import torch
7
+ from espnet2.bin.tts_inference import Text2Speech
8
+ from models import Generator
9
+ from scipy.io.wavfile import write
10
+ from meldataset import MAX_WAV_VALUE
11
+ from env import AttrDict
12
+ import json
13
+ import yaml
14
+ import concurrent.futures
15
+ import numpy as np
16
+ import time
17
+ import re
18
+
19
+ from text_preprocess_for_inference import TTSDurAlignPreprocessor, CharTextPreprocessor, TTSPreprocessor
20
+
21
+ SAMPLING_RATE = 22050
22
+
23
+ def load_hifigan_vocoder(language, gender, device):
24
+ # Load HiFi-GAN vocoder configuration file and generator model for the specified language and gender
25
+ vocoder_config = f"vocoder/{gender}/{language}/config.json"
26
+ vocoder_generator = f"vocoder/{gender}/{language}/generator"
27
+ # Read the contents of the vocoder configuration file
28
+ with open(vocoder_config, 'r') as f:
29
+ data = f.read()
30
+ json_config = json.loads(data)
31
+ h = AttrDict(json_config)
32
+ torch.manual_seed(h.seed)
33
+ # Move the generator model to the specified device (CPU or GPU)
34
+ device = torch.device(device)
35
+ generator = Generator(h).to(device)
36
+ state_dict_g = torch.load(vocoder_generator, device)
37
+ generator.load_state_dict(state_dict_g['generator'])
38
+ generator.eval()
39
+ generator.remove_weight_norm()
40
+
41
+ # Return the loaded and prepared HiFi-GAN generator model
42
+ return generator
43
+
44
+
45
+ def load_fastspeech2_model(language, gender, device):
46
+
47
+ #updating the config.yaml fiel based on language and gender
48
+ with open(f"{language}/{gender}/model/config.yaml", "r") as file:
49
+ config = yaml.safe_load(file)
50
+
51
+ current_working_directory = os.getcwd()
52
+ feat="model/feats_stats.npz"
53
+ pitch="model/pitch_stats.npz"
54
+ energy="model/energy_stats.npz"
55
+
56
+ feat_path=os.path.join(current_working_directory,language,gender,feat)
57
+ pitch_path=os.path.join(current_working_directory,language,gender,pitch)
58
+ energy_path=os.path.join(current_working_directory,language,gender,energy)
59
+
60
+
61
+ config["normalize_conf"]["stats_file"] = feat_path
62
+ config["pitch_normalize_conf"]["stats_file"] = pitch_path
63
+ config["energy_normalize_conf"]["stats_file"] = energy_path
64
+
65
+ with open(f"{language}/{gender}/model/config.yaml", "w") as file:
66
+ yaml.dump(config, file)
67
+
68
+ tts_model = f"{language}/{gender}/model/model.pth"
69
+ tts_config = f"{language}/{gender}/model/config.yaml"
70
+
71
+
72
+ return Text2Speech(train_config=tts_config, model_file=tts_model, device=device)
73
+
74
+ def text_synthesis(language, gender, sample_text, vocoder, model, MAX_WAV_VALUE, device, alpha):
75
+ # Perform Text-to-Speech synthesis
76
+ with torch.no_grad():
77
+ # Load the FastSpeech2 model for the specified language and gender
78
+
79
+ # model = load_fastspeech2_model(language, gender, device)
80
+
81
+
82
+ # Generate mel-spectrograms from the input text using the FastSpeech2 model
83
+ out = model(sample_text, decode_conf={"alpha": alpha})
84
+ print("TTS Done")
85
+ x = out["feat_gen_denorm"].T.unsqueeze(0) * 2.3262
86
+ x = x.to(device)
87
+
88
+ # Use the HiFi-GAN vocoder to convert mel-spectrograms to raw audio waveforms
89
+ y_g_hat = vocoder(x)
90
+ audio = y_g_hat.squeeze()
91
+ audio = audio * MAX_WAV_VALUE
92
+ audio = audio.cpu().numpy().astype('int16')
93
+
94
+ # Return the synthesized audio
95
+ return audio
96
+
97
+ def split_into_chunks(text, words_per_chunk=100):
98
+ words = text.split()
99
+ chunks = [words[i:i + words_per_chunk] for i in range(0, len(words), words_per_chunk)]
100
+ return [' '.join(chunk) for chunk in chunks]
101
+
102
+
103
+
104
+
105
+ def extract_text_alpha_chunks(text, default_alpha=1.0):
106
+ alpha_pattern = r"<alpha=([0-9.]+)>"
107
+ sil_pattern = r"<sil=([0-9.]+)(ms|s)>"
108
+
109
+ chunks = []
110
+ alpha = default_alpha
111
+
112
+ alpha_blocks = re.split(alpha_pattern, text)
113
+ i = 0
114
+ while i < len(alpha_blocks):
115
+ if i == 0:
116
+ current_block = alpha_blocks[i]
117
+ i += 1
118
+ else:
119
+ alpha = float(alpha_blocks[i])
120
+ i += 1
121
+ current_block = alpha_blocks[i] if i < len(alpha_blocks) else ""
122
+ i += 1
123
+
124
+ sil_matches = list(re.finditer(sil_pattern, current_block))
125
+ sil_placeholders = {}
126
+ for j, match in enumerate(sil_matches):
127
+ tag = match.group(0)
128
+ value = float(match.group(1))
129
+ unit = match.group(2)
130
+ duration = value / 1000.0 if unit == "ms" else value
131
+ placeholder = f"__SIL_{j}__"
132
+ sil_placeholders[placeholder] = duration
133
+ current_block = current_block.replace(tag, f" {placeholder} ")
134
+
135
+ sentences = [s.strip() for s in current_block.split('.') if s.strip()]
136
+ for sentence in sentences:
137
+ words = sentence.split()
138
+ buffer = []
139
+ for word in words:
140
+ if word in sil_placeholders:
141
+ if buffer:
142
+ chunks.append((" ".join(buffer), alpha, False, None))
143
+ buffer = []
144
+ chunks.append(("", alpha, True, sil_placeholders[word]))
145
+ else:
146
+ buffer.append(word)
147
+ if buffer:
148
+ chunks.append((" ".join(buffer), alpha, False, None))
149
+ return chunks
150
+
151
+
152
+
153
+ if __name__ == "__main__":
154
+ parser = argparse.ArgumentParser(description="Text-to-Speech Inference")
155
+ parser.add_argument("--language", type=str, required=True, help="Language (e.g., hindi)")
156
+ parser.add_argument("--gender", type=str, required=True, help="Gender (e.g., female)")
157
+ parser.add_argument("--sample_text", type=str, required=True, help="Text to be synthesized")
158
+ parser.add_argument("--output_file", type=str, default="", help="Output WAV file path")
159
+ parser.add_argument("--alpha", type=float, default=1, help="Alpha Parameter for speed control (e.g. 1.1 (slow) or 0.8 (fast))")
160
+
161
+ args = parser.parse_args()
162
+
163
+ phone_dictionary = {}
164
+ # Set the device
165
+ device = "cuda" if torch.cuda.is_available() else "cpu"
166
+
167
+ # Load the HiFi-GAN vocoder with dynamic language and gender
168
+ vocoder = load_hifigan_vocoder(args.language, args.gender, device)
169
+ model = load_fastspeech2_model(args.language, args.gender, device)
170
+ if args.language == "urdu" or args.language == "punjabi":
171
+ preprocessor = CharTextPreprocessor()
172
+ elif args.language == "english":
173
+ preprocessor = TTSPreprocessor()
174
+ else:
175
+ preprocessor = TTSDurAlignPreprocessor()
176
+
177
+
178
+
179
+ start_time = time.time()
180
+ audio_arr = []
181
+ result = split_into_chunks(args.sample_text)
182
+ text_alpha_chunks = extract_text_alpha_chunks(args.sample_text, args.alpha)
183
+
184
+ with concurrent.futures.ThreadPoolExecutor() as executor:
185
+ futures = []
186
+ for chunk_text, alpha_val, is_silence, sil_duration in text_alpha_chunks:
187
+ if is_silence:
188
+ silence_samples = int(sil_duration * SAMPLING_RATE)
189
+ silence_audio = np.zeros(silence_samples, dtype=np.int16)
190
+ futures.append(silence_audio)
191
+ else:
192
+ preprocessed_text, _ = preprocessor.preprocess(chunk_text, args.language, args.gender, phone_dictionary)
193
+ preprocessed_text = " ".join(preprocessed_text)
194
+ future = executor.submit(
195
+ text_synthesis, args.language, args.gender, preprocessed_text,
196
+ vocoder, model, MAX_WAV_VALUE, device, alpha_val
197
+ )
198
+ futures.append(future)
199
+
200
+ for item in futures:
201
+ if isinstance(item, np.ndarray):
202
+ audio_arr.append(item)
203
+ else:
204
+ audio_arr.append(item.result())
205
+
206
+ result_array = np.concatenate(audio_arr, axis=0)
207
+ output_file = args.output_file if args.output_file else f"{args.language}_{args.gender}_output.wav"
208
+ write(output_file, SAMPLING_RATE, result_array)
209
+ print(f"Synthesis completed in {time.time()-start_time:.2f} sec → {output_file}")