codewithjarair commited on
Commit
6d30bec
·
verified ·
1 Parent(s): bccd3f8

Update engine.py

Browse files
Files changed (1) hide show
  1. engine.py +127 -0
engine.py CHANGED
@@ -0,0 +1,127 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import os
2
+ import random
3
+ import re
4
+ import tempfile
5
+ import torch
6
+ import torchaudio
7
+ import numpy as np
8
+ from chatterbox.tts import ChatterboxTTS
9
+
10
+ # Constants
11
+ MAX_CHUNK_CHARS = 250
12
+ DEFAULT_DEVICE = "cuda" if torch.cuda.is_available() else "cpu"
13
+
14
+ class VoiceCloningEngine:
15
+ """
16
+ A standalone engine to handle Chatterbox TTS operations including
17
+ model management, text chunking, and audio generation.
18
+ """
19
+ def __init__(self, device=DEFAULT_DEVICE):
20
+ self.device = device
21
+ self.model = None
22
+ self.sr = 24000 # Default sample rate for Chatterbox
23
+
24
+ def load_model(self):
25
+ """Lazy load the model to save memory until needed."""
26
+ if self.model is None:
27
+ print(f"Initializing Chatterbox TTS on {self.device}...")
28
+ try:
29
+ self.model = ChatterboxTTS.from_pretrained(self.device)
30
+ self.sr = self.model.sr
31
+ except Exception as e:
32
+ print(f"Failed to load model: {e}")
33
+ raise RuntimeError(f"Model initialization failed: {str(e)}")
34
+ return self.model
35
+
36
+ def set_seed(self, seed: int):
37
+ """Set seeds for reproducibility."""
38
+ if seed == 0:
39
+ seed = random.randint(1, 1000000)
40
+ torch.manual_seed(seed)
41
+ torch.cuda.manual_seed(seed)
42
+ torch.cuda.manual_seed_all(seed)
43
+ random.seed(seed)
44
+ np.random.seed(seed)
45
+ return seed
46
+
47
+ def chunk_text(self, text):
48
+ """
49
+ Split long scripts into chunks at sentence boundaries.
50
+ Optimized for the Chatterbox model's token limit.
51
+ """
52
+ if not text:
53
+ return []
54
+
55
+ # Split by sentence boundaries while keeping the punctuation
56
+ sentences = re.split(r'(?<=[.!?])\s+', text.strip())
57
+ chunks = []
58
+ current_chunk = ""
59
+
60
+ for sentence in sentences:
61
+ if len(current_chunk) + len(sentence) <= MAX_CHUNK_CHARS:
62
+ current_chunk += (sentence + " ")
63
+ else:
64
+ if current_chunk:
65
+ chunks.append(current_chunk.strip())
66
+
67
+ # If a single sentence is too long, split it by commas or spaces
68
+ if len(sentence) > MAX_CHUNK_CHARS:
69
+ sub_parts = re.split(r'(?<=,)\s+|\s+', sentence)
70
+ temp = ""
71
+ for part in sub_parts:
72
+ if len(temp) + len(part) <= MAX_CHUNK_CHARS:
73
+ temp += (part + " ")
74
+ else:
75
+ if temp: chunks.append(temp.strip())
76
+ temp = part + " "
77
+ current_chunk = temp
78
+ else:
79
+ current_chunk = sentence + " "
80
+
81
+ if current_chunk:
82
+ chunks.append(current_chunk.strip())
83
+ return chunks
84
+
85
+ def generate(self, text, ref_audio, exaggeration, cfg_weight, temperature, seed, progress_callback=None):
86
+ """
87
+ Generate cloned audio by processing chunks and concatenating them.
88
+ """
89
+ self.load_model()
90
+ actual_seed = self.set_seed(int(seed))
91
+ chunks = self.chunk_text(text)
92
+
93
+ if not chunks:
94
+ raise ValueError("The script is empty or invalid.")
95
+ if ref_audio is None:
96
+ raise ValueError("A reference audio file is required for voice cloning.")
97
+
98
+ all_wavs = []
99
+ total = len(chunks)
100
+
101
+ for i, chunk in enumerate(chunks):
102
+ if progress_callback:
103
+ progress_callback((i / total), desc=f"Processing chunk {i+1}/{total}")
104
+
105
+ # Generate the audio chunk
106
+ wav = self.model.generate(
107
+ chunk,
108
+ audio_prompt_path=ref_audio,
109
+ exaggeration=exaggeration,
110
+ temperature=temperature,
111
+ cfg_weight=cfg_weight
112
+ )
113
+
114
+ # Ensure the output is a 2D tensor [1, T]
115
+ if wav.dim() == 1:
116
+ wav = wav.unsqueeze(0)
117
+ all_wavs.append(wav.cpu())
118
+
119
+ # Concatenate all segments
120
+ final_wav = torch.cat(all_wavs, dim=-1)
121
+
122
+ # Save to a temporary file
123
+ with tempfile.NamedTemporaryFile(suffix=".wav", delete=False) as tmp:
124
+ output_path = tmp.name
125
+ torchaudio.save(output_path, final_wav, self.sr)
126
+
127
+ return output_path, actual_seed