Update src/chatterbox/mtl_tts.py
Browse files- src/chatterbox/mtl_tts.py +22 -29
src/chatterbox/mtl_tts.py
CHANGED
|
@@ -64,9 +64,7 @@ class ChatterboxMultilingualTTS:
|
|
| 64 |
ENC_COND_LEN = 6 * S3_SR
|
| 65 |
DEC_COND_LEN = 10 * S3GEN_SR
|
| 66 |
|
| 67 |
-
def __init__(self, t3
|
| 68 |
-
tokenizer: MTLTokenizer, device: torch.device,
|
| 69 |
-
conds: Conditionals = None):
|
| 70 |
self.sr = S3GEN_SR
|
| 71 |
self.t3 = t3
|
| 72 |
self.s3gen = s3gen
|
|
@@ -77,13 +75,7 @@ class ChatterboxMultilingualTTS:
|
|
| 77 |
self.watermarker = perth.PerthImplicitWatermarker()
|
| 78 |
|
| 79 |
@classmethod
|
| 80 |
-
def
|
| 81 |
-
return SUPPORTED_LANGUAGES.copy()
|
| 82 |
-
|
| 83 |
-
@classmethod
|
| 84 |
-
def from_local(cls, ckpt_dir: Path, device: torch.device) -> "ChatterboxMultilingualTTS":
|
| 85 |
-
ckpt_dir = Path(ckpt_dir)
|
| 86 |
-
|
| 87 |
ve = VoiceEncoder()
|
| 88 |
ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", weights_only=True))
|
| 89 |
ve.to(device).eval()
|
|
@@ -105,10 +97,10 @@ class ChatterboxMultilingualTTS:
|
|
| 105 |
if (builtin_voice := ckpt_dir / "conds.pt").exists():
|
| 106 |
conds = Conditionals.load(builtin_voice)
|
| 107 |
|
| 108 |
-
return cls(t3, s3gen, ve, tokenizer, device, conds
|
| 109 |
|
| 110 |
@classmethod
|
| 111 |
-
def from_pretrained(cls, device
|
| 112 |
if device is None:
|
| 113 |
device = torch.device("cpu")
|
| 114 |
elif isinstance(device, str):
|
|
@@ -122,35 +114,36 @@ class ChatterboxMultilingualTTS:
|
|
| 122 |
revision="main",
|
| 123 |
allow_patterns=[
|
| 124 |
"ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt",
|
| 125 |
-
"grapheme_mtl_merged_expanded_v1.json", "conds.pt"
|
| 126 |
-
"Cangjie5_TC.json"
|
| 127 |
],
|
| 128 |
token=os.getenv("HF_TOKEN"),
|
| 129 |
))
|
| 130 |
|
| 131 |
-
|
| 132 |
-
|
| 133 |
-
# Ensure all params on CPU and eval
|
| 134 |
-
model.t3.to(device).eval()
|
| 135 |
-
model.s3gen.to(device).eval()
|
| 136 |
-
model.ve.to(device).eval()
|
| 137 |
-
if model.conds:
|
| 138 |
-
for k, v in model.conds.gen.items():
|
| 139 |
-
if torch.is_tensor(v):
|
| 140 |
-
model.conds.gen[k] = v.to(device)
|
| 141 |
-
return model
|
| 142 |
|
| 143 |
@torch.no_grad()
|
| 144 |
-
def generate(self, text: str, speaker_embedding
|
| 145 |
"""
|
| 146 |
-
|
| 147 |
-
|
| 148 |
"""
|
| 149 |
-
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 150 |
token_ids = self.tokenizer.encode(text)
|
| 151 |
token_ids = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
|
| 152 |
|
| 153 |
conds = self.conds.gen if self.conds else {}
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 154 |
t3_out = self.t3(token_ids, **conds)
|
| 155 |
audio = self.s3gen(t3_out, **conds)
|
| 156 |
|
|
|
|
| 64 |
ENC_COND_LEN = 6 * S3_SR
|
| 65 |
DEC_COND_LEN = 10 * S3GEN_SR
|
| 66 |
|
| 67 |
+
def __init__(self, t3, s3gen, ve, tokenizer, device, conds=None):
|
|
|
|
|
|
|
| 68 |
self.sr = S3GEN_SR
|
| 69 |
self.t3 = t3
|
| 70 |
self.s3gen = s3gen
|
|
|
|
| 75 |
self.watermarker = perth.PerthImplicitWatermarker()
|
| 76 |
|
| 77 |
@classmethod
|
| 78 |
+
def from_local(cls, ckpt_dir, device=torch.device("cpu")):
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 79 |
ve = VoiceEncoder()
|
| 80 |
ve.load_state_dict(torch.load(ckpt_dir / "ve.pt", weights_only=True))
|
| 81 |
ve.to(device).eval()
|
|
|
|
| 97 |
if (builtin_voice := ckpt_dir / "conds.pt").exists():
|
| 98 |
conds = Conditionals.load(builtin_voice)
|
| 99 |
|
| 100 |
+
return cls(t3, s3gen, ve, tokenizer, device, conds)
|
| 101 |
|
| 102 |
@classmethod
|
| 103 |
+
def from_pretrained(cls, device=None):
|
| 104 |
if device is None:
|
| 105 |
device = torch.device("cpu")
|
| 106 |
elif isinstance(device, str):
|
|
|
|
| 114 |
revision="main",
|
| 115 |
allow_patterns=[
|
| 116 |
"ve.pt", "t3_mtl23ls_v2.safetensors", "s3gen.pt",
|
| 117 |
+
"grapheme_mtl_merged_expanded_v1.json", "conds.pt"
|
|
|
|
| 118 |
],
|
| 119 |
token=os.getenv("HF_TOKEN"),
|
| 120 |
))
|
| 121 |
|
| 122 |
+
return cls.from_local(ckpt_dir, device)
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 123 |
|
| 124 |
@torch.no_grad()
|
| 125 |
+
def generate(self, text: str, speaker_embedding=None, language_id=None, **kwargs):
|
| 126 |
"""
|
| 127 |
+
CPU-safe text-to-speech.
|
| 128 |
+
Accepts optional `language_id` and any other kwargs.
|
| 129 |
"""
|
| 130 |
+
# Normalize punctuation
|
| 131 |
+
text = text.strip()
|
| 132 |
+
if not text.endswith("."):
|
| 133 |
+
text += "."
|
| 134 |
+
|
| 135 |
+
# Encode text
|
| 136 |
token_ids = self.tokenizer.encode(text)
|
| 137 |
token_ids = torch.tensor(token_ids, dtype=torch.long).unsqueeze(0).to(self.device)
|
| 138 |
|
| 139 |
conds = self.conds.gen if self.conds else {}
|
| 140 |
+
|
| 141 |
+
# Include language_id in conds if provided
|
| 142 |
+
if language_id is not None:
|
| 143 |
+
conds = conds.copy()
|
| 144 |
+
conds['language_id'] = language_id
|
| 145 |
+
|
| 146 |
+
# Run through T3 and S3Gen
|
| 147 |
t3_out = self.t3(token_ids, **conds)
|
| 148 |
audio = self.s3gen(t3_out, **conds)
|
| 149 |
|