Voice-Clone-Podcast / app-backup.py
seawolf2357's picture
Rename app.py to app-backup.py
013bd45 verified
raw
history blame
11.7 kB
import random
import numpy as np
import torch
from chatterbox.src.chatterbox.tts import ChatterboxTTS
import gradio as gr
import spaces
import re
DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
print(f"๐Ÿš€ Running on device: {DEVICE}")
# --- Global Model Initialization ---
MODEL = None
def get_or_load_model():
"""Loads the ChatterboxTTS model if it hasn't been loaded already,
and ensures it's on the correct device."""
global MODEL
if MODEL is None:
print("Model not loaded, initializing...")
try:
MODEL = ChatterboxTTS.from_pretrained(DEVICE)
if hasattr(MODEL, 'to') and str(MODEL.device) != DEVICE:
MODEL.to(DEVICE)
print(f"Model loaded successfully. Internal device: {getattr(MODEL, 'device', 'N/A')}")
except Exception as e:
print(f"Error loading model: {e}")
raise
return MODEL
# Attempt to load the model at startup.
try:
get_or_load_model()
except Exception as e:
print(f"CRITICAL: Failed to load model on startup. Application may not function. Error: {e}")
def set_seed(seed: int):
"""Sets the random seed for reproducibility across torch, numpy, and random."""
torch.manual_seed(seed)
if DEVICE == "cuda":
torch.cuda.manual_seed(seed)
torch.cuda.manual_seed_all(seed)
random.seed(seed)
np.random.seed(seed)
def split_text_into_chunks(text: str, max_chars: int = 250) -> list[str]:
"""
ํ…์ŠคํŠธ๋ฅผ ๋ฌธ์žฅ ๋‹จ์œ„๋กœ ๋‚˜๋ˆ„๋˜, ๊ฐ ์ฒญํฌ๊ฐ€ max_chars๋ฅผ ๋„˜์ง€ ์•Š๋„๋ก ํ•ฉ๋‹ˆ๋‹ค.
"""
# ๋ฌธ์žฅ ๋‹จ์œ„๋กœ ๋ถ„๋ฆฌ (๊ธฐ๋ณธ์ ์ธ ๋ฌธ์žฅ ๋ถ„๋ฆฌ)
sentences = re.split(r'(?<=[.!?])\s+', text.strip())
chunks = []
current_chunk = ""
for sentence in sentences:
# ํ˜„์žฌ ์ฒญํฌ์— ๋ฌธ์žฅ์„ ์ถ”๊ฐ€ํ•ด๋„ max_chars๋ฅผ ๋„˜์ง€ ์•Š์œผ๋ฉด ์ถ”๊ฐ€
if len(current_chunk) + len(sentence) + 1 <= max_chars:
if current_chunk:
current_chunk += " " + sentence
else:
current_chunk = sentence
else:
# ํ˜„์žฌ ์ฒญํฌ๋ฅผ ์ €์žฅํ•˜๊ณ  ์ƒˆ ์ฒญํฌ ์‹œ์ž‘
if current_chunk:
chunks.append(current_chunk)
# ๋ฌธ์žฅ ์ž์ฒด๊ฐ€ max_chars๋ณด๋‹ค ๊ธด ๊ฒฝ์šฐ ๊ฐ•์ œ๋กœ ๋ถ„ํ• 
if len(sentence) > max_chars:
words = sentence.split()
temp_chunk = ""
for word in words:
if len(temp_chunk) + len(word) + 1 <= max_chars:
if temp_chunk:
temp_chunk += " " + word
else:
temp_chunk = word
else:
if temp_chunk:
chunks.append(temp_chunk)
temp_chunk = word
if temp_chunk:
current_chunk = temp_chunk
else:
current_chunk = sentence
# ๋งˆ์ง€๋ง‰ ์ฒญํฌ ์ถ”๊ฐ€
if current_chunk:
chunks.append(current_chunk)
return chunks
@spaces.GPU
def generate_tts_audio(
text_input: str,
audio_prompt_path_input: str,
exaggeration_input: float,
temperature_input: float,
seed_num_input: int,
cfgw_input: float,
chunk_size_input: int,
progress=gr.Progress()
) -> tuple[int, np.ndarray]:
"""
๊ธด ํ…์ŠคํŠธ๋ฅผ ์ฒญํฌ๋กœ ๋‚˜๋ˆ„์–ด TTS ์˜ค๋””์˜ค๋ฅผ ์ƒ์„ฑํ•˜๊ณ  ์—ฐ๊ฒฐํ•ฉ๋‹ˆ๋‹ค.
๋ชจ๋“  ์ฒ˜๋ฆฌ๋ฅผ ๋‹จ์ผ GPU ์ปจํ…์ŠคํŠธ ๋‚ด์—์„œ ์ˆ˜ํ–‰ํ•ฉ๋‹ˆ๋‹ค.
"""
current_model = get_or_load_model()
if current_model is None:
raise RuntimeError("TTS model is not loaded.")
if seed_num_input != 0:
set_seed(int(seed_num_input))
# ํ…์ŠคํŠธ๋ฅผ ์ฒญํฌ๋กœ ๋ถ„ํ• 
chunks = split_text_into_chunks(text_input, max_chars=chunk_size_input)
total_chunks = len(chunks)
print(f"ํ…์ŠคํŠธ๋ฅผ {total_chunks}๊ฐœ์˜ ์ฒญํฌ๋กœ ๋ถ„ํ• ํ–ˆ์Šต๋‹ˆ๋‹ค.")
# ๊ฐ ์ฒญํฌ์— ๋Œ€ํ•ด ์˜ค๋””์˜ค ์ƒ์„ฑ
audio_segments = []
for i, chunk in enumerate(chunks):
progress((i + 1) / total_chunks, f"์ฒญํฌ {i + 1}/{total_chunks} ์ƒ์„ฑ ์ค‘...")
print(f"์ฒญํฌ {i + 1}/{total_chunks} ์ƒ์„ฑ ์ค‘: '{chunk[:50]}...'")
try:
# ์ง์ ‘ generate ๋ฉ”์„œ๋“œ ํ˜ธ์ถœ (๋ณ„๋„ ํ•จ์ˆ˜ ์—†์ด)
wav = current_model.generate(
chunk,
audio_prompt_path=audio_prompt_path_input,
exaggeration=exaggeration_input,
temperature=temperature_input,
cfg_weight=cfgw_input,
)
wav_chunk = wav.squeeze(0).numpy()
audio_segments.append(wav_chunk)
except Exception as e:
print(f"์ฒญํฌ {i + 1} ์ƒ์„ฑ ์ค‘ ์˜ค๋ฅ˜ ๋ฐœ์ƒ: {e}")
# ์˜ค๋ฅ˜ ๋ฐœ์ƒ ์‹œ ๊ณ„์† ์ง„ํ–‰
continue
# ๋ชจ๋“  ์˜ค๋””์˜ค ์„ธ๊ทธ๋จผํŠธ ์—ฐ๊ฒฐ
if audio_segments:
# ๊ฐ ์ฒญํฌ ์‚ฌ์ด์— ์งง์€ ๋ฌด์Œ ์ถ”๊ฐ€ (์„ ํƒ์‚ฌํ•ญ)
silence_duration = int(0.2 * current_model.sr) # 0.2์ดˆ ๋ฌด์Œ
silence = np.zeros(silence_duration)
final_audio = []
for i, segment in enumerate(audio_segments):
final_audio.append(segment)
if i < len(audio_segments) - 1: # ๋งˆ์ง€๋ง‰ ์„ธ๊ทธ๋จผํŠธ๊ฐ€ ์•„๋‹ˆ๋ฉด ๋ฌด์Œ ์ถ”๊ฐ€
final_audio.append(silence)
concatenated_audio = np.concatenate(final_audio)
print(f"์˜ค๋””์˜ค ์ƒ์„ฑ ์™„๋ฃŒ. ์ด ๊ธธ์ด: {len(concatenated_audio) / current_model.sr:.2f}์ดˆ")
return (current_model.sr, concatenated_audio)
else:
raise RuntimeError("์˜ค๋””์˜ค ์ƒ์„ฑ์— ์‹คํŒจํ–ˆ์Šต๋‹ˆ๋‹ค.")
# ๋‹จ์ผ ์ฒญํฌ ์ƒ์„ฑ์„ ์œ„ํ•œ ๊ฐ„๋‹จํ•œ wrapper ํ•จ์ˆ˜ (GPU ๋ฐ์ฝ”๋ ˆ์ดํ„ฐ ํฌํ•จ)
@spaces.GPU
def generate_single_audio(
text_input: str,
audio_prompt_path_input: str,
exaggeration_input: float,
temperature_input: float,
seed_num_input: int,
cfgw_input: float
) -> tuple[int, np.ndarray]:
"""
๋‹จ์ผ ํ…์ŠคํŠธ์— ๋Œ€ํ•œ TTS ์˜ค๋””์˜ค ์ƒ์„ฑ (300์ž ์ดํ•˜)
"""
current_model = get_or_load_model()
if current_model is None:
raise RuntimeError("TTS model is not loaded.")
if seed_num_input != 0:
set_seed(int(seed_num_input))
print(f"Generating audio for text: '{text_input[:50]}...'")
wav = current_model.generate(
text_input[:300], # ์•ˆ์ „์„ ์œ„ํ•ด 300์ž๋กœ ์ œํ•œ
audio_prompt_path=audio_prompt_path_input,
exaggeration=exaggeration_input,
temperature=temperature_input,
cfg_weight=cfgw_input,
)
print("Audio generation complete.")
return (current_model.sr, wav.squeeze(0).numpy())
with gr.Blocks() as demo:
gr.Markdown(
"""
# Chatterbox TTS Demo - ๋ฌด์ œํ•œ ๊ธธ์ด ๋ฒ„์ „
๊ธด ํ…์ŠคํŠธ๋„ ์ฒญํฌ๋กœ ๋‚˜๋ˆ„์–ด ์ฒ˜๋ฆฌํ•˜์—ฌ ์ œํ•œ ์—†์ด ์Œ์„ฑ์„ ์ƒ์„ฑํ•ฉ๋‹ˆ๋‹ค.
"""
)
with gr.Row():
with gr.Column():
text = gr.Textbox(
value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.",
label="ํ…์ŠคํŠธ ์ž…๋ ฅ (๊ธธ์ด ์ œํ•œ ์—†์Œ)",
lines=10,
max_lines=30
)
ref_wav = gr.Audio(
sources=["upload", "microphone"],
type="filepath",
label="Reference Audio File (Optional)",
value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
)
with gr.Row():
exaggeration = gr.Slider(
0.25, 2, step=.05,
label="Exaggeration (Neutral = 0.5)",
value=.5
)
cfg_weight = gr.Slider(
0.2, 1, step=.05,
label="CFG/Pace",
value=0.5
)
with gr.Row():
chunk_size = gr.Slider(
100, 300, step=50,
label="์ฒญํฌ ํฌ๊ธฐ (๋ฌธ์ž ์ˆ˜)",
value=250,
info="ํ…์ŠคํŠธ๋ฅผ ๋‚˜๋ˆŒ ์ฒญํฌ์˜ ์ตœ๋Œ€ ํฌ๊ธฐ์ž…๋‹ˆ๋‹ค. ์ž‘์„์ˆ˜๋ก ๋” ์ž์—ฐ์Šค๋Ÿฝ์ง€๋งŒ ์ฒ˜๋ฆฌ ์‹œ๊ฐ„์ด ๊ธธ์–ด์ง‘๋‹ˆ๋‹ค."
)
mode = gr.Radio(
choices=["๋‹จ์ผ ์ƒ์„ฑ (300์ž ์ดํ•˜)", "์ฒญํฌ ๋ถ„ํ•  (๋ฌด์ œํ•œ)"],
value="์ฒญํฌ ๋ถ„ํ•  (๋ฌด์ œํ•œ)",
label="์ƒ์„ฑ ๋ชจ๋“œ"
)
with gr.Accordion("๊ณ ๊ธ‰ ์˜ต์…˜", open=False):
seed_num = gr.Number(value=0, label="Random seed (0 for random)")
temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
run_btn = gr.Button("์Œ์„ฑ ์ƒ์„ฑ", variant="primary")
with gr.Column():
audio_output = gr.Audio(label="์ƒ์„ฑ๋œ ์Œ์„ฑ")
# ํ…์ŠคํŠธ ๊ธธ์ด ํ‘œ์‹œ
char_count = gr.Textbox(
label="ํ…์ŠคํŠธ ์ •๋ณด",
value="0 ๋ฌธ์ž, ์•ฝ 0๊ฐœ ์ฒญํฌ",
interactive=False
)
# ํ…์ŠคํŠธ ์ž…๋ ฅ ์‹œ ๋ฌธ์ž ์ˆ˜์™€ ์˜ˆ์ƒ ์ฒญํฌ ์ˆ˜ ์—…๋ฐ์ดํŠธ
def update_char_count(text, chunk_size, mode):
char_len = len(text)
if mode == "๋‹จ์ผ ์ƒ์„ฑ (300์ž ์ดํ•˜)":
if char_len > 300:
return f"{char_len} ๋ฌธ์ž (โš ๏ธ 300์ž ์ดˆ๊ณผ - ์ž˜๋ฆด ์ˆ˜ ์žˆ์Œ)"
else:
return f"{char_len} ๋ฌธ์ž"
else:
chunks = split_text_into_chunks(text, max_chars=chunk_size)
chunk_count = len(chunks)
return f"{char_len} ๋ฌธ์ž, ์•ฝ {chunk_count}๊ฐœ ์ฒญํฌ๋กœ ๋ถ„ํ• ๋จ"
text.change(
fn=update_char_count,
inputs=[text, chunk_size, mode],
outputs=[char_count]
)
chunk_size.change(
fn=update_char_count,
inputs=[text, chunk_size, mode],
outputs=[char_count]
)
mode.change(
fn=update_char_count,
inputs=[text, chunk_size, mode],
outputs=[char_count]
)
# ๋ชจ๋“œ์— ๋”ฐ๋ผ ๋‹ค๋ฅธ ํ•จ์ˆ˜ ํ˜ธ์ถœ
def process_audio(text, ref_wav, exaggeration, temp, seed_num, cfg_weight, chunk_size, mode):
if mode == "๋‹จ์ผ ์ƒ์„ฑ (300์ž ์ดํ•˜)":
return generate_single_audio(text, ref_wav, exaggeration, temp, seed_num, cfg_weight)
else:
return generate_tts_audio(text, ref_wav, exaggeration, temp, seed_num, cfg_weight, chunk_size)
run_btn.click(
fn=process_audio,
inputs=[
text,
ref_wav,
exaggeration,
temp,
seed_num,
cfg_weight,
chunk_size,
mode
],
outputs=[audio_output],
)
gr.Markdown(
"""
### ์‚ฌ์šฉ ํŒ:
- **๋‹จ์ผ ์ƒ์„ฑ ๋ชจ๋“œ**: 300์ž ์ดํ•˜์˜ ์งง์€ ํ…์ŠคํŠธ์— ์ ํ•ฉํ•˜๋ฉฐ ๋น ๋ฅด๊ฒŒ ์ƒ์„ฑ๋ฉ๋‹ˆ๋‹ค
- **์ฒญํฌ ๋ถ„ํ•  ๋ชจ๋“œ**: ๊ธด ํ…์ŠคํŠธ๋ฅผ ์ž๋™์œผ๋กœ ์—ฌ๋Ÿฌ ๋ถ€๋ถ„์œผ๋กœ ๋‚˜๋ˆ„์–ด ์ฒ˜๋ฆฌํ•ฉ๋‹ˆ๋‹ค
- ์ฒญํฌ ํฌ๊ธฐ๋ฅผ ์กฐ์ ˆํ•˜์—ฌ ํ’ˆ์งˆ๊ณผ ์†๋„์˜ ๊ท ํ˜•์„ ๋งž์ถœ ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค
- ๊ฐ ์ฒญํฌ ์‚ฌ์ด์—๋Š” ์ž์—ฐ์Šค๋Ÿฌ์šด ์ „ํ™˜์„ ์œ„ํ•ด ์งง์€ ๋ฌด์Œ์ด ์ถ”๊ฐ€๋ฉ๋‹ˆ๋‹ค
- ๋งค์šฐ ๊ธด ํ…์ŠคํŠธ์˜ ๊ฒฝ์šฐ ์ฒ˜๋ฆฌ ์‹œ๊ฐ„์ด ์˜ค๋ž˜ ๊ฑธ๋ฆด ์ˆ˜ ์žˆ์Šต๋‹ˆ๋‹ค
"""
)
demo.launch()