Chatterbox / app.py
peterlllmm's picture
Update app.py
7f27076 verified
import nltk
nltk.download("punkt")
import random
import numpy as np
import torch
import io
import os
import soundfile as sf
from nltk.tokenize import sent_tokenize
from pydub import AudioSegment, silence # Added silence module
import gradio as gr
from chatterbox.src.chatterbox.tts import ChatterboxTTS
# ===============================
# DEVICE
# ===============================
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Running on: {DEVICE}")
# ===============================
# LOAD MODEL ONCE
# ===============================
MODEL = None
def get_model():
global MODEL
if MODEL is None:
print("Loading Chatterbox model...")
MODEL = ChatterboxTTS.from_pretrained(DEVICE)
if hasattr(MODEL, "to"):
MODEL.to(DEVICE)
print("Model ready.")
return MODEL
get_model()
# ===============================
# SEED
# ===============================
def set_seed(seed):
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
# ===============================
# PODCAST SAFE SETTINGS
# ===============================
MAX_CHARS = 220
SILENCE_MS = 250 # Reduced slightly since we are cleaning audio
FADE_IN = 10 # Reduced fade to avoid eating words
FADE_OUT = 10 # Reduced fade to avoid weird half-breath sounds
# ===============================
# HELPER: TRIM SILENCE/BREATHS
# ===============================
def trim_audio_segment(audio_segment, silence_thresh=-40):
"""
Trims silence or quiet breath sounds from the start and end of a chunk.
Adjust silence_thresh (dBFS) if it cuts off actual words.
"""
# Detect non-silent chunks
non_silent_ranges = silence.detect_nonsilent(
audio_segment,
min_silence_len=100,
silence_thresh=silence_thresh
)
# If audio is completely silent or empty, return empty
if not non_silent_ranges:
return AudioSegment.empty()
# Get start of first sound and end of last sound
start_trim = non_silent_ranges[0][0]
end_trim = non_silent_ranges[-1][1]
return audio_segment[start_trim:end_trim]
# ===============================
# MAIN TTS FUNCTION
# ===============================
def generate_tts(
text,
ref_audio=None,
exaggeration=0.4,
temperature=0.7,
seed=0,
cfg_weight=0.6,
):
model = get_model()
if seed != 0:
set_seed(int(seed))
kwargs = {
"exaggeration": exaggeration,
"temperature": temperature,
"cfg_weight": cfg_weight,
}
# --------------------------------
# Handle reference voice
# --------------------------------
temp_prompt = None
if ref_audio:
try:
audio = AudioSegment.from_file(ref_audio)
temp_prompt = "voice_prompt.wav"
audio.export(temp_prompt, format="wav")
kwargs["audio_prompt_path"] = temp_prompt
except:
print("Reference audio failed — using default voice.")
# --------------------------------
# Sentence chunking
# --------------------------------
sentences = sent_tokenize(text)
chunks = []
current = ""
for s in sentences:
if len(current) + len(s) < MAX_CHARS:
current += " " + s
else:
chunks.append(current.strip())
current = s
if current.strip():
chunks.append(current.strip())
print(f"Total chunks: {len(chunks)}")
# --------------------------------
# Generate audio per chunk
# --------------------------------
final_audio = AudioSegment.empty()
clean_pause = AudioSegment.silent(duration=SILENCE_MS)
for i, chunk in enumerate(chunks):
print(f"Generating chunk {i+1}/{len(chunks)}")
# 1. Generate Raw Audio
wav = model.generate(chunk, **kwargs)
wav_np = wav.squeeze(0).cpu().numpy()
buffer = io.BytesIO()
sf.write(buffer, wav_np, model.sr, format="WAV")
buffer.seek(0)
segment = AudioSegment.from_wav(buffer)
# 2. TRIM ARTIFACTS (The Fix)
# We strip the "trailing breath" or silence from the model output
# BEFORE we add our own clean silence.
segment = trim_audio_segment(segment, silence_thresh=-45)
# 3. Apply light fade only after trimming
if len(segment) > 0:
segment = segment.fade_in(FADE_IN).fade_out(FADE_OUT)
final_audio += segment + clean_pause
# --------------------------------
# Export
# --------------------------------
output_path = "story_voice.mp3"
final_audio.export(output_path, format="mp3", bitrate="192k")
if temp_prompt and os.path.exists(temp_prompt):
os.remove(temp_prompt)
return output_path
# ===============================
# GRADIO UI
# ===============================
with gr.Blocks() as demo:
gr.Markdown("## 🎙️ Storyteller / Podcast Chatterbox TTS (Cleaned)")
text = gr.Textbox(
label="Story Text",
lines=12,
placeholder="Paste your full story here..."
)
ref = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Voice (optional)"
)
exaggeration = gr.Slider(0.25, 1.0, value=0.4, step=0.05, label="Emotion")
temperature = gr.Slider(0.3, 1.2, value=0.7, step=0.05, label="Variation")
cfg = gr.Slider(0.3, 1.0, value=0.6, step=0.05, label="Voice Stability")
seed = gr.Number(value=0, label="Seed (0 = random)")
btn = gr.Button("Generate Voice")
out = gr.Audio(label="Final Audio")
btn.click(
fn=generate_tts,
inputs=[text, ref, exaggeration, temperature, seed, cfg],
outputs=out
)
demo.launch(share=True)