File size: 4,471 Bytes
6d30bec
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import os
import random
import re
import tempfile
import torch
import torchaudio
import numpy as np
from chatterbox.tts import ChatterboxTTS

# Constants
MAX_CHUNK_CHARS = 250
DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"

class VoiceCloningEngine:
    """
    A standalone engine to handle Chatterbox TTS operations including
    model management, text chunking, and audio generation.
    """
    def __init__(self, device=DEFAULT_DEVICE):
        self.device = device
        self.model = None
        self.sr = 24000  # Default sample rate for Chatterbox

    def load_model(self):
        """Lazy load the model to save memory until needed."""
        if self.model is None:
            print(f"Initializing Chatterbox TTS on {self.device}...")
            try:
                self.model = ChatterboxTTS.from_pretrained(self.device)
                self.sr = self.model.sr
            except Exception as e:
                print(f"Failed to load model: {e}")
                raise RuntimeError(f"Model initialization failed: {str(e)}")
        return self.model

    def set_seed(self, seed: int):
        """Set seeds for reproducibility."""
        if seed == 0:
            seed = random.randint(1, 1000000)
        torch.manual_seed(seed)
        torch.cuda.manual_seed(seed)
        torch.cuda.manual_seed_all(seed)
        random.seed(seed)
        np.random.seed(seed)
        return seed

    def chunk_text(self, text):
        """
        Split long scripts into chunks at sentence boundaries.
        Optimized for the Chatterbox model's token limit.
        """
        if not text:
            return []
            
        # Split by sentence boundaries while keeping the punctuation
        sentences = re.split(r'(?<=[.!?])\s+', text.strip())
        chunks = []
        current_chunk = ""
        
        for sentence in sentences:
            if len(current_chunk) + len(sentence) <= MAX_CHUNK_CHARS:
                current_chunk += (sentence + " ")
            else:
                if current_chunk:
                    chunks.append(current_chunk.strip())
                
                # If a single sentence is too long, split it by commas or spaces
                if len(sentence) > MAX_CHUNK_CHARS:
                    sub_parts = re.split(r'(?<=,)\s+|\s+', sentence)
                    temp = ""
                    for part in sub_parts:
                        if len(temp) + len(part) <= MAX_CHUNK_CHARS:
                            temp += (part + " ")
                        else:
                            if temp: chunks.append(temp.strip())
                            temp = part + " "
                    current_chunk = temp
                else:
                    current_chunk = sentence + " "
                    
        if current_chunk:
            chunks.append(current_chunk.strip())
        return chunks

    def generate(self, text, ref_audio, exaggeration, cfg_weight, temperature, seed, progress_callback=None):
        """
        Generate cloned audio by processing chunks and concatenating them.
        """
        self.load_model()
        actual_seed = self.set_seed(int(seed))
        chunks = self.chunk_text(text)
        
        if not chunks:
            raise ValueError("The script is empty or invalid.")
        if ref_audio is None:
            raise ValueError("A reference audio file is required for voice cloning.")

        all_wavs = []
        total = len(chunks)
        
        for i, chunk in enumerate(chunks):
            if progress_callback:
                progress_callback((i / total), desc=f"Processing chunk {i+1}/{total}")
            
            # Generate the audio chunk
            wav = self.model.generate(
                chunk,
                audio_prompt_path=ref_audio,
                exaggeration=exaggeration,
                temperature=temperature,
                cfg_weight=cfg_weight
            )
            
            # Ensure the output is a 2D tensor [1, T]
            if wav.dim() == 1:
                wav = wav.unsqueeze(0)
            all_wavs.append(wav.cpu())

        # Concatenate all segments
        final_wav = torch.cat(all_wavs, dim=-1)
        
        # Save to a temporary file
        with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
            output_path = tmp.name
            torchaudio.save(output_path, final_wav, self.sr)
            
        return output_path, actual_seed