Shroukkkk's picture
Update app.py
e030059 verified
import torch
import torchaudio
import gradio as gr
from fastapi import FastAPI
from speechbrain.pretrained import Tacotron2, HIFIGAN
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
SR = 22050
tacotron2 = Tacotron2.from_hparams(
source="speechbrain/tts-tacotron2-ljspeech",
savedir="models/tacotron2",
run_opts={"device": DEVICE},
)
hifigan = HIFIGAN.from_hparams(
source="speechbrain/tts-hifigan-ljspeech",
savedir="models/hifigan",
run_opts={"device": DEVICE},
)
@torch.inference_mode()
def synth(text):
text = (text or "").strip()
if not text:
return None
seq, seq_len = tacotron2.text_to_seq(text)
seq = [int(x) for x in seq]
# Tacotron2 encoder conv needs enough timesteps
min_tokens = 5
pad_id = 0
if len(seq) < min_tokens:
seq = seq + [pad_id] * (min_tokens - len(seq))
seq_len = len(seq)
seq = torch.tensor(seq, dtype=torch.long, device=DEVICE).unsqueeze(0)
seq_len = torch.tensor([seq_len], device=DEVICE)
mel, _, _ = tacotron2.infer(seq, seq_len)
# Optional: still keep mel padding for vocoder safety
if mel.shape[-1] < 5:
mel = F.pad(mel, (0, 5 - mel.shape[-1]), mode="replicate")
wav = hifigan.decode_batch(mel)
if wav.dim() == 3:
wav = wav.squeeze(1)
wav = wav[0].cpu()
out_path = "out.wav"
torchaudio.save(out_path, wav.unsqueeze(0), SR)
return out_path
io = gr.Interface(
fn=synth,
inputs=gr.Textbox(label="Text", lines=3),
outputs=gr.Audio(type="filepath", label="Output"),
title="Tacotron 2 + HiFi-GAN",
allow_flagging="never",
api_name=False,
)
app = FastAPI()
app = gr.mount_gradio_app(app, io, path="/")