codewithjarair commited on
Commit
4f3b5bd
·
verified ·
1 Parent(s): 1cc1f7e

Update app.py

Browse files
Files changed (1) hide show
  1. app.py +107 -146
app.py CHANGED
@@ -1,160 +1,121 @@
1
  import os
2
- import random
3
- import re
4
- import tempfile
5
- import torch
6
- import torchaudio
7
- import numpy as np
8
  import gradio as gr
9
- from chatterbox.tts import ChatterboxTTS
10
 
11
- # Constants
12
- MAX_CHUNK_CHARS = 250
13
- DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
14
 
15
- class VoiceCloningEngine:
16
  """
17
- A dedicated engine to handle Chatterbox TTS operations including
18
- model management, text chunking, and audio generation.
19
  """
20
- def __init__(self, device=DEFAULT_DEVICE):
21
- self.device = device
22
- self.model = None
23
- self.sr = 24000 # Default Chatterbox SR
24
-
25
- def load_model(self):
26
- """Loads the model into memory if not already present."""
27
- if self.model is None:
28
- print(f"Loading Chatterbox TTS on {self.device}...")
29
- self.model = ChatterboxTTS.from_pretrained(self.device)
30
- self.sr = self.model.sr
31
- return self.model
32
-
33
- def set_seed(self, seed: int):
34
- """Sets deterministic seeds for reproducibility."""
35
- if seed == 0:
36
- seed = random.randint(1, 1000000)
37
- torch.manual_seed(seed)
38
- torch.cuda.manual_seed(seed)
39
- torch.cuda.manual_seed_all(seed)
40
- random.seed(seed)
41
- np.random.seed(seed)
42
- return seed
43
-
44
- def chunk_text(self, text):
45
- """
46
- Splits text into chunks at sentence boundaries for long script handling.
47
- """
48
- # Split by punctuation followed by space
49
- sentences = re.split(r'(?<=[.!?])\s+', text.strip())
50
- chunks = []
51
- current_chunk = ""
52
-
53
- for sentence in sentences:
54
- if len(current_chunk) + len(sentence) <= MAX_CHUNK_CHARS:
55
- current_chunk += (sentence + " ")
56
- else:
57
- if current_chunk:
58
- chunks.append(current_chunk.strip())
59
-
60
- # Handle single sentences longer than MAX_CHUNK_CHARS
61
- if len(sentence) > MAX_CHUNK_CHARS:
62
- sub_parts = re.split(r'(?<=,)\s+|\s+', sentence)
63
- temp = ""
64
- for part in sub_parts:
65
- if len(temp) + len(part) <= MAX_CHUNK_CHARS:
66
- temp += (part + " ")
67
- else:
68
- if temp: chunks.append(temp.strip())
69
- temp = part + " "
70
- current_chunk = temp
71
- else:
72
- current_chunk = sentence + " "
73
-
74
- if current_chunk:
75
- chunks.append(current_chunk.strip())
76
- return chunks
77
-
78
- def generate(self, text, ref_audio, exaggeration, cfg_weight, temperature, seed, progress=None):
79
- """
80
- Processes the full script by chunking and concatenating results.
81
- """
82
- self.load_model()
83
- actual_seed = self.set_seed(int(seed))
84
- chunks = self.chunk_text(text)
85
-
86
- if not chunks:
87
- raise ValueError("No valid text provided.")
88
- if ref_audio is None:
89
- raise ValueError("Reference audio is required for cloning.")
90
-
91
- all_wavs = []
92
- for i, chunk in enumerate(chunks):
93
- if progress:
94
- progress((i / len(chunks)), desc=f"Processing chunk {i+1}/{len(chunks)}")
95
-
96
- wav = self.model.generate(
97
- chunk,
98
- audio_prompt_path=ref_audio,
99
- exaggeration=exaggeration,
100
- temperature=temperature,
101
- cfg_weight=cfg_weight
102
- )
103
-
104
- if wav.dim() == 1:
105
- wav = wav.unsqueeze(0)
106
- all_wavs.append(wav.cpu())
107
-
108
- final_wav = torch.cat(all_wavs, dim=-1)
109
-
110
- with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
111
- output_path = tmp.name
112
- torchaudio.save(output_path, final_wav, self.sr)
113
-
114
- return output_path, actual_seed
115
 
