TTS / app.py
aberbossio's picture
Upload 3 files
12407c8 verified
import gc
import re
import threading
import traceback
import gradio as gr
import numpy as np
import torch
from transformers import pipeline
MODEL_ID = "fishaudio/s2-pro"
DEFAULT_SR = 24000
SILENCE_MS = 180
CHUNK_CHARS = 280
_pipe = None
_pipe_error = None
_pipe_lock = threading.Lock()
_gen_lock = threading.Lock()
def load_pipeline():
global _pipe, _pipe_error
if _pipe is not None:
return _pipe
if _pipe_error is not None:
raise RuntimeError(_pipe_error)
with _pipe_lock:
if _pipe is not None:
return _pipe
if _pipe_error is not None:
raise RuntimeError(_pipe_error)
try:
_pipe = pipeline(
task="text-to-audio",
model=MODEL_ID,
device=-1,
trust_remote_code=True,
)
return _pipe
except Exception as e:
_pipe_error = f"Failed to load {MODEL_ID}: {e}"
raise RuntimeError(_pipe_error) from e
def normalize_audio(audio):
audio = np.asarray(audio, dtype=np.float32)
if audio.ndim > 1:
audio = audio.squeeze()
max_abs = np.max(np.abs(audio)) if audio.size else 0.0
if max_abs > 1.0:
audio = audio / max_abs
return audio
def split_long_sentence(sentence: str, limit: int):
words = sentence.split()
if not words:
return []
chunks = []
current = words[0]
for word in words[1:]:
trial = current + " " + word
if len(trial) <= limit:
current = trial
else:
chunks.append(current)
current = word
if current:
chunks.append(current)
return chunks
def chunk_text(text: str, limit: int = CHUNK_CHARS):
text = re.sub(r"\s+", " ", (text or "").strip())
if not text:
return []
sentences = re.split(r"(?<=[.!?।])\s+", text)
chunks = []
current = ""
for sentence in sentences:
sentence = sentence.strip()
if not sentence:
continue
parts = [sentence] if len(sentence) <= limit else split_long_sentence(sentence, limit)
for part in parts:
if not current:
current = part
elif len(current) + 1 + len(part) <= limit:
current += " " + part
else:
chunks.append(current)
current = part
if current:
chunks.append(current)
return chunks
def run_one_chunk(pipe, text_chunk: str):
result = pipe(text_chunk)
if isinstance(result, dict):
audio = result.get("audio")
sr = result.get("sampling_rate") or result.get("sample_rate") or DEFAULT_SR
elif isinstance(result, tuple) and len(result) == 2:
sr, audio = result
else:
raise gr.Error(f"Unexpected model output type: {type(result)}")
if audio is None:
raise gr.Error("Model returned no audio.")
return int(sr), normalize_audio(audio)
def synthesize_long(text: str):
text = (text or "").strip()
if not text:
raise gr.Error("Please enter some text.")
chunks = chunk_text(text)
if not chunks:
raise gr.Error("Could not split input text.")
pipe = load_pipeline()
silence = None
pieces = []
sr = DEFAULT_SR
with _gen_lock:
try:
for idx, chunk in enumerate(chunks, start=1):
sr, audio = run_one_chunk(pipe, chunk)
if silence is None:
silence = np.zeros(int(sr * SILENCE_MS / 1000), dtype=np.float32)
pieces.append(audio)
if idx < len(chunks):
pieces.append(silence)
except Exception as e:
tb = traceback.format_exc(limit=2)
raise gr.Error(f"Generation failed: {e}\n\n{tb}") from e
finally:
gc.collect()
if torch.cuda.is_available():
torch.cuda.empty_cache()
if not pieces:
raise gr.Error("No audio was generated.")
final_audio = np.concatenate(pieces)
info = (
f"Done. Model: {MODEL_ID} | Segments: {len(chunks)} | "
f"Input characters: {len(text)} | Output seconds: {len(final_audio) / sr:.1f}"
)
return (sr, final_audio), info
def app_info():
return (
"Long text is supported by auto-splitting your input into smaller chunks and stitching the audio together. "
"There is no small textbox cap or single-pass text cap in the app itself, but the machine and model still have practical limits."
)
with gr.Blocks() as demo:
gr.Markdown("# Fish Audio S2 Pro Text to Speech")
gr.Markdown(app_info())
text = gr.Textbox(
label="Text",
lines=14,
placeholder="Type very long text here. The app will split it into chunks automatically.",
)
btn = gr.Button("Generate Speech")
audio = gr.Audio(label="Audio", type="numpy", show_download_button=True)
status = gr.Textbox(label="Status", interactive=False)
btn.click(synthesize_long, inputs=text, outputs=[audio, status], api_name="tts")
if __name__ == "__main__":
demo.launch()