File size: 4,865 Bytes
844d1a1
 
98659fe
65a784c
98659fe
 
4b0bfb1
98659fe
a60e434
98659fe
561919f
98659fe
6418466
561919f
6418466
 
 
 
 
844d1a1
98659fe
844d1a1
dc4db49
65a784c
98659fe
a60e434
 
 
 
 
 
 
 
 
 
 
 
 
98659fe
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
561919f
98659fe
 
 
 
 
 
a60e434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
0529fb1
 
 
 
 
 
 
a60e434
 
 
 
 
 
 
844d1a1
4b0bfb1
a60e434
 
98659fe
a60e434
 
98659fe
a60e434
 
 
 
 
 
 
 
98659fe
a60e434
4b0bfb1
 
a60e434
 
 
 
 
 
 
98659fe
a60e434
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
9af3f5d
 
a60e434
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
import gradio as gr
import torch
import numpy as np
from TTS.api import TTS
from pydub import AudioSegment
import os
import re
import soundfile as sf
import time

# Security bypass and TOS agreement
os.environ["COQUI_TOS_AGREED"] = "1"

# Patch torch.load for embedding loading
original_torch_load = torch.load
def patched_torch_load(*args, **kwargs):
    kwargs['weights_only'] = False
    return original_torch_load(*args, **kwargs)
torch.load = patched_torch_load

# Initialize XTTS model
device = "cuda" if torch.cuda.is_available() else "cpu"
tts = TTS("tts_models/multilingual/multi-dataset/xtts_v2").to(device)

def extract_speaker_embedding(audio_path):
    try:
        # Get conditioning latents using built-in method
        gpt_cond_latent, speaker_embedding = tts.synthesizer.tts_model.get_conditioning_latents(audio_path=[audio_path])
        
        # Save both latents
        embedding_path = "speaker_embedding.pth"
        torch.save({
            "gpt_cond_latent": gpt_cond_latent.cpu(),
            "speaker_embedding": speaker_embedding.cpu()
        }, embedding_path)
        return embedding_path
    except Exception as e:
        raise gr.Error(f"Error extracting embedding: {str(e)}")

def split_text(text, max_length=182):
    sentences = []
    current = []
    current_len = 0
    
    words = re.split(r'(\s+)', text)
    for word in words:
        if current_len + len(word) > max_length:
            sentences.append("".join(current).strip())
            current = []
            current_len = 0
        current.append(word)
        current_len += len(word)
    
    if current:
        sentences.append("".join(current).strip())
    
    processed = []
    for s in sentences:
        if not s.endswith(('.','!','?')):
            s += '.'
        processed.append(s)
    
    return processed

def synthesize_speech(text, embedding_path):
    try:
        # Load embeddings
        embeddings = torch.load(embedding_path)
        gpt_cond_latent = embeddings["gpt_cond_latent"].to(device)
        speaker_embedding = embeddings["speaker_embedding"].to(device)
        
        # Split text into chunks
        text_chunks = split_text(text)
        
        # Synthesize each chunk
        audio_chunks = []
        for chunk in text_chunks:
            out = tts.synthesizer.tts_model.inference(
                chunk,
                "ru",
                gpt_cond_latent,
                speaker_embedding,
                temperature=0.7,
                length_penalty=1.0,
                repetition_penalty=2.0,
            )
            # Handle both tensor and numpy array outputs
            wav = out["wav"].squeeze()
            if isinstance(wav, torch.Tensor):
                audio_chunks.append(wav.cpu().numpy())
            else:
                audio_chunks.append(wav)
        
        # Combine and save audio
        full_audio = np.concatenate(audio_chunks)
        output_path = "output.wav"
        sf.write(output_path, full_audio, 24000)
        return output_path
    except Exception as e:
        raise gr.Error(f"Error generating speech: {str(e)}")

# Gradio Interface
with gr.Blocks(theme=gr.themes.Soft()) as demo:
    gr.Markdown("# 🐸 XTTS v2 Voice Cloning Demo")
    
    with gr.Tab("🔊 Voice Embedding Creation"):
        gr.Markdown("Upload a short Russian audio sample (3-10 seconds)")
        with gr.Row():
            audio_input = gr.Audio(
                sources=["upload", "microphone"],
                type="filepath",
                label="Input Audio",
                waveform_options={"sample_rate": 24000}
            )
            embedding_output = gr.File(label="Saved Embedding")
        extract_btn = gr.Button("Create Voice Embedding", variant="primary")
    
    with gr.Tab("📢 Speech Generation"):
        gr.Markdown("Upload embedding and enter Russian text")
        with gr.Row():
            text_input = gr.Textbox(
                label="Text Input",
                placeholder="Enter text to synthesize...",
                lines=4,
                max_lines=10
            )
            embedding_input = gr.File(label="Upload Embedding File")
        with gr.Row():
            audio_output = gr.Audio(
                label="Generated Speech",
                autoplay=True,
                waveform_options={"sample_rate": 24000}
            )
        synth_btn = gr.Button("Generate Speech", variant="primary")

    # Event handlers
    extract_btn.click(
        extract_speaker_embedding,
        inputs=audio_input,
        outputs=embedding_output
    )
    
    synth_btn.click(
        synthesize_speech,
        inputs=[text_input, embedding_input],
        outputs=audio_output
    )

if __name__ == "__main__":
    demo.launch(
        server_name="0.0.0.0",
        server_port=7860,
        share=False,
        show_error=True
    )