arijitx's picture
Fix model path: arijitx/chatterbox-bangla -> BosonLab/chatterbox-bangla
c684caa verified
"""
ChatterBox Bengali TTS Space
Fine-tuned Bengali TTS model for text-to-speech synthesis
"""
import sys
import os
import tempfile
import gradio as gr
import torch
from huggingface_hub import snapshot_download
import torchaudio
# Add chatterbox-finetuning to path
chatterbox_finetuning_path = os.path.join(os.path.dirname(__file__), "chatterbox-finetuning")
sys.path.insert(0, chatterbox_finetuning_path)
from src.chatterbox_.tts import ChatterboxTTS
# Model configuration
MODEL_ID = "BosonLab/chatterbox-bangla"
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
# Load model and tokenizer
print(f"Loading model from {MODEL_ID}...")
print(f"Using device: {DEVICE}")
try:
# Download model from Hugging Face
model_dir = snapshot_download(MODEL_ID)
print(f"Model downloaded to: {model_dir}")
# Load ChatterBox TTS model
model = ChatterboxTTS.from_local(model_dir, device=DEVICE)
print("Model loaded successfully!")
except Exception as e:
print(f"Error loading model: {e}")
raise
def generate_speech(text, reference_audio=None):
"""
Generate speech from text using the fine-tuned model
Args:
text: Bengali text to convert to speech
reference_audio: Path to reference audio file for voice cloning
Returns:
Tuple of (audio, metadata)
"""
try:
# Create temporary directory for output
temp_dir = tempfile.mkdtemp()
# Generate speech with voice cloning if reference audio is provided
if reference_audio is not None:
wav = model.generate(text, audio_prompt_path=reference_audio)
output_path = os.path.join(temp_dir, "output.wav")
else:
wav = model.generate(text)
output_path = os.path.join(temp_dir, "output.wav")
return (model.sr, wav.squeeze(0).numpy()), f"Generated {len(text)} characters of Bengali text"
# Save audio to file
# torchaudio.save(output_path, wav, model.sr)
# return output_path, f"Generated {len(text)} characters of Bengali text"
except Exception as e:
print(f"Error generating speech: {e}")
return None, f"Error: {str(e)}"
# Create Gradio interface
with gr.Blocks(title="ChatterBox Bengali TTS", theme=gr.themes.Soft()) as demo:
with gr.Row():
with gr.Column(scale=1):
text_input = gr.Textbox(
label="Bengali Text",
placeholder="Enter Bengali text here...",
lines=5,
max_lines=10
)
reference_audio = gr.Audio(
label="Reference Audio (Optional for Voice Cloning)",
type="filepath"
)
generate_btn = gr.Button("Generate Speech", variant="primary")
with gr.Column(scale=1):
audio_output = gr.Audio(label="Generated Speech")
metadata_output = gr.Textbox(label="Metadata", lines=3)
# Event handlers
generate_btn.click(
fn=generate_speech,
inputs=[text_input, reference_audio],
outputs=[audio_output, metadata_output]
)
if __name__ == "__main__":
demo.launch(
server_name="0.0.0.0",
server_port=7860,
share=False
)