tahirturk commited on
Commit
a9bcfd7
·
1 Parent(s): 30c1824
Files changed (1) hide show
  1. app.py +20 -94
app.py CHANGED
@@ -3,28 +3,15 @@ import re
3
  import numpy as np
4
  import torch
5
  import torchaudio
6
- import warnings
7
- import importlib
8
-
9
  from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
10
  import gradio as gr
11
  import spaces
12
 
13
- # ===========================================
14
- # ✅ Environment & Warnings Cleanup
15
- # ===========================================
16
- warnings.filterwarnings("ignore", category=UserWarning)
17
- warnings.filterwarnings("ignore", category=FutureWarning)
18
- torch.set_printoptions(precision=4, sci_mode=False)
19
-
20
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
21
  print(f"🚀 Running on device: {DEVICE}")
22
 
23
  MODEL = None
24
 
25
- # ===========================================
26
- # ✅ Default Language Configurations
27
- # ===========================================
28
  LANGUAGE_CONFIG = {
29
  "ar": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
30
  "text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."},
@@ -40,88 +27,33 @@ LANGUAGE_CONFIG = {
40
  "text": "上个月,我们达到了一个新的里程碑。 我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。"},
41
  }
42
 
43
-
44
  def default_audio_for_ui(lang: str) -> str | None:
45
  return LANGUAGE_CONFIG.get(lang, {}).get("audio")
46
 
47
-
48
  def default_text_for_ui(lang: str) -> str:
49
  return LANGUAGE_CONFIG.get(lang, {}).get("text", "")
50
 
51
-
52
  def get_supported_languages_display() -> str:
53
  items = [f"**{name}** (`{code}`)" for code, name in sorted(SUPPORTED_LANGUAGES.items())]
54
- mid = len(items) // 2
55
  return f"### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)\n" \
56
  f"{' • '.join(items[:mid])}\n\n{' • '.join(items[mid:])}"
57
 
58
-
59
- # ===========================================
60
- # ✅ Smart & Safe Model Loader
61
- # ===========================================
62
  def get_or_load_model():
63
  global MODEL
64
  if MODEL is None:
65
- print("🔄 Loading TTS model...")
66
-
67
- # Try to detect transformers version
68
- try:
69
- import transformers
70
- tf_version = transformers.__version__
71
- except Exception:
72
- tf_version = "unknown"
73
-
74
- # Detect whether attn_implementation is supported
75
- supports_attn = False
76
- try:
77
- from inspect import signature
78
- sig = signature(ChatterboxMultilingualTTS.from_pretrained)
79
- supports_attn = "attn_implementation" in sig.parameters
80
- except Exception:
81
- pass
82
-
83
- try:
84
- if supports_attn:
85
- print(f"⚙️ Using Transformers v{tf_version} with attn_implementation='eager'")
86
- MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE, attn_implementation="eager")
87
- else:
88
- print(f"⚙️ Using Transformers v{tf_version} (attn_implementation not supported)")
89
- MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
90
- except TypeError:
91
- print("⚠️ Fallback: attn_implementation not accepted — loading default config")
92
- MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
93
- except RuntimeError as e:
94
- # Handle out-of-memory and auto CPU fallback
95
- if "CUDA out of memory" in str(e) or "CUDA error" in str(e):
96
- print("💡 GPU memory insufficient. Falling back to CPU...")
97
- global DEVICE
98
- DEVICE = "cpu"
99
- MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
100
- else:
101
- raise e
102
- except Exception as e:
103
- print(f"❌ Model loading failed: {e}")
104
- raise
105
-
106
- # Move to appropriate device
107
  if hasattr(MODEL, "to"):
108
  MODEL.to(DEVICE)
109
-
110
- # Optional flatten for RNN memory warning
111
- if hasattr(MODEL, "rnn") and hasattr(MODEL.rnn, "flatten_parameters"):
112
- try:
113
- MODEL.rnn.flatten_parameters()
114
- except Exception:
115
- pass
116
-
117
  print(f"✅ Model loaded successfully on {DEVICE}")
118
- print(f"💡 Attention mode: {'eager' if supports_attn else 'default'}")
119
  return MODEL
120
 
 
 
 
 
121
 
