Ryanus commited on
Commit
cd40891
·
verified ·
1 Parent(s): cad7a1b

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +42 -211
app.py CHANGED
@@ -1,218 +1,49 @@
1
- import random
2
- import numpy as np
3
  import torch
4
  import gradio as gr
5
- import logging
6
  from pathlib import Path
7
- import sys
8
- import re
9
- from typing import List
10
-
11
- logging.basicConfig(level=logging.INFO)
12
- logger = logging.getLogger(__name__)
13
-
14
- # 強制 torch.load 使用 CPU
15
- original_torch_load = torch.load
16
- def patched_torch_load(f, map_location=None, **kwargs):
17
- if map_location is None:
18
- map_location = 'cpu'
19
- logger.info(f"🔧 Loading with map_location={map_location}")
20
- return original_torch_load(f, map_location=map_location, **kwargs)
21
- torch.load = patched_torch_load
22
- if 'torch' in sys.modules:
23
- sys.modules['torch'].load = patched_torch_load
24
- logger.info("✅ Applied torch.load device mapping patch")
25
-
26
- DEVICE = "cpu"
27
- logger.info("🚀 Running on CPU")
28
-
29
- MODEL = None
30
- def get_or_load_model():
31
- global MODEL, DEVICE
32
- if MODEL is None:
33
- print("Model not loaded, initializing...")
34
- try:
35
- try:
36
- from chatterbox.src.chatterbox.tts import ChatterboxTTS
37
- logger.info("✅ Using official chatterbox.src import path")
38
- except ImportError:
39
- from chatterbox import ChatterboxTTS
40
- logger.info("✅ Using chatterbox direct import path")
41
- MODEL = ChatterboxTTS.from_pretrained("cpu")
42
- MODEL.device = "cpu"
43
- logger.info(f"✅ Model loaded successfully on {DEVICE}")
44
- except Exception as e:
45
- logger.error(f"❌ Error loading model: {e}")
46
- raise
47
- return MODEL
48
-
49
- def set_seed(seed: int):
50
- torch.manual_seed(seed)
51
- random.seed(seed)
52
- np.random.seed(seed)
53
-
54
- def split_text_into_chunks(text: str, max_chars: int = 250) -> List[str]:
55
- if len(text) <= max_chars:
56
- return [text]
57
- sentences = re.split(r'(?<=[.!?])\s+', text)
58
- chunks = []
59
- current_chunk = ""
60
- for sentence in sentences:
61
- if len(sentence) > max_chars:
62
- if current_chunk:
63
- chunks.append(current_chunk.strip())
64
- current_chunk = ""
65
- parts = re.split(r'(?<=,)\s+', sentence)
66
- for part in parts:
67
- if len(part) > max_chars:
68
- words = part.split()
69
- word_chunk = ""
70
- for word in words:
71
- if len(word_chunk + " " + word) <= max_chars:
72
- word_chunk += " " + word if word_chunk else word
73
- else:
74
- if word_chunk:
75
- chunks.append(word_chunk.strip())
76
- word_chunk = word
77
- if word_chunk:
78
- chunks.append(word_chunk.strip())
79
- else:
80
- if len(current_chunk + " " + part) <= max_chars:
81
- current_chunk += " " + part if current_chunk else part
82
- else:
83
- if current_chunk:
84
- chunks.append(current_chunk.strip())
85
- current_chunk = part
86
- else:
87
- if len(current_chunk + " " + sentence) <= max_chars:
88
- current_chunk += " " + sentence if current_chunk else sentence
89
- else:
90
- if current_chunk:
91
- chunks.append(current_chunk.strip())
92
- current_chunk = sentence
93
- if current_chunk:
94
- chunks.append(current_chunk.strip())
95
- return [chunk for chunk in chunks if chunk.strip()]
96
-
97
- def generate_tts_audio(
98
- text_input: str,
99
- audio_prompt_path_input: str,
100
- exaggeration_input: float,
101
- temperature_input: float,
102
- seed_num_input: int,
103
- cfgw_input: float,
104
- chunk_size: int = 250
105
- ) -> tuple[int, np.ndarray]:
106
- try:
107
- current_model = get_or_load_model()
108
- if current_model is None:
109
- raise RuntimeError("TTS model is not loaded.")
110
- if seed_num_input != 0:
111
- set_seed(int(seed_num_input))
112
- text_chunks = split_text_into_chunks(text_input, chunk_size)
113
- logger.info(f"Processing {len(text_chunks)} text chunk(s)")
114
- generated_wavs = []
115
- for i, chunk in enumerate(text_chunks):
116
- logger.info(f"Generating chunk {i+1}/{len(text_chunks)}: '{chunk[:50]}...'")
117
- wav = current_model.generate(
118
- chunk,
119
- audio_prompt_path=audio_prompt_path_input,
120
- exaggeration=exaggeration_input,
121
- temperature=temperature_input,
122
- cfg_weight=cfgw_input,
123
- )
124
- generated_wavs.append(wav)
125
- if len(generated_wavs) > 1:
126
- silence_samples = int(0.3 * current_model.sr)
127
- silence = torch.zeros(1, silence_samples, dtype=generated_wavs[0].dtype)
128
- final_wav = generated_wavs[0]
129
- for wav_chunk in generated_wavs[1:]:
130
- final_wav = torch.cat([final_wav, silence, wav_chunk], dim=1)
131
- else:
132
- final_wav = generated_wavs[0]
133
- return (current_model.sr, final_wav.squeeze(0).numpy())
134
- except Exception as e:
135
- logger.error(f"❌ Generation failed: {e}")
136
- raise gr.Error(f"Generation failed: {str(e)}")
137
-
138
- with gr.Blocks(title="🎙️ Chatterbox-TTS (CPU)", theme=gr.themes.Soft()) as demo:
139
- gr.HTML("""
140
- <div style="text-align: center; padding: 20px;">
141
- <h1>🎙️ Chatterbox-TTS Demo (CPU)</h1>
142
- <p style="font-size: 18px; color: #666;">
143
- Generate high-quality speech from text with reference audio styling<br>
144
- <strong>Running on CPU (Huggingface Space)!</strong>
145
- </p>
146
- </div>
147
- """)
148
- with gr.Row():
149
- with gr.Column():
150
- text = gr.Textbox(
151
- value="Hello! This is a test of the Chatterbox-TTS voice cloning system running on CPU.",
152
- label="Text to synthesize (supports long text with automatic chunking)",
153
- max_lines=10,
154
- lines=5
155
- )
156
- ref_wav = gr.Audio(
157
- type="filepath",
158
- label="Reference Audio File (Optional - 6+ seconds recommended)",
159
- sources=["upload", "microphone"]
160
- )
161
- exaggeration = gr.Slider(
162
- 0.25, 2, step=0.05,
163
- label="Exaggeration (Neutral = 0.5, extreme values can be unstable)",
164
- value=0.5
165
- )
166
- cfg_weight = gr.Slider(
167
- 0.2, 1, step=0.05,
168
- label="CFG/Pace",
169
- value=0.5
170
- )
171
- with gr.Accordion("⚙️ Advanced Options", open=False):
172
- chunk_size = gr.Slider(
173
- 100, 400, step=25,
174
- label="Chunk Size (characters per chunk for long text)",
175
- value=250
176
- )
177
- seed_num = gr.Number(
178
- value=0,
179
- label="Random seed (0 for random)",
180
- precision=0
181
- )
182
- temp = gr.Slider(
183
- 0.05, 5, step=0.05,
184
- label="Temperature",
185
- value=0.8
186
- )
187
- run_btn = gr.Button("🎵 Generate Speech", variant="primary", size="lg")
188
- with gr.Column():
189
- audio_output = gr.Audio(label="Generated Speech")
190
- run_btn.click(
191
- fn=generate_tts_audio,
192
- inputs=[text, ref_wav, exaggeration, temp, seed_num, cfg_weight, chunk_size],
193
- outputs=[audio_output],
194
- show_progress=True
195
  )
