import torch import torchaudio import gradio as gr from fastapi import FastAPI from speechbrain.pretrained import Tacotron2, HIFIGAN DEVICE = "cuda" if torch.cuda.is_available() else "cpu" SR = 22050 tacotron2 = Tacotron2.from_hparams( source="speechbrain/tts-tacotron2-ljspeech", savedir="models/tacotron2", run_opts={"device": DEVICE}, ) hifigan = HIFIGAN.from_hparams( source="speechbrain/tts-hifigan-ljspeech", savedir="models/hifigan", run_opts={"device": DEVICE}, ) @torch.inference_mode() def synth(text): text = (text or "").strip() if not text: return None seq, seq_len = tacotron2.text_to_seq(text) seq = [int(x) for x in seq] # Tacotron2 encoder conv needs enough timesteps min_tokens = 5 pad_id = 0 if len(seq) < min_tokens: seq = seq + [pad_id] * (min_tokens - len(seq)) seq_len = len(seq) seq = torch.tensor(seq, dtype=torch.long, device=DEVICE).unsqueeze(0) seq_len = torch.tensor([seq_len], device=DEVICE) mel, _, _ = tacotron2.infer(seq, seq_len) # Optional: still keep mel padding for vocoder safety if mel.shape[-1] < 5: mel = F.pad(mel, (0, 5 - mel.shape[-1]), mode="replicate") wav = hifigan.decode_batch(mel) if wav.dim() == 3: wav = wav.squeeze(1) wav = wav[0].cpu() out_path = "out.wav" torchaudio.save(out_path, wav.unsqueeze(0), SR) return out_path io = gr.Interface( fn=synth, inputs=gr.Textbox(label="Text", lines=3), outputs=gr.Audio(type="filepath", label="Output"), title="Tacotron 2 + HiFi-GAN", allow_flagging="never", api_name=False, ) app = FastAPI() app = gr.mount_gradio_app(app, io, path="/")