Spaces:
Runtime error
Runtime error
| import os | |
| import subprocess | |
| import logging | |
| import torch | |
| import gradio as gr | |
| from scipy.io.wavfile import write | |
| # 自定义模块 | |
| import commons | |
| import utils | |
| from data_utils import TextAudioLoader, TextAudioCollate, TextAudioSpeakerLoader, TextAudioSpeakerCollate | |
| from models import SynthesizerTrn | |
| from text.symbols import symbols | |
| from text import text_to_sequence | |
| # 配置日志 | |
| logging.basicConfig(level=logging.INFO) | |
| logger = logging.getLogger(__name__) | |
| # 编译 monotonic_align 模块 | |
| def compile_monotonic_align(): | |
| try: | |
| os.system('cd monotonic_align && python setup.py build_ext --inplace && cd ..') | |
| logger.info("Successfully compiled monotonic_align.") | |
| except subprocess.CalledProcessError as e: | |
| logger.error(f"Failed to compile monotonic_align: {e}") | |
| raise RuntimeError("Compilation of monotonic_align failed.") | |
| # 加载配置和模型 | |
| def load_config_and_model(config_path, checkpoint_path): | |
| if not os.path.exists(config_path): | |
| raise FileNotFoundError(f"Config file not found: {config_path}") | |
| if not os.path.exists(checkpoint_path): | |
| raise FileNotFoundError(f"Checkpoint file not found: {checkpoint_path}") | |
| # 加载超参数 | |
| hps = utils.get_hparams_from_file(config_path) | |
| logger.info("Loaded hyperparameters from config file.") | |
| # 初始化模型 | |
| net_g = SynthesizerTrn( | |
| len(symbols), | |
| hps.data.filter_length // 2 + 1, | |
| hps.train.segment_size // hps.data.hop_length, | |
| **hps.model, | |
| ) | |
| net_g.eval() | |
| logger.info("Initialized SynthesizerTrn model.") | |
| # 加载预训练权重 | |
| utils.load_checkpoint(checkpoint_path, net_g, None) | |
| logger.info(f"Loaded model checkpoint from {checkpoint_path}.") | |
| return hps, net_g | |
| # 文本到语音合成 | |
| def text_to_speech(content, hps, net_g): | |
| if not content or not isinstance(content, str): | |
| raise ValueError("Input text is empty or invalid.") | |
| try: | |
| # 将文本转换为序列 | |
| stn_tst = text_to_sequence(content, hps.data.text_cleaners) | |
| if hps.data.add_blank: | |
| stn_tst = commons.intersperse(stn_tst, 0) | |
| stn_tst = torch.LongTensor(stn_tst) | |
| # 模型推理 | |
| with torch.no_grad(): | |
| x_tst = stn_tst.unsqueeze(0) | |
| x_tst_lengths = torch.LongTensor([stn_tst.size(0)]) | |
| audio = net_g.infer( | |
| x_tst, x_tst_lengths, noise_scale=0.667, noise_scale_w=0.8, length_scale=1 | |
| )[0][0, 0].data.float().numpy() | |
| return hps.data.sampling_rate, audio | |
| except Exception as e: | |
| logger.error(f"Error during text-to-speech synthesis: {e}") | |
| raise RuntimeError("Failed to generate audio.") | |
| # Gradio 界面 | |
| def create_gradio_interface(hps, net_g): | |
| def safe_syn(content): | |
| try: | |
| return text_to_speech(content, hps, net_g) | |
| except Exception as e: | |
| logger.error(f"Error in Gradio interface: {e}") | |
| return None | |
| app = gr.Blocks() | |
| with app: | |
| with gr.Tabs(): | |
| with gr.TabItem("Basic"): | |
| input1 = gr.Textbox(label="Input Text", placeholder="Enter text here...") | |
| submit = gr.Button("Convert", variant="primary") | |
| output1 = gr.Audio(label="Output Audio") | |
| submit.click(safe_syn, input1, output1) | |
| return app | |
| # 主函数 | |
| def main(): | |
| try: | |
| # 编译 monotonic_align | |
| compile_monotonic_align() | |
| # 加载配置和模型 | |
| config_path = "configs/steins_gate_base.json" | |
| checkpoint_path = "G_265000.pth" | |
| hps, net_g = load_config_and_model(config_path, checkpoint_path) | |
| # 创建 Gradio 界面 | |
| app = create_gradio_interface(hps, net_g) | |
| logger.info("Starting Gradio interface...") | |
| app.launch() | |
| except Exception as e: | |
| logger.critical(f"Fatal error: {e}") | |
| exit(1) | |
| if __name__ == "__main__": | |
| main() |