196
- gr.Examples(
197
- examples=[
198
- ["Hello! This is a test of voice cloning technology running on CPU."],
199
- ["The quick brown fox jumps over the lazy dog. This sentence contains every letter of the alphabet. Now we can test longer text with multiple sentences to see how the chunking works."],
200
- ["Welcome to the future of voice synthesis! With Chatterbox, you can clone any voice in seconds. The technology uses advanced neural networks to capture the unique characteristics of a speaker's voice. This includes their tone, accent, speaking rhythm, and emotional expressiveness. The result is incredibly natural-sounding speech that maintains the original speaker's identity."],
201
- ],
202
- inputs=[text],
203
- label="📝 Example Texts"
 
 
 
 
 
 
 
 
 
 
 
 
204
  )
205
 
206
- def main():
207
- try:
208
- logger.info("Loading model at startup...")
209
- get_or_load_model()
210
- logger.info("✅ Startup model loading complete!")
211
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True, show_error=True)
212
- except Exception as e:
213
- logger.error(f"❌ CRITICAL: Failed to load model on startup: {e}")
214
- print(f"Application may not function properly. Error: {e}")
215
- demo.launch(server_name="0.0.0.0", server_port=7860, share=True, debug=True, show_error=True)
216
-
217
  if __name__ == "__main__":
