maya1-mlx-8bit / tts.py
0xhb's picture
Initial upload: Maya1 MLX 8-bit quantized
c4deb31 verified
#!/usr/bin/env python3
# /// script
# requires-python = ">=3.11"
# dependencies = [
# "mlx>=0.31",
# "mlx-lm>=0.31",
# "torch",
# "snac",
# "soundfile",
# "numpy",
# ]
# ///
"""CLI tool to convert text files to audio using Maya1 TTS via MLX."""
import argparse
import sys
import time
from concurrent.futures import ThreadPoolExecutor, Future
from pathlib import Path
import mlx.core as mx
import numpy as np
import soundfile as sf
import torch
from mlx_lm import load
from mlx_lm.generate import cache as cache_mod, generate_step
from mlx_lm.sample_utils import make_logits_processors, make_sampler
from snac import SNAC
CODE_START_TOKEN_ID = 128257
CODE_END_TOKEN_ID = 128258
CODE_TOKEN_OFFSET = 128266
SNAC_MIN_ID = 128266
SNAC_MAX_ID = 156937
SNAC_TOKENS_PER_FRAME = 7
SOH_ID = 128259
EOH_ID = 128260
SOA_ID = 128261
BOS_ID = 128000
TEXT_EOT_ID = 128009
def build_prefix_ids(tokenizer, description: str) -> list[int]:
desc_text = f'<description="{description}">'
desc_ids = tokenizer.encode(desc_text, add_special_tokens=False)
return [SOH_ID, BOS_ID] + desc_ids
def build_suffix_ids(tokenizer, text: str, last_prefix_token: int) -> list[int]:
text_ids = tokenizer.encode(f" {text}", add_special_tokens=False)
return [last_prefix_token] + text_ids + [TEXT_EOT_ID, EOH_ID, SOA_ID, CODE_START_TOKEN_ID]
def prefill_cache(model, prefix_ids: list[int], prefill_step_size: int = 2048):
prompt_cache = cache_mod.make_prompt_cache(model)
prompt = mx.array(prefix_ids)
n = len(prefix_ids) - 1
for i in range(0, n, prefill_step_size):
chunk = prompt[i:min(i + prefill_step_size, n)]
model(chunk[None], cache=prompt_cache)
mx.eval([c.state for c in prompt_cache])
return prompt_cache
def clone_cache(prompt_cache) -> list:
cloned = []
for c in prompt_cache:
state = c.state
meta = c.meta_state
new_state = tuple(mx.array(a) for a in state)
new_c = type(c).from_state(new_state, meta)
cloned.append(new_c)
return cloned
def generate_snac_tokens(model, suffix_ids, prefix_cache, max_tokens, sampler, logits_processors):
chunk_cache = clone_cache(prefix_cache)
tokens = []
for token, _ in generate_step(
mx.array(suffix_ids), model,
max_tokens=max_tokens, sampler=sampler,
logits_processors=logits_processors, prompt_cache=chunk_cache,
):
token_id = token.item() if hasattr(token, "item") else int(token)
if token_id == CODE_END_TOKEN_ID:
break
tokens.append(token_id)
return tokens
def extract_snac_codes(token_ids):
return [t for t in token_ids if SNAC_MIN_ID <= t <= SNAC_MAX_ID]
def unpack_snac_from_7(snac_tokens):
frames = len(snac_tokens) // SNAC_TOKENS_PER_FRAME
snac_tokens = snac_tokens[:frames * SNAC_TOKENS_PER_FRAME]
if frames == 0:
return [[], [], []]
l1, l2, l3 = [], [], []
for i in range(frames):
slots = snac_tokens[i * 7:(i + 1) * 7]
l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096)
l2.extend([(slots[1] - CODE_TOKEN_OFFSET) % 4096, (slots[4] - CODE_TOKEN_OFFSET) % 4096])
l3.extend([(slots[2] - CODE_TOKEN_OFFSET) % 4096, (slots[3] - CODE_TOKEN_OFFSET) % 4096,
(slots[5] - CODE_TOKEN_OFFSET) % 4096, (slots[6] - CODE_TOKEN_OFFSET) % 4096])
return [l1, l2, l3]
def decode_snac_to_audio(snac_model, snac_tokens, device):
codes = extract_snac_codes(snac_tokens)
if len(codes) < SNAC_TOKENS_PER_FRAME:
return np.array([], dtype=np.float32)
levels = unpack_snac_from_7(codes)
codes_tensor = [torch.tensor(level, dtype=torch.long, device=device).unsqueeze(0) for level in levels]
with torch.inference_mode():
z_q = snac_model.quantizer.from_codes(codes_tensor)
audio = snac_model.decoder(z_q)[0, 0].cpu().numpy()
if len(audio) > 2048:
audio = audio[2048:]
return audio
def split_text_into_chunks(text, max_chars=200):
text = text.strip()
if not text:
return []
if len(text) <= max_chars:
return [text]
chunks = []
remaining = text
while remaining:
remaining = remaining.strip()
if not remaining:
break
if len(remaining) <= max_chars:
chunks.append(remaining)
break
split_pos = -1
limit = min(max_chars, len(remaining))
for delimiters in ['.!?', ',;:', ' ']:
for i in range(limit - 1, -1, -1):
if remaining[i] in delimiters:
split_pos = i + 1
break
if split_pos != -1:
break
if split_pos == -1:
split_pos = max_chars
chunks.append(remaining[:split_pos].strip())
remaining = remaining[split_pos:]
return [c for c in chunks if c]
def main():
parser = argparse.ArgumentParser(description="Maya1 TTS — MLX optimized for Apple Silicon")
parser.add_argument("input", help="Input text file path, or '-' for stdin")
parser.add_argument("-o", "--output", default="output.wav", help="Output WAV file (default: output.wav)")
parser.add_argument("-d", "--description",
default="Calm and clear female voice, in her 30s with an American accent, natural pacing.",
help="Voice description prompt")
parser.add_argument("-m", "--model", default=".", help="Path to MLX model directory (default: current directory)")
parser.add_argument("--max-chars", type=int, default=200, help="Max characters per chunk (default: 200)")
parser.add_argument("--max-tokens", type=int, default=2048, help="Max tokens per chunk (default: 2048)")
parser.add_argument("--temperature", type=float, default=0.4, help="Sampling temperature (default: 0.4)")
parser.add_argument("--top-p", type=float, default=0.9, help="Top-p (default: 0.9)")
parser.add_argument("--repetition-penalty", type=float, default=1.1, help="Repetition penalty (default: 1.1)")
args = parser.parse_args()
if args.input == "-":
text = sys.stdin.read().strip()
else:
input_path = Path(args.input)
if not input_path.exists():
print(f"Error: '{input_path}' not found.", file=sys.stderr)
sys.exit(1)
text = input_path.read_text(encoding="utf-8").strip()
if not text:
print("Error: empty input.", file=sys.stderr)
sys.exit(1)
chunks = split_text_into_chunks(text, max_chars=args.max_chars)
print(f"Input: {len(text)} characters, {len(chunks)} chunk(s)")
print("\n[1/3] Loading Maya1 MLX model...")
t0 = time.perf_counter()
model, tokenizer = load(args.model)
print(f" Model loaded in {time.perf_counter() - t0:.1f}s")
print("[2/3] Loading SNAC audio decoder...")
snac_device = "mps" if torch.backends.mps.is_available() else "cpu"
snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(snac_device)
print(f" SNAC decoder on {snac_device}")
print("[3/3] Prefilling description cache...")
prefix_ids = build_prefix_ids(tokenizer, args.description)
t0 = time.perf_counter()
prefix_cache = prefill_cache(model, prefix_ids)
print(f" Cached {len(prefix_ids)} prefix tokens in {time.perf_counter() - t0:.2f}s")
sampler = make_sampler(temp=args.temperature, top_p=args.top_p) if args.temperature > 0 else None
logits_processors = None
if args.repetition_penalty and args.repetition_penalty != 1.0:
logits_processors = make_logits_processors(repetition_penalty=args.repetition_penalty)
print(f"\nGenerating speech for {len(chunks)} chunk(s)...\n")
all_audio = [None] * len(chunks)
total_t0 = time.perf_counter()
executor = ThreadPoolExecutor(max_workers=1)
pending_decode = None
pending_idx = -1
for i, chunk in enumerate(chunks):
print(f" Chunk {i+1}/{len(chunks)}: {chunk[:80]}{'...' if len(chunk) > 80 else ''}")
suffix_ids = build_suffix_ids(tokenizer, chunk, prefix_ids[-1])
t0 = time.perf_counter()
generated_tokens = generate_snac_tokens(
model, suffix_ids, prefix_cache,
max_tokens=args.max_tokens, sampler=sampler, logits_processors=logits_processors,
)
gen_time = time.perf_counter() - t0
snac_count = sum(1 for t in generated_tokens if SNAC_MIN_ID <= t <= SNAC_MAX_ID)
print(f" Generated {len(generated_tokens)} tokens ({snac_count} SNAC) in {gen_time:.1f}s")
if pending_decode is not None:
audio = pending_decode.result()
all_audio[pending_idx] = audio
if len(audio) > 0:
print(f" [bg] Chunk {pending_idx+1} decoded: {len(audio)/24000:.2f}s")
pending_decode = executor.submit(decode_snac_to_audio, snac_model, generated_tokens, snac_device)
pending_idx = i
if pending_decode is not None:
audio = pending_decode.result()
all_audio[pending_idx] = audio
if len(audio) > 0:
print(f" [bg] Chunk {pending_idx+1} decoded: {len(audio)/24000:.2f}s")
executor.shutdown(wait=False)
valid_audio = [a for a in all_audio if a is not None and len(a) > 0]
if not valid_audio:
print("\nError: No audio was generated.", file=sys.stderr)
sys.exit(1)
silence = np.zeros(int(24000 * 0.3), dtype=np.float32)
combined = []
for i, audio in enumerate(valid_audio):
combined.append(audio)
if i < len(valid_audio) - 1:
combined.append(silence)
final_audio = np.concatenate(combined)
total_time = time.perf_counter() - total_t0
duration = len(final_audio) / 24000
print(f"\nTotal audio: {duration:.2f}s, generated in {total_time:.1f}s")
print(f"Real-time factor: {duration / total_time:.2f}x")
sf.write(args.output, final_audio, 24000)
print(f"Saved to {args.output}")
if __name__ == "__main__":
main()