122
- # ===========================================
123
- # ✅ Helper Utilities
124
- # ===========================================
125
  def set_seed(seed: int):
126
  torch.manual_seed(seed)
127
  if DEVICE == "cuda":
@@ -130,18 +62,16 @@ def set_seed(seed: int):
130
  random.seed(seed)
131
  np.random.seed(seed)
132
 
133
-
134
  def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None:
135
  if provided_path and str(provided_path).strip():
136
  return provided_path
137
  return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
138
 
139
-
140
  # ✅ Text chunking helper
141
  def split_text_into_chunks(text: str, max_chars: int = 500) -> list[str]:
142
  """
143
- Split text into manageable chunks (300–500 characters),
144
- breaking on sentence boundaries.
145
  """
146
  text = re.sub(r"\s+", " ", text.strip())
147
  if len(text) <= max_chars:
@@ -161,10 +91,6 @@ def split_text_into_chunks(text: str, max_chars: int = 500) -> list[str]:
161
 
162
  return [c for c in chunks if c]
163
 
164
-
165
- # ===========================================
166
- # ✅ TTS Generation
167
- # ===========================================
168
  @spaces.GPU
169
  def generate_tts_audio(
170
  text_input: str,
@@ -183,7 +109,7 @@ def generate_tts_audio(
183
  if seed_num_input != 0:
184
  set_seed(int(seed_num_input))
185
 
186
- print(f"\n🗣 Generating speech text length: {len(text_input)}")
187
 
188
  chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
189
  generate_kwargs = {
@@ -195,30 +121,30 @@ def generate_tts_audio(
195
  generate_kwargs["audio_prompt_path"] = chosen_prompt
196
  print(f"🎧 Using reference: {chosen_prompt}")
197
  else:
198
- print("🎙 Using default neutral voice (no reference).")
199
 
 
200
  chunks = split_text_into_chunks(text_input)
201
- print(f"🪄 Text split into {len(chunks)} chunks")
202
 
203
  all_audio = []
 
204
  for i, chunk in enumerate(chunks):
205
- print(f"🔹 Generating chunk {i + 1}/{len(chunks)} ({len(chunk)} chars)...")
206
  wav = current_model.generate(chunk, language_id=language_id, **generate_kwargs)
207
  all_audio.append(wav.squeeze(0).cpu())
208
 
 
209
  final_audio = torch.cat(all_audio, dim=-1)
210
- print("✅ Audio generation complete.\n")
211
  return (current_model.sr, final_audio.numpy())
212
 
213
-
214
- # ===========================================
215
- # ✅ Gradio UI
216
- # ===========================================
217
  with gr.Blocks() as demo:
218
  gr.Markdown("""
219
  # 🎙️ Multi Language Realistic Voice Cloner
220
  Generate long-form multilingual speech with reference audio styling and auto-chunking support.
221
- **By Tahir Turk**
222
  """)
223
 
224
  gr.Markdown(get_supported_languages_display())
@@ -270,4 +196,4 @@ with gr.Blocks() as demo:
270
  outputs=[audio_output],
271
  )
272
 
273
- demo.launch(mcp_server=True, share=True)
 
3
  import numpy as np
4
  import torch
5
  import torchaudio
 
 
 
6
  from src.chatterbox.mtl_tts import ChatterboxMultilingualTTS, SUPPORTED_LANGUAGES
7
  import gradio as gr
8
  import spaces
9
 
 
 
 
 
 
 
 
10
  DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
11
  print(f"🚀 Running on device: {DEVICE}")
12
 
13
  MODEL = None
14
 
 
 
 
15
  LANGUAGE_CONFIG = {
16
  "ar": {"audio": "https://storage.googleapis.com/chatterbox-demo-samples/mtl_prompts/ar_f/ar_prompts2.flac",
17
  "text": "في الشهر الماضي، وصلنا إلى معلم جديد بمليارين من المشاهدات على قناتنا على يوتيوب."},
 
27
  "text": "上个月,我们达到了一个新的里程碑。 我们的YouTube频道观看次数达到了二十亿次,这绝对令人难以置信。"},
28
  }
29
 
 
30
  def default_audio_for_ui(lang: str) -> str | None:
31
  return LANGUAGE_CONFIG.get(lang, {}).get("audio")
32
 
 
33
  def default_text_for_ui(lang: str) -> str:
34
  return LANGUAGE_CONFIG.get(lang, {}).get("text", "")
35
 
 
36
  def get_supported_languages_display() -> str:
37
  items = [f"**{name}** (`{code}`)" for code, name in sorted(SUPPORTED_LANGUAGES.items())]
38
+ mid = len(items)//2
39
  return f"### 🌍 Supported Languages ({len(SUPPORTED_LANGUAGES)} total)\n" \
40
  f"{' • '.join(items[:mid])}\n\n{' • '.join(items[mid:])}"
41
 
 
 
 
 
42
  def get_or_load_model():
43
  global MODEL
44
  if MODEL is None:
45
+ print("Model not loaded, initializing...")
46
+ MODEL = ChatterboxMultilingualTTS.from_pretrained(DEVICE)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
47
  if hasattr(MODEL, "to"):
48
  MODEL.to(DEVICE)
 
 
 
 
 
 
 
 
49
  print(f"✅ Model loaded successfully on {DEVICE}")
 
50
  return MODEL
51
 
52
+ try:
53
+ get_or_load_model()
54
+ except Exception as e:
55
+ print(f"CRITICAL: Failed to load model. Error: {e}")
56
 
 
 
 
57
  def set_seed(seed: int):
58
  torch.manual_seed(seed)
59
  if DEVICE == "cuda":
 
62
  random.seed(seed)
63
  np.random.seed(seed)
64
 
 
65
  def resolve_audio_prompt(language_id: str, provided_path: str | None) -> str | None:
66
  if provided_path and str(provided_path).strip():
67
  return provided_path
68
  return LANGUAGE_CONFIG.get(language_id, {}).get("audio")
69
 
 
70
  # ✅ Text chunking helper
71
  def split_text_into_chunks(text: str, max_chars: int = 500) -> list[str]:
72
  """
73
+ Split text into manageable chunks around 300 characters each,
74
+ breaking on sentence boundaries (., ?, !, etc.).
75
  """
76
  text = re.sub(r"\s+", " ", text.strip())
77
  if len(text) <= max_chars:
 
91
 
92
  return [c for c in chunks if c]
93
 
 
 
 
 
94
  @spaces.GPU
95
  def generate_tts_audio(
96
  text_input: str,
 
109
  if seed_num_input != 0:
110
  set_seed(int(seed_num_input))
111
 
112
+ print(f"🗣 Generating audio for text length={len(text_input)}")
113
 
114
  chosen_prompt = audio_prompt_path_input or default_audio_for_ui(language_id)
115
  generate_kwargs = {
 
121
  generate_kwargs["audio_prompt_path"] = chosen_prompt
122
  print(f"🎧 Using reference: {chosen_prompt}")
123
  else:
124
+ print("No reference provided, using default voice.")
125
 
126
+ # ✅ Split text into manageable chunks
127
  chunks = split_text_into_chunks(text_input)
128
+ print(f"🪄 Split text into {len(chunks)} chunks")
129
 
130
  all_audio = []
131
+
132
  for i, chunk in enumerate(chunks):
133
+ print(f"🔹 Generating chunk {i+1}/{len(chunks)} ({len(chunk)} chars)")
134
  wav = current_model.generate(chunk, language_id=language_id, **generate_kwargs)
135
  all_audio.append(wav.squeeze(0).cpu())
136
 
137
+ # ✅ Concatenate all audio segments
138
  final_audio = torch.cat(all_audio, dim=-1)
139
+ print("✅ Audio generation complete.")
140
  return (current_model.sr, final_audio.numpy())
141
 
142
+ # === Gradio Interface ===
 
 
 
143
  with gr.Blocks() as demo:
144
  gr.Markdown("""
145
  # 🎙️ Multi Language Realistic Voice Cloner
146
  Generate long-form multilingual speech with reference audio styling and auto-chunking support.
147
+ By Tahir Turk
148
  """)
149
 
150
  gr.Markdown(get_supported_languages_display())
 
196
  outputs=[audio_output],
197
  )
198
 
199
+ demo.launch(mcp_server=True, share=True)