seawolf2357 commited on
Commit
75fc1e5
Β·
verified Β·
1 Parent(s): 013bd45

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +279 -0
app.py ADDED
@@ -0,0 +1,279 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import random
2
+ import numpy as np
3
+ import torch
4
+ from chatterbox.src.chatterbox.tts import ChatterboxTTS
5
+ import gradio as gr
6
+ import spaces
7
+ import re
8
+
9
+ DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
10
+ print(f"πŸš€ Running on device: {DEVICE}")
11
+
12
+ def set_seed(seed: int):
13
+ """Sets the random seed for reproducibility across torch, numpy, and random."""
14
+ torch.manual_seed(seed)
15
+ if DEVICE == "cuda":
16
+ torch.cuda.manual_seed(seed)
17
+ torch.cuda.manual_seed_all(seed)
18
+ random.seed(seed)
19
+ np.random.seed(seed)
20
+
21
+ def split_text_into_chunks(text: str, max_chars: int = 250) -> list[str]:
22
+ """
23
+ ν…μŠ€νŠΈλ₯Ό λ¬Έμž₯ λ‹¨μœ„λ‘œ λ‚˜λˆ„λ˜, 각 청크가 max_charsλ₯Ό λ„˜μ§€ μ•Šλ„λ‘ ν•©λ‹ˆλ‹€.
24
+ """
25
+ # λ¬Έμž₯ λ‹¨μœ„λ‘œ 뢄리 (기본적인 λ¬Έμž₯ 뢄리)
26
+ sentences = re.split(r'(?<=[.!?])\s+', text.strip())
27
+
28
+ chunks = []
29
+ current_chunk = ""
30
+
31
+ for sentence in sentences:
32
+ # ν˜„μž¬ 청크에 λ¬Έμž₯을 좔가해도 max_charsλ₯Ό λ„˜μ§€ μ•ŠμœΌλ©΄ μΆ”κ°€
33
+ if len(current_chunk) + len(sentence) + 1 <= max_chars:
34
+ if current_chunk:
35
+ current_chunk += " " + sentence
36
+ else:
37
+ current_chunk = sentence
38
+ else:
39
+ # ν˜„μž¬ 청크λ₯Ό μ €μž₯ν•˜κ³  μƒˆ 청크 μ‹œμž‘
40
+ if current_chunk:
41
+ chunks.append(current_chunk)
42
+
43
+ # λ¬Έμž₯ μžμ²΄κ°€ max_chars보닀 κΈ΄ 경우 κ°•μ œλ‘œ λΆ„ν• 
44
+ if len(sentence) > max_chars:
45
+ words = sentence.split()
46
+ temp_chunk = ""
47
+ for word in words:
48
+ if len(temp_chunk) + len(word) + 1 <= max_chars:
49
+ if temp_chunk:
50
+ temp_chunk += " " + word
51
+ else:
52
+ temp_chunk = word
53
+ else:
54
+ if temp_chunk:
55
+ chunks.append(temp_chunk)
56
+ temp_chunk = word
57
+ if temp_chunk:
58
+ current_chunk = temp_chunk
59
+ else:
60
+ current_chunk = sentence
61
+
62
+ # λ§ˆμ§€λ§‰ 청크 μΆ”κ°€
63
+ if current_chunk:
64
+ chunks.append(current_chunk)
65
+
66
+ return chunks
67
+
68
+ @spaces.GPU(duration=120) # GPU μ‚¬μš© μ‹œκ°„μ„ μΆ©λΆ„νžˆ μ„€μ •
69
+ def generate_tts_audio_gpu(
70
+ text_input: str,
71
+ audio_prompt_path_input: str,
72
+ exaggeration_input: float,
73
+ temperature_input: float,
74
+ seed_num_input: int,
75
+ cfgw_input: float,
76
+ chunk_size_input: int
77
+ ) -> tuple[int, np.ndarray]:
78
+ """
79
+ GPUμ—μ„œ TTS μ˜€λ””μ˜€λ₯Ό μƒμ„±ν•©λ‹ˆλ‹€.
80
+ """
81
+ # GPU ν•¨μˆ˜ λ‚΄μ—μ„œ λͺ¨λΈ λ‘œλ“œ
82
+ model = ChatterboxTTS.from_pretrained(DEVICE)
83
+
84
+ if seed_num_input != 0:
85
+ set_seed(int(seed_num_input))
86
+
87
+ # ν…μŠ€νŠΈκ°€ 짧으면 단일 생성
88
+ if len(text_input) <= 300:
89
+ print(f"단일 ν…μŠ€νŠΈ 생성: '{text_input[:50]}...'")
90
+ wav = model.generate(
91
+ text_input,
92
+ audio_prompt_path=audio_prompt_path_input,
93
+ exaggeration=exaggeration_input,
94
+ temperature=temperature_input,
95
+ cfg_weight=cfgw_input,
96
+ )
97
+ return (model.sr, wav.squeeze(0).numpy())
98
+
99
+ # κΈ΄ ν…μŠ€νŠΈλŠ” 청크둜 λΆ„ν• 
100
+ chunks = split_text_into_chunks(text_input, max_chars=chunk_size_input)
101
+ total_chunks = len(chunks)
102
+ print(f"ν…μŠ€νŠΈλ₯Ό {total_chunks}개의 청크둜 λΆ„ν• ν–ˆμŠ΅λ‹ˆλ‹€.")
103
+
104
+ audio_segments = []
105
+
106
+ for i, chunk in enumerate(chunks):
107
+ print(f"청크 {i + 1}/{total_chunks} 생성 쀑: '{chunk[:50]}...'")
108
+
109
+ try:
110
+ wav = model.generate(
111
+ chunk,
112
+ audio_prompt_path=audio_prompt_path_input,
113
+ exaggeration=exaggeration_input,
114
+ temperature=temperature_input,
115
+ cfg_weight=cfgw_input,
116
+ )
117
+ wav_chunk = wav.squeeze(0).numpy()
118
+ audio_segments.append(wav_chunk)
119
+ except Exception as e:
120
+ print(f"청크 {i + 1} 생성 쀑 였λ₯˜ λ°œμƒ: {e}")
121
+ continue
122
+
123
+ if not audio_segments:
124
+ raise RuntimeError("μ˜€λ””μ˜€ 생성에 μ‹€νŒ¨ν–ˆμŠ΅λ‹ˆλ‹€.")
125
+
126
+ # μ˜€λ””μ˜€ μ„Έκ·Έλ¨ΌνŠΈ μ—°κ²°
127
+ silence_duration = int(0.2 * model.sr) # 0.2초 무음
128
+ silence = np.zeros(silence_duration)
129
+
130
+ final_audio = []
131
+ for i, segment in enumerate(audio_segments):
132
+ final_audio.append(segment)
133
+ if i < len(audio_segments) - 1:
134
+ final_audio.append(silence)
135
+
136
+ concatenated_audio = np.concatenate(final_audio)
137
+
138
+ print(f"μ˜€λ””μ˜€ 생성 μ™„λ£Œ. 총 길이: {len(concatenated_audio) / model.sr:.2f}초")
139
+ return (model.sr, concatenated_audio)
140
+
141
+ # Gradio μΈν„°νŽ˜μ΄μŠ€
142
+ with gr.Blocks() as demo:
143
+ gr.Markdown(
144
+ """
145
+ # Chatterbox TTS Demo - λ¬΄μ œν•œ 길이 버전
146
+ κΈ΄ ν…μŠ€νŠΈλ„ 청크둜 λ‚˜λˆ„μ–΄ μ²˜λ¦¬ν•˜μ—¬ μ œν•œ 없이 μŒοΏ½οΏ½μ„ μƒμ„±ν•©λ‹ˆλ‹€.
147
+
148
+ ⚠️ **주의**: κΈ΄ ν…μŠ€νŠΈ 처리 μ‹œ μ‹œκ°„μ΄ 였래 걸릴 수 μžˆμŠ΅λ‹ˆλ‹€.
149
+ """
150
+ )
151
+
152
+ with gr.Row():
153
+ with gr.Column():
154
+ text = gr.Textbox(
155
+ value="Now let's make my mum's favourite. So three mars bars into the pan. Then we add the tuna and just stir for a bit, just let the chocolate and fish infuse. A sprinkle of olive oil and some tomato ketchup. Now smell that. Oh boy this is going to be incredible.",
156
+ label="ν…μŠ€νŠΈ μž…λ ₯ (길이 μ œν•œ μ—†μŒ)",
157
+ lines=10,
158
+ max_lines=30
159
+ )
160
+
161
+ ref_wav = gr.Audio(
162
+ sources=["upload", "microphone"],
163
+ type="filepath",
164
+ label="Reference Audio File (Optional)",
165
+ value="https://storage.googleapis.com/chatterbox-demo-samples/prompts/female_shadowheart4.flac"
166
+ )
167
+
168
+ with gr.Row():
169
+ exaggeration = gr.Slider(
170
+ 0.25, 2, step=.05,
171
+ label="Exaggeration (Neutral = 0.5)",
172
+ value=.5
173
+ )
174
+ cfg_weight = gr.Slider(
175
+ 0.2, 1, step=.05,
176
+ label="CFG/Pace",
177
+ value=0.5
178
+ )
179
+
180
+ chunk_size = gr.Slider(
181
+ 100, 300, step=50,
182
+ label="청크 크기 (문자 수)",
183
+ value=250,
184
+ info="ν…μŠ€νŠΈλ₯Ό λ‚˜λˆŒ 청크의 μ΅œλŒ€ ν¬κΈ°μž…λ‹ˆλ‹€."
185
+ )
186
+
187
+ with gr.Accordion("κ³ κΈ‰ μ˜΅μ…˜", open=False):
188
+ seed_num = gr.Number(value=0, label="Random seed (0 for random)")
189
+ temp = gr.Slider(0.05, 5, step=.05, label="Temperature", value=.8)
190
+
191
+ run_btn = gr.Button("🎀 μŒμ„± 생성", variant="primary")
192
+
193
+ with gr.Column():
194
+ audio_output = gr.Audio(label="μƒμ„±λœ μŒμ„±")
195
+
196
+ # ν…μŠ€νŠΈ 길이 ν‘œμ‹œ
197
+ char_count = gr.Textbox(
198
+ label="ν…μŠ€νŠΈ 정보",
199
+ value="0 문자",
200
+ interactive=False
201
+ )
202
+
203
+ status = gr.Textbox(
204
+ label="μƒνƒœ",
205
+ value="λŒ€κΈ° 쀑...",
206
+ interactive=False
207
+ )
208
+
209
+ # ν…μŠ€νŠΈ μž…λ ₯ μ‹œ 문자 수 μ—…λ°μ΄νŠΈ
210
+ def update_char_count(text, chunk_size):
211
+ char_len = len(text)
212
+ if char_len <= 300:
213
+ return f"{char_len} 문자 (단일 생성)"
214
+ else:
215
+ chunks = split_text_into_chunks(text, max_chars=chunk_size)
216
+ chunk_count = len(chunks)
217
+ estimated_time = chunk_count * 3 # 청크당 μ•½ 3초 μ˜ˆμƒ
218
+ return f"{char_len} 문자, {chunk_count}개 청크 (μ˜ˆμƒ μ‹œκ°„: μ•½ {estimated_time}초)"
219
+
220
+ text.change(
221
+ fn=update_char_count,
222
+ inputs=[text, chunk_size],
223
+ outputs=[char_count]
224
+ )
225
+
226
+ chunk_size.change(
227
+ fn=update_char_count,
228
+ inputs=[text, chunk_size],
229
+ outputs=[char_count]
230
+ )
231
+
232
+ # 생성 ν•¨μˆ˜ 래퍼 (μƒνƒœ μ—…λ°μ΄νŠΈ 포함)
233
+ def generate_with_status(text, ref_wav, exaggeration, temp, seed_num, cfg_weight, chunk_size):
234
+ try:
235
+ yield gr.update(value="처리 쀑... GPUλ₯Ό ν• λ‹Ήλ°›λŠ” μ€‘μž…λ‹ˆλ‹€."), None
236
+
237
+ # GPU ν•¨μˆ˜ 호좜
238
+ sr, audio = generate_tts_audio_gpu(
239
+ text, ref_wav, exaggeration, temp, seed_num, cfg_weight, chunk_size
240
+ )
241
+
242
+ yield gr.update(value="βœ… 생성 μ™„λ£Œ!"), (sr, audio)
243
+
244
+ except Exception as e:
245
+ yield gr.update(value=f"❌ 였λ₯˜ λ°œμƒ: {str(e)}"), None
246
+
247
+ run_btn.click(
248
+ fn=generate_with_status,
249
+ inputs=[
250
+ text,
251
+ ref_wav,
252
+ exaggeration,
253
+ temp,
254
+ seed_num,
255
+ cfg_weight,
256
+ chunk_size
257
+ ],
258
+ outputs=[status, audio_output],
259
+ )
260
+
261
+ gr.Markdown(
262
+ """
263
+ ### πŸ’‘ μ‚¬μš© 팁:
264
+ - **300자 μ΄ν•˜**: λΉ λ₯Έ 단일 생성
265
+ - **300자 초과**: μžλ™μœΌλ‘œ 청크둜 λΆ„ν• ν•˜μ—¬ 처리
266
+ - 청크 크기가 μž‘μ„μˆ˜λ‘ μžμ—°μŠ€λŸ½μ§€λ§Œ 처리 μ‹œκ°„μ΄ μ¦κ°€ν•©λ‹ˆλ‹€
267
+ - GPU 할당을 κΈ°λ‹€λ¦¬λŠ” μ‹œκ°„μ΄ μžˆμ„ 수 μžˆμŠ΅λ‹ˆλ‹€
268
+
269
+ ### ⏱️ μ˜ˆμƒ 처리 μ‹œκ°„:
270
+ - 300자 μ΄ν•˜: μ•½ 5-10초
271
+ - 1000자: μ•½ 15-30초
272
+ - 5000자: μ•½ 1-2λΆ„
273
+ """
274
+ )
275
+
276
+ # μ•± μ‹€ν–‰ μ‹œ λͺ¨λΈ λ‘œλ“œ 제거 (GPU ν•¨μˆ˜ λ‚΄μ—μ„œλ§Œ λ‘œλ“œ)
277
+ print("앱이 μ‹œμž‘λ˜μ—ˆμŠ΅λ‹ˆλ‹€. λͺ¨λΈμ€ 첫 생성 μ‹œ λ‘œλ“œλ©λ‹ˆλ‹€.")
278
+
279
+ demo.queue().launch()