Spaces:
Sleeping
Sleeping
Create app.py
Browse files
app.py
ADDED
|
@@ -0,0 +1,218 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
import random
|
| 2 |
+
import numpy as np
|
| 3 |
+
import torch
|
| 4 |
+
import gradio as gr
|
| 5 |
+
import logging
|
| 6 |
+
from pathlib import Path
|
| 7 |
+
import sys
|
| 8 |
+
import re
|
| 9 |
+
from typing import List
|
| 10 |
+
|
| 11 |
+
logging.basicConfig(level=logging.INFO)
|
| 12 |
+
logger = logging.getLogger(__name__)
|
| 13 |
+
|
| 14 |
+
# εΌ·εΆ torch.load δ½Ώη¨ CPU
|
| 15 |
+
original_torch_load = torch.load
|
| 16 |
+
def patched_torch_load(f, map_location=None, **kwargs):
|
| 17 |
+
if map_location is None:
|
| 18 |
+
map_location = 'cpu'
|
| 19 |
+
logger.info(f"π§ Loading with map_location={map_location}")
|
| 20 |
+
return original_torch_load(f, map_location=map_location, **kwargs)
|
| 21 |
+
torch.load = patched_torch_load
|
| 22 |
+
if 'torch' in sys.modules:
|
| 23 |
+
sys.modules['torch'].load = patched_torch_load
|
| 24 |
+
logger.info("β
Applied torch.load device mapping patch")
|
| 25 |
+
|
| 26 |
+
DEVICE = "cpu"
|
| 27 |
+
logger.info("π Running on CPU")
|
| 28 |
+
|
| 29 |
+
MODEL = None
|
| 30 |
+
def get_or_load_model():
|
| 31 |
+
global MODEL, DEVICE
|
| 32 |
+
if MODEL is None:
|
| 33 |
+
print("Model not loaded, initializing...")
|
| 34 |
+
try:
|
| 35 |
+
try:
|
| 36 |
+
from chatterbox.src.chatterbox.tts import ChatterboxTTS
|
| 37 |
+
logger.info("β
Using official chatterbox.src import path")
|
| 38 |
+
except ImportError:
|
| 39 |
+
from chatterbox import ChatterboxTTS
|
| 40 |
+
logger.info("β
Using chatterbox direct import path")
|
| 41 |
+
MODEL = ChatterboxTTS.from_pretrained("cpu")
|
| 42 |
+
MODEL.device = "cpu"
|
| 43 |
+
logger.info(f"β
Model loaded successfully on {DEVICE}")
|
| 44 |
+
except Exception as e:
|
| 45 |
+
logger.error(f"β Error loading model: {e}")
|
| 46 |
+
raise
|
| 47 |
+
return MODEL
|
| 48 |
+
|
| 49 |
+
def set_seed(seed: int):
|
| 50 |
+
torch.manual_seed(seed)
|
| 51 |
+
random.seed(seed)
|
| 52 |
+
np.random.seed(seed)
|
| 53 |
+
|
| 54 |
+
def split_text_into_chunks(text: str, max_chars: int = 250) -> List[str]:
|
| 55 |
+
if len(text) <= max_chars:
|
| 56 |
+
return [text]
|
| 57 |
+
sentences = re.split(r'(?<=[.!?])\s+', text)
|
| 58 |
+
chunks = []
|
| 59 |
+
current_chunk = ""
|
| 60 |
+
for sentence in sentences:
|
| 61 |
+
if len(sentence) > max_chars:
|
| 62 |
+
if current_chunk:
|
| 63 |
+
chunks.append(current_chunk.strip())
|
| 64 |
+
current_chunk = ""
|
| 65 |
+
parts = re.split(r'(?<=,)\s+', sentence)
|
| 66 |
+
for part in parts:
|
| 67 |
+
if len(part) > max_chars:
|
| 68 |
+
words = part.split()
|
| 69 |
+
word_chunk = ""
|
| 70 |
+
for word in words:
|
| 71 |
+
if len(word_chunk + " " + word) <= max_chars:
|
| 72 |
+
word_chunk += " " + word if word_chunk else word
|
| 73 |
+
else:
|
| 74 |
+
if word_chunk:
|
| 75 |
+
chunks.append(word_chunk.strip())
|
| 76 |
+
word_chunk = word
|
| 77 |
+
if word_chunk:
|
| 78 |
+
chunks.append(word_chunk.strip())
|
| 79 |
+
else:
|
| 80 |
+
if len(current_chunk + " " + part) <= max_chars:
|
| 81 |
+
current_chunk += " " + part if current_chunk else part
|
| 82 |
+
else:
|
| 83 |
+
if current_chunk:
|
| 84 |
+
chunks.append(current_chunk.strip())
|
| 85 |
+
current_chunk = part
|
| 86 |
+
else:
|
| 87 |
+
if len(current_chunk + " " + sentence) <= max_chars:
|
| 88 |
+
current_chunk += " " + sentence if current_chunk else sentence
|
| 89 |
+
else:
|
| 90 |
+
if current_chunk:
|
| 91 |
+
chunks.append(current_chunk.strip())
|
| 92 |
+
current_chunk = sentence
|
| 93 |
+
if current_chunk:
|
| 94 |
+
chunks.append(current_chunk.strip())
|
| 95 |
+
return [chunk for chunk in chunks if chunk.strip()]
|
| 96 |
+
|
| 97 |
+
def generate_tts_audio(
|
| 98 |
+
text_input: str,
|
| 99 |
+
audio_prompt_path_input: str,
|
| 100 |
+
exaggeration_input: float,
|
| 101 |
+
temperature_input: float,
|
| 102 |
+
seed_num_input: int,
|
| 103 |
+
cfgw_input: float,
|
| 104 |
+
chunk_size: int = 250
|
| 105 |
+
) -> tuple[int, np.ndarray]:
|
| 106 |
+
try:
|
| 107 |
+
current_model = get_or_load_model()
|
| 108 |
+
if current_model is None:
|
| 109 |
+
raise RuntimeError("TTS model is not loaded.")
|
| 110 |
+
if seed_num_input != 0:
|
| 111 |
+
set_seed(int(seed_num_input))
|
| 112 |
+
text_chunks = split_text_into_chunks(text_input, chunk_size)
|
| 113 |
+
logger.info(f"Processing {len(text_chunks)} text chunk(s)")
|
| 114 |
+
generated_wavs = []
|
| 115 |
+
for i, chunk in enumerate(text_chunks):
|
| 116 |
+
logger.info(f"Generating chunk {i+1}/{len(text_chunks)}: '{chunk[:50]}...'")
|
| 117 |
+
wav = current_model.generate(
|
| 118 |
+
chunk,
|
| 119 |
+
audio_prompt_path=audio_prompt_path_input,
|
| 120 |
+
exaggeration=exaggeration_input,
|
| 121 |
+
temperature=temperature_input,
|
| 122 |
+
cfg_weight=cfgw_input,
|
| 123 |
+
)
|
| 124 |
+
generated_wavs.append(wav)
|
| 125 |
+
if len(generated_wavs) > 1:
|
| 126 |
+
silence_samples = int(0.3 * current_model.sr)
|
| 127 |
+
silence = torch.zeros(1, silence_samples, dtype=generated_wavs[0].dtype)
|
| 128 |
+
final_wav = generated_wavs[0]
|
| 129 |
+
for wav_chunk in generated_wavs[1:]:
|
| 130 |
+
final_wav = torch.cat([final_wav, silence, wav_chunk], dim=1)
|
| 131 |
+
else:
|
| 132 |
+
final_wav = generated_wavs[0]
|
| 133 |
+
return (current_model.sr, final_wav.squeeze(0).numpy())
|
| 134 |
+
except Exception as e:
|
| 135 |
+
logger.error(f"β Generation failed: {e}")
|
| 136 |
+
raise gr.Error(f"Generation failed: {str(e)}")
|
| 137 |
+
|
| 138 |
+
with gr.Blocks(title="ποΈ Chatterbox-TTS (CPU)", theme=gr.themes.Soft()) as demo:
|
| 139 |
+
gr.HTML("""
|
| 140 |
+
<div style="text-align: center; padding: 20px;">
|
| 141 |
+
<h1>ποΈ Chatterbox-TTS Demo (CPU)</h1>
|
| 142 |
+
<p style="font-size: 18px; color: #666;">
|
| 143 |
+
Generate high-quality speech from text with reference audio styling<br>
|
| 144 |
+
<strong>Running on CPU (Huggingface Space)!</strong>
|
| 145 |
+
</p>
|
| 146 |
+
</div>
|
| 147 |
+
""")
|
| 148 |
+
with gr.Row():
|
| 149 |
+
with gr.Column():
|
| 150 |
+
text = gr.Textbox(
|
| 151 |
+
value="Hello! This is a test of the Chatterbox-TTS voice cloning system running on CPU.",
|
| 152 |
+
label="Text to synthesize (supports long text with automatic chunking)",
|
| 153 |
+
max_lines=10,
|
| 154 |
+
lines=5
|
| 155 |
+
)
|
| 156 |
+
ref_wav = gr.Audio(
|
| 157 |
+
type="filepath",
|
| 158 |
+
label="Reference Audio File (Optional - 6+ seconds recommended)",
|
| 159 |
+
sources=["upload", "microphone"]
|
| 160 |
+
)
|
| 161 |
+
exaggeration = gr.Slider(
|
| 162 |
+
0.25, 2, step=0.05,
|
| 163 |
+
label="Exaggeration (Neutral = 0.5, extreme values can be unstable)",
|
| 164 |
+
value=0.5
|
| 165 |
+
)
|
| 166 |
+
cfg_weight = gr.Slider(
|
| 167 |
+
0.2, 1, step=0.05,
|
| 168 |
+
label="CFG/Pace",
|
| 169 |
+
value=0.5
|
| 170 |
+
)
|
| 171 |
+
with gr.Accordion("βοΈ Advanced Options", open=False):
|
| 172 |
+
chunk_size = gr.Slider(
|
| 173 |
+
100, 400, step=25,
|
| 174 |
+
label="Chunk Size (characters per chunk for long text)",
|
| 175 |
+
value=250
|
| 176 |
+
)
|
| 177 |
+
seed_num = gr.Number(
|
| 178 |
+
value=0,
|
| 179 |
+
label="Random seed (0 for random)",
|
| 180 |
+
precision=0
|
| 181 |
+
)
|
| 182 |
+
temp = gr.Slider(
|
| 183 |
+
0.05, 5, step=0.05,
|
| 184 |
+
label="Temperature",
|
| 185 |
+
value=0.8
|
| 186 |
+
)
|
| 187 |
+
run_btn = gr.Button("π΅ Generate Speech", variant="primary", size="lg")
|
| 188 |
+
with gr.Column():
|
| 189 |
+
audio_output = gr.Audio(label="Generated Speech")
|
| 190 |
+
run_btn.click(
|
| 191 |
+
fn=generate_tts_audio,
|
| 192 |
+
inputs=[text, ref_wav, exaggeration, temp, seed_num, cfg_weight, chunk_size],
|
| 193 |
+
outputs=[audio_output],
|
| 194 |
+
show_progress=True
|
| 195 |
+
)
|
| 196 |
+
gr.Examples(
|
| 197 |
+
examples=[
|
| 198 |
+
["Hello! This is a test of voice cloning technology running on CPU."],
|
| 199 |
+
["The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet. Now we can test longer text with multiple sentences to see how the chunking works."],
|
| 200 |
+
["Welcome to the future of voice synthesis! With Chatterbox, you can clone any voice in seconds. The technology uses advanced neural networks to capture the unique characteristics of a speaker's voice. This includes their tone, accent, speaking rhythm, and emotional expressiveness. The result is incredibly natural-sounding speech that maintains the original speaker's identity."],
|
| 201 |
+
],
|
| 202 |
+
inputs=[text],
|
| 203 |
+
label="π Example Texts"
|
| 204 |
+
)
|
| 205 |
+
|
| 206 |
+
def main():
|
| 207 |
+
try:
|
| 208 |
+
logger.info("Loading model at startup...")
|
| 209 |
+
get_or_load_model()
|
| 210 |
+
logger.info("β
Startup model loading complete!")
|
| 211 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True, show_error=True)
|
| 212 |
+
except Exception as e:
|
| 213 |
+
logger.error(f"β CRITICAL: Failed to load model on startup: {e}")
|
| 214 |
+
print(f"Application may not function properly. Error: {e}")
|
| 215 |
+
demo.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True, show_error=True)
|
| 216 |
+
|
| 217 |
+
if __name__ == "__main__":
|
| 218 |
+
main()
|