OLAVAUD commited on
Commit
8cc3ba8
·
verified ·
1 Parent(s): 81275c8

Upload app.py

Browse files
Files changed (1) hide show
  1. app.py +206 -0
app.py ADDED
@@ -0,0 +1,206 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import warnings
2
+ import os
3
+ import random
4
+ import numpy as np
5
+ import torch
6
+ import gradio as gr
7
+ from chatterbox.tts import ChatterboxTTS
8
+ from typing import Optional, Tuple
9
+ from datetime import datetime
10
+ import soundfile as sf
11
+ from pathlib import Path
12
+
13
+ # Désactivation des warnings
14
+ warnings.filterwarnings("ignore", category=UserWarning)
15
+ warnings.filterwarnings("ignore", category=FutureWarning)
16
+
17
+ # Constants
18
+ DEVICE = "cpu" #
19
+ MAX_TEXT_LENGTH = 2000
20
+ MAX_TEXT_SPLIT = 500
21
+ RECORDINGS_DIR = "voice_cloning_recordings"
22
+ DEFAULT_TEXT = """Once when I was six years old I saw a magnificent picture in a book...""" # Texte tronqué
23
+
24
+ # Nouvelle implémentation avec correction
25
+ class CPUTTS(ChatterboxTTS):
26
+ @classmethod
27
+ def from_local(cls, ckpt_dir, device="cpu", **kwargs):
28
+ original_torch_load = torch.load
29
+ def cpu_load(*args, **kwargs):
30
+ kwargs['map_location'] = torch.device('cpu')
31
+ return original_torch_load(*args, **kwargs)
32
+
33
+ torch.load = cpu_load
34
+ try:
35
+ model = super().from_local(ckpt_dir, device, **kwargs)
36
+ # Modification: Utilisation de _model au lieu de model pour l'appel to()
37
+ if hasattr(model, '_model'):
38
+ model._model.to('cpu')
39
+ return model
40
+ finally:
41
+ torch.load = original_torch_load
42
+
43
+ class TTSService:
44
+ def __init__(self):
45
+ self.model = None
46
+
47
+ def load_model(self) -> ChatterboxTTS:
48
+ if self.model is None:
49
+ with warnings.catch_warnings():
50
+ warnings.simplefilter("ignore")
51
+ self.model = CPUTTS.from_pretrained(DEVICE)
52
+
53
+ if hasattr(self.model, '_model'):
54
+ self.model._model.to('cpu')
55
+ return self.model
56
+
57
+
58
+ @staticmethod
59
+ def set_seed(seed: int) -> None:
60
+ torch.manual_seed(seed)
61
+ if torch.cuda.is_available():
62
+ torch.cuda.manual_seed(seed)
63
+ torch.cuda.manual_seed_all(seed)
64
+ random.seed(seed)
65
+ np.random.seed(seed)
66
+
67
+ @staticmethod
68
+ def validate_inputs(text: str, audio_path: Optional[str]) -> Tuple[str, Optional[str]]:
69
+ if not text.strip():
70
+ raise gr.Error("🚨 Please enter some text to synthesize")
71
+ if len(text) > MAX_TEXT_LENGTH:
72
+ raise gr.Error(f"📜 Text too long (max {MAX_TEXT_LENGTH} characters)")
73
+ if audio_path and not os.path.exists(audio_path):
74
+ raise gr.Error("🔊 Reference audio file not found")
75
+ return text, audio_path
76
+
77
+ @staticmethod
78
+ def save_audio(audio: Optional[Tuple[int, np.ndarray]], prefix: str = "reference") -> Optional[str]:
79
+ if audio is None:
80
+ return None
81
+ sr, data = audio
82
+ os.makedirs(RECORDINGS_DIR, exist_ok=True)
83
+ filename = f"{RECORDINGS_DIR}/{prefix}_{datetime.now().strftime('%Y%m%d_%H%M%S')}.wav"
84
+ sf.write(filename, data, sr)
85
+ return filename
86
+
87
+ @staticmethod
88
+ def split_long_text(text: str, max_length: int = MAX_TEXT_SPLIT) -> list[str]:
89
+ sentences = []
90
+ current_chunk = ""
91
+ for sentence in text.split('.'):
92
+ if len(current_chunk) + len(sentence) < max_length:
93
+ current_chunk += sentence + '.'
94
+ else:
95
+ if current_chunk:
96
+ sentences.append(current_chunk)
97
+ current_chunk = sentence + '.'
98
+ if current_chunk:
99
+ sentences.append(current_chunk)
100
+ return sentences
101
+
102
+ def generate_speech(
103
+ self,
104
+ text: str,
105
+ audio_prompt: Optional[Tuple[int, np.ndarray]],
106
+ exaggeration: float,
107
+ temperature: float,
108
+ seed_num: int,
109
+ cfg_weight: float
110
+ ) -> Tuple[int, np.ndarray]:
111
+ try:
112
+ audio_prompt_path = self.save_audio(audio_prompt, "reference")
113
+ text, audio_prompt_path = self.validate_inputs(text, audio_prompt_path)
114
+
115
+ if seed_num != 0:
116
+ self.set_seed(int(seed_num))
117
+
118
+ model = self.load_model()
119
+
120
+ if len(text) > MAX_TEXT_SPLIT:
121
+ text_chunks = self.split_long_text(text)
122
+ full_audio = []
123
+ for chunk in text_chunks:
124
+ wav = model.generate(
125
+ chunk,
126
+ audio_prompt_path=audio_prompt_path,
127
+ exaggeration=exaggeration,
128
+ temperature=temperature,
129
+ cfg_weight=cfg_weight,
130
+ )
131
+ full_audio.append(wav.squeeze(0).numpy())
132
+ final_audio = np.concatenate(full_audio)
133
+ output_path = self.save_audio((model.sr, final_audio), "output")
134
+ return model.sr, final_audio
135
+ else:
136
+ wav = model.generate(
137
+ text,
138
+ audio_prompt_path=audio_prompt_path,
139
+ exaggeration=exaggeration,
140
+ temperature=temperature,
141
+ cfg_weight=cfg_weight,
142
+ )
143
+ output_path = self.save_audio((model.sr, wav.squeeze(0).numpy()), "output")
144
+ return model.sr, wav.squeeze(0).numpy()
145
+ except Exception as e:
146
+ raise gr.Error(f"❌ Generation failed: {str(e)}")
147
+
148
+ def create_interface() -> gr.Blocks:
149
+ tts_service = TTSService()
150
+
151
+ with gr.Blocks(title="🎤 VoiceClone - Unlimited Chatterbox", theme="soft") as demo:
152
+ gr.Markdown("# 🎤 VoiceClone - Unlimited Chatterbox 🎧")
153
+ gr.Markdown("Clone voices and generate speech with AI magic! ✨")
154
+
155
+ with gr.Row():
156
+ with gr.Column(scale=1):
157
+ gr.Markdown("## ⚙️ Input Parameters")
158
+ text_input = gr.Textbox(
159
+ value=DEFAULT_TEXT,
160
+ label=f"📝 Text to synthesize (max {MAX_TEXT_LENGTH} chars)",
161
+ max_lines=10,
162
+ placeholder="Enter your text here...",
163
+ interactive=True
164
+ )
165
+ with gr.Group():
166
+ ref_audio = gr.Audio(
167
+ sources=["upload", "microphone"],
168
+ type="numpy",
169
+ label="🎤 Reference Audio (Wav)"
170
+ )
171
+ exaggeration = gr.Slider(0.25, 2, step=0.05, value=0.5,
172
+ label="🎚️ Exaggeration (Neutral = 0.5)")
173
+ cfg_weight = gr.Slider(0.0, 1, step=0.05, value=0.5,
174
+ label="⏱️ CFG/Pace Control")
175
+ with gr.Accordion("🔧 Advanced Options", open=False):
176
+ seed_num = gr.Number(value=0, label="🎲 Random seed (0 = random)", precision=0)
177
+ temp = gr.Slider(0.05, 5, step=0.05, value=0.8,
178
+ label="🌡️ Temperature (higher = more random)")
179
+ generate_btn = gr.Button("✨ Generate Speech", variant="primary")
180
+
181
+ with gr.Column(scale=1):
182
+ gr.Markdown("## 🔊 Output")
183
+ audio_output = gr.Audio(label="🎧 Generated Speech", interactive=False)
184
+ gr.Markdown("""
185
+ **💡 Tips:**
186
+ - Use clear reference audio under 10 seconds ⏱️
187
+ - Long texts (>500 chars) will be automatically split ✂️
188
+ - Files saved in 'voice_cloning_recordings' folder 📁
189
+ - CPU mode may be slower ⏳
190
+ """)
191
+
192
+ generate_btn.click(
193
+ fn=tts_service.generate_speech,
194
+ inputs=[text_input, ref_audio, exaggeration, temp, seed_num, cfg_weight],
195
+ outputs=audio_output,
196
+ api_name="generate"
197
+ )
198
+
199
+ return demo
200
+
201
+ if __name__ == "__main__":
202
+ os.environ['CUDA_VISIBLE_DEVICES'] = '-1'
203
+ torch.set_default_device('cpu')
204
+ os.makedirs(RECORDINGS_DIR, exist_ok=True)
205
+ app = create_interface()
206
+ app.queue(max_size=10).launch(server_name="0.0.0.0", server_port=7860, share=False)