218
- main()
 
1
+ import os
2
+ import time
3
  import torch
4
  import gradio as gr
 
5
  from pathlib import Path
6
+ import torchaudio
7
+ from chatterbox.tts import ChatterboxTTS
8
+
9
+ # 初始化儲存資料夾
10
+ OUTPUT_DIR = Path("outputs")
11
+ OUTPUT_DIR.mkdir(exist_ok=True)
12
+
13
+ # 載入模型
14
+ model = ChatterboxTTS.from_pretrained(device="cpu")
15
+
16
+ def tts_and_save(text, ref_wav, exaggeration, temperature, seed, cfg_weight):
17
+ if seed != 0:
18
+ torch.manual_seed(int(seed))
19
+ wav = model.generate(
20
+ text,
21
+ audio_prompt_path=ref_wav,
22
+ exaggeration=exaggeration,
23
+ temperature=temperature,
24
+ cfg_weight=cfg_weight,
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
25
  )
26
+ timestamp = time.strftime("%Y%m%d_%H%M%S")
27
+ filename = OUTPUT_DIR / f"tts_{timestamp}.wav"
28
+ torchaudio.save(str(filename), wav.cpu(), model.sr)
29
+ return (model.sr, wav.squeeze(0).numpy()), str(filename)
30
+
31
+ with gr.Blocks() as demo:
32
+ text = gr.Textbox(label="輸入文字")
33
+ ref_wav = gr.Audio(label="參考語音(可選)", sources=["upload", "microphone"], type="filepath")
34
+ exaggeration = gr.Slider(0.25, 2, value=0.5, step=0.05, label="Exaggeration")
35
+ cfg_weight = gr.Slider(0.2, 1, value=0.5, step=0.05, label="CFG/Pace")
36
+ temperature = gr.Slider(0.05, 5, value=0.8, step=0.05, label="Temperature")
37
+ seed = gr.Number(value=0, label="隨機種子 (0=隨機)", precision=0)
38
+ btn = gr.Button("生成並自動儲存")
39
+ output_audio = gr.Audio(label="語音預覽")
40
+ saved_path = gr.Textbox(label="儲存路徑", interactive=False)
41
+
42
+ btn.click(
43
+ tts_and_save,
44
+ inputs=[text, ref_wav, exaggeration, temperature, seed, cfg_weight],
45
+ outputs=[output_audio, saved_path]
46
  )
47
 
 
 
 
 
 
 
 
 
 
 
 
48
  if __name__ == "__main__":
49
+ demo.launch()