116
- # Initialize the engine
117
- engine = VoiceCloningEngine()
118
-
119
- def process_tts(text, ref_audio, exaggeration, cfg_weight, temperature, seed, progress=gr.Progress()):
120
  try:
121
- path, used_seed = engine.generate(text, ref_audio, exaggeration, cfg_weight, temperature, seed, progress)
122
- return path, f"Success! Seed used: {used_seed}"
 
 
 
 
 
 
 
 
 
123
  except Exception as e:
 
 
124
  return None, f"Error: {str(e)}"
125
 
126
- # UI Construction
127
- with gr.Blocks(theme=gr.themes.Soft(), title="Chatterbox Voice Clone") as demo:
128
- gr.Markdown("# 🗣️ Voice Cloning TTS Engine")
129
- gr.Markdown("Optimized for long scripts with intelligent chunking and smooth concatenation.")
130
-
131
- with gr.Row():
132
- with gr.Column(scale=1):
133
- text_input = gr.Textbox(label="Script", lines=8, placeholder="Enter long text here...")
134
- ref_audio = gr.Audio(label="Reference Voice", type="filepath")
135
-
136
- with gr.Row():
137
- exag = gr.Slider(0.1, 1.0, value=0.5, label="Exaggeration", info="Warning: >0.8 can be unstable")
138
- cfg = gr.Slider(0.0, 1.0, value=0.5, label="CFG/Pace")
139
-
140
- with gr.Accordion("Advanced Options", open=False):
141
- seed_val = gr.Number(label="Seed", value=0, precision=0, info="0 for random")
142
- temp_val = gr.Slider(0.1, 2.0, value=1.0, label="Temperature")
143
-
144
- btn = gr.Button("Generate", variant="primary")
145
-
146
- with gr.Column(scale=1):
147
- audio_out = gr.Audio(label="Generated Audio", type="filepath")
148
- status = gr.Textbox(label="Status", interactive=False)
149
-
150
- gr.Markdown("### 📖 Quick Guide")
151
- gr.Markdown("""
152
- - **Chunking**: Sentences are automatically split at ~250 chars.
153
- - **Secrets**: Use HF Secrets for API keys if needed.
154
- - **Pacing**: Lower CFG for slower, more deliberate speech.
155
- """)
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
156
 
157
- btn.click(process_tts, [text_input, ref_audio, exag, cfg, temp_val, seed_val], [audio_out, status])
 
 
 
 
 
 
 
 
 
 
 
 
 
 
158
 
159
  if __name__ == "__main__":
160
- demo.launch(server_name="0.0.0.0")
 
 
 
1
  import os
 
 
 
 
 
 
2
  import gradio as gr
3
+ from engine import VoiceCloningEngine
4
 
5
+ # Initialize the Voice Cloning Engine
6
+ engine = VoiceCloningEngine()
 
7
 
8
+ def process_tts(text, ref_audio, exaggeration, cfg_weight, temperature, seed, progress=gr.Progress()):
9
  """
10
+ Main TTS processing function connecting the UI with the VoiceCloningEngine.
 
11
  """
12
+ if not text.strip():
13
+ return None, "Error: Please enter a script."
14
+ if ref_audio is None:
15
+ return None, "Error: Please upload a reference audio clip."
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
16
 
 
 
 
 
17
  try:
18
+ # Call the engine with the Gradio Progress callback
19
+ output_path, used_seed = engine.generate(
20
+ text=text,
21
+ ref_audio=ref_audio,
22
+ exaggeration=exaggeration,
23
+ cfg_weight=cfg_weight,
24
+ temperature=temperature,
25
+ seed=seed,
26
+ progress_callback=progress
27
+ )
28
+ return output_path, f"Successfully generated audio with seed {used_seed}."
29
  except Exception as e:
30
+ import traceback
31
+ traceback.print_exc()
32
  return None, f"Error: {str(e)}"
33
 
34
+ # UI Layout and Configuration
35
+ def create_ui():
36
+ with gr.Blocks(theme=gr.themes.Soft(), title="Voice Cloning TTS Chatterbox") as demo:
37
+ gr.Markdown("# 🗣️ Voice Cloning TTS Engine")
38
+ gr.Markdown("""
39
+ **A high-performance voice cloning application powered by Chatterbox TTS.**
40
+ Optimized for long scripts with intelligent chunking, context preservation, and smooth concatenation.
41
+ """)
42
+
43
+ with gr.Row():
44
+ # Configuration Column
45
+ with gr.Column(scale=1):
46
+ text_input = gr.Textbox(
47
+ label="Script",
48
+ placeholder="Paste your long script here. The engine automatically splits it at sentence boundaries for smooth narration...",
49
+ lines=10,
50
+ value="Welcome to the modular voice cloning application. By separating the core processing engine into its own file, we ensure cleaner code and better scalability. This tool automatically handles long texts, ensuring that your narration is smooth and continuous across multiple sentences."
51
+ )
52
+ ref_audio = gr.Audio(
53
+ label="Reference Voice (Voice to Clone)",
54
+ type="filepath",
55
+ sources=["upload", "microphone"]
56
+ )
57
+
58
+ with gr.Row():
59
+ exaggeration = gr.Slider(
60
+ 0.1, 1.0, value=0.5, step=0.05,
61
+ label="Exaggeration",
62
+ info="Intensity of cloned voice traits. Default 0.5. Warning: >0.8 can be unstable."
63
+ )
64
+ cfg_weight = gr.Slider(
65
+ 0.0, 1.0, value=0.5, step=0.05,
66
+ label="CFG/Pace",
67
+ info="Balance between text adherence and reference voice speed."
68
+ )
69
+
70
+ with gr.Accordion("Advanced Options", open=False):
71
+ seed = gr.Number(
72
+ label="Seed",
73
+ value=0,
74
+ precision=0,
75
+ info="Set to 0 for a random seed each time."
76
+ )
77
+ temperature = gr.Slider(
78
+ 0.1, 2.0, value=1.0, step=0.05,
79
+ label="Temperature",
80
+ info="Higher values increase expressiveness and randomness."
81
+ )
82
+
83
+ generate_btn = gr.Button("Generate Speech", variant="primary")
84
+
85
+ # Result Column
86
+ with gr.Column(scale=1):
87
+ audio_output = gr.Audio(label="Generated Speech", type="filepath")
88
+ status_msg = gr.Textbox(label="Status", interactive=False)
89
+
90
+ gr.Markdown("### 📖 Documentation")
91
+ gr.Markdown("""
92
+ ### Features
93
+ - **Modular Engine**: The `VoiceCloningEngine` in `engine.py` handles all core processing, making the app easier to maintain.
94
+ - **Intelligent Chunking**: Scripts are automatically split at sentence boundaries (~250 chars) for stability.
95
+ - **Context Preservation**: Audio segments are concatenated smoothly for long-form narration.
96
+
97
+ ### Deployment & Secrets
98
+ - **Secrets Management**: If your app requires API keys, set them in the **Hugging Face Space Secrets** and access them via `os.getenv()`.
99
+ - **GPU Recommended**: This app runs best on a T4 or L4 GPU Space.
100
+ """)
101
 
102
+ # Connect UI events
103
+ generate_btn.click(
104
+ fn=process_tts,
105
+ inputs=[
106
+ text_input,
107
+ ref_audio,
108
+ exaggeration,
109
+ cfg_weight,
110
+ temperature,
111
+ seed
112
+ ],
113
+ outputs=[audio_output, status_msg]
114
+ )
115
+
116
+ return demo
117
 
118
  if __name__ == "__main__":
119
+ ui = create_ui()
120
+ # Ensure server_name is set for Hugging Face compatibility
121
+ ui.launch(server_name="0.0.0.0")