| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| |
| """CLI tool to convert text files to audio using Maya1 TTS via MLX.""" |
|
|
| import argparse |
| import sys |
| import time |
| from concurrent.futures import ThreadPoolExecutor, Future |
| from pathlib import Path |
|
|
| import mlx.core as mx |
| import numpy as np |
| import soundfile as sf |
| import torch |
| from mlx_lm import load |
| from mlx_lm.generate import cache as cache_mod, generate_step |
| from mlx_lm.sample_utils import make_logits_processors, make_sampler |
| from snac import SNAC |
|
|
| CODE_START_TOKEN_ID = 128257 |
| CODE_END_TOKEN_ID = 128258 |
| CODE_TOKEN_OFFSET = 128266 |
| SNAC_MIN_ID = 128266 |
| SNAC_MAX_ID = 156937 |
| SNAC_TOKENS_PER_FRAME = 7 |
| SOH_ID = 128259 |
| EOH_ID = 128260 |
| SOA_ID = 128261 |
| BOS_ID = 128000 |
| TEXT_EOT_ID = 128009 |
|
|
|
|
| def build_prefix_ids(tokenizer, description: str) -> list[int]: |
| desc_text = f'<description="{description}">' |
| desc_ids = tokenizer.encode(desc_text, add_special_tokens=False) |
| return [SOH_ID, BOS_ID] + desc_ids |
|
|
|
|
| def build_suffix_ids(tokenizer, text: str, last_prefix_token: int) -> list[int]: |
| text_ids = tokenizer.encode(f" {text}", add_special_tokens=False) |
| return [last_prefix_token] + text_ids + [TEXT_EOT_ID, EOH_ID, SOA_ID, CODE_START_TOKEN_ID] |
|
|
|
|
| def prefill_cache(model, prefix_ids: list[int], prefill_step_size: int = 2048): |
| prompt_cache = cache_mod.make_prompt_cache(model) |
| prompt = mx.array(prefix_ids) |
| n = len(prefix_ids) - 1 |
| for i in range(0, n, prefill_step_size): |
| chunk = prompt[i:min(i + prefill_step_size, n)] |
| model(chunk[None], cache=prompt_cache) |
| mx.eval([c.state for c in prompt_cache]) |
| return prompt_cache |
|
|
|
|
| def clone_cache(prompt_cache) -> list: |
| cloned = [] |
| for c in prompt_cache: |
| state = c.state |
| meta = c.meta_state |
| new_state = tuple(mx.array(a) for a in state) |
| new_c = type(c).from_state(new_state, meta) |
| cloned.append(new_c) |
| return cloned |
|
|
|
|
| def generate_snac_tokens(model, suffix_ids, prefix_cache, max_tokens, sampler, logits_processors): |
| chunk_cache = clone_cache(prefix_cache) |
| tokens = [] |
| for token, _ in generate_step( |
| mx.array(suffix_ids), model, |
| max_tokens=max_tokens, sampler=sampler, |
| logits_processors=logits_processors, prompt_cache=chunk_cache, |
| ): |
| token_id = token.item() if hasattr(token, "item") else int(token) |
| if token_id == CODE_END_TOKEN_ID: |
| break |
| tokens.append(token_id) |
| return tokens |
|
|
|
|
| def extract_snac_codes(token_ids): |
| return [t for t in token_ids if SNAC_MIN_ID <= t <= SNAC_MAX_ID] |
|
|
|
|
| def unpack_snac_from_7(snac_tokens): |
| frames = len(snac_tokens) // SNAC_TOKENS_PER_FRAME |
| snac_tokens = snac_tokens[:frames * SNAC_TOKENS_PER_FRAME] |
| if frames == 0: |
| return [[], [], []] |
| l1, l2, l3 = [], [], [] |
| for i in range(frames): |
| slots = snac_tokens[i * 7:(i + 1) * 7] |
| l1.append((slots[0] - CODE_TOKEN_OFFSET) % 4096) |
| l2.extend([(slots[1] - CODE_TOKEN_OFFSET) % 4096, (slots[4] - CODE_TOKEN_OFFSET) % 4096]) |
| l3.extend([(slots[2] - CODE_TOKEN_OFFSET) % 4096, (slots[3] - CODE_TOKEN_OFFSET) % 4096, |
| (slots[5] - CODE_TOKEN_OFFSET) % 4096, (slots[6] - CODE_TOKEN_OFFSET) % 4096]) |
| return [l1, l2, l3] |
|
|
|
|
| def decode_snac_to_audio(snac_model, snac_tokens, device): |
| codes = extract_snac_codes(snac_tokens) |
| if len(codes) < SNAC_TOKENS_PER_FRAME: |
| return np.array([], dtype=np.float32) |
| levels = unpack_snac_from_7(codes) |
| codes_tensor = [torch.tensor(level, dtype=torch.long, device=device).unsqueeze(0) for level in levels] |
| with torch.inference_mode(): |
| z_q = snac_model.quantizer.from_codes(codes_tensor) |
| audio = snac_model.decoder(z_q)[0, 0].cpu().numpy() |
| if len(audio) > 2048: |
| audio = audio[2048:] |
| return audio |
|
|
|
|
| def split_text_into_chunks(text, max_chars=200): |
| text = text.strip() |
| if not text: |
| return [] |
| if len(text) <= max_chars: |
| return [text] |
| chunks = [] |
| remaining = text |
| while remaining: |
| remaining = remaining.strip() |
| if not remaining: |
| break |
| if len(remaining) <= max_chars: |
| chunks.append(remaining) |
| break |
| split_pos = -1 |
| limit = min(max_chars, len(remaining)) |
| for delimiters in ['.!?', ',;:', ' ']: |
| for i in range(limit - 1, -1, -1): |
| if remaining[i] in delimiters: |
| split_pos = i + 1 |
| break |
| if split_pos != -1: |
| break |
| if split_pos == -1: |
| split_pos = max_chars |
| chunks.append(remaining[:split_pos].strip()) |
| remaining = remaining[split_pos:] |
| return [c for c in chunks if c] |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Maya1 TTS — MLX optimized for Apple Silicon") |
| parser.add_argument("input", help="Input text file path, or '-' for stdin") |
| parser.add_argument("-o", "--output", default="output.wav", help="Output WAV file (default: output.wav)") |
| parser.add_argument("-d", "--description", |
| default="Calm and clear female voice, in her 30s with an American accent, natural pacing.", |
| help="Voice description prompt") |
| parser.add_argument("-m", "--model", default=".", help="Path to MLX model directory (default: current directory)") |
| parser.add_argument("--max-chars", type=int, default=200, help="Max characters per chunk (default: 200)") |
| parser.add_argument("--max-tokens", type=int, default=2048, help="Max tokens per chunk (default: 2048)") |
| parser.add_argument("--temperature", type=float, default=0.4, help="Sampling temperature (default: 0.4)") |
| parser.add_argument("--top-p", type=float, default=0.9, help="Top-p (default: 0.9)") |
| parser.add_argument("--repetition-penalty", type=float, default=1.1, help="Repetition penalty (default: 1.1)") |
| args = parser.parse_args() |
|
|
| if args.input == "-": |
| text = sys.stdin.read().strip() |
| else: |
| input_path = Path(args.input) |
| if not input_path.exists(): |
| print(f"Error: '{input_path}' not found.", file=sys.stderr) |
| sys.exit(1) |
| text = input_path.read_text(encoding="utf-8").strip() |
| if not text: |
| print("Error: empty input.", file=sys.stderr) |
| sys.exit(1) |
|
|
| chunks = split_text_into_chunks(text, max_chars=args.max_chars) |
| print(f"Input: {len(text)} characters, {len(chunks)} chunk(s)") |
|
|
| print("\n[1/3] Loading Maya1 MLX model...") |
| t0 = time.perf_counter() |
| model, tokenizer = load(args.model) |
| print(f" Model loaded in {time.perf_counter() - t0:.1f}s") |
|
|
| print("[2/3] Loading SNAC audio decoder...") |
| snac_device = "mps" if torch.backends.mps.is_available() else "cpu" |
| snac_model = SNAC.from_pretrained("hubertsiuzdak/snac_24khz").eval().to(snac_device) |
| print(f" SNAC decoder on {snac_device}") |
|
|
| print("[3/3] Prefilling description cache...") |
| prefix_ids = build_prefix_ids(tokenizer, args.description) |
| t0 = time.perf_counter() |
| prefix_cache = prefill_cache(model, prefix_ids) |
| print(f" Cached {len(prefix_ids)} prefix tokens in {time.perf_counter() - t0:.2f}s") |
|
|
| sampler = make_sampler(temp=args.temperature, top_p=args.top_p) if args.temperature > 0 else None |
| logits_processors = None |
| if args.repetition_penalty and args.repetition_penalty != 1.0: |
| logits_processors = make_logits_processors(repetition_penalty=args.repetition_penalty) |
|
|
| print(f"\nGenerating speech for {len(chunks)} chunk(s)...\n") |
| all_audio = [None] * len(chunks) |
| total_t0 = time.perf_counter() |
|
|
| executor = ThreadPoolExecutor(max_workers=1) |
| pending_decode = None |
| pending_idx = -1 |
|
|
| for i, chunk in enumerate(chunks): |
| print(f" Chunk {i+1}/{len(chunks)}: {chunk[:80]}{'...' if len(chunk) > 80 else ''}") |
| suffix_ids = build_suffix_ids(tokenizer, chunk, prefix_ids[-1]) |
| t0 = time.perf_counter() |
| generated_tokens = generate_snac_tokens( |
| model, suffix_ids, prefix_cache, |
| max_tokens=args.max_tokens, sampler=sampler, logits_processors=logits_processors, |
| ) |
| gen_time = time.perf_counter() - t0 |
| snac_count = sum(1 for t in generated_tokens if SNAC_MIN_ID <= t <= SNAC_MAX_ID) |
| print(f" Generated {len(generated_tokens)} tokens ({snac_count} SNAC) in {gen_time:.1f}s") |
|
|
| if pending_decode is not None: |
| audio = pending_decode.result() |
| all_audio[pending_idx] = audio |
| if len(audio) > 0: |
| print(f" [bg] Chunk {pending_idx+1} decoded: {len(audio)/24000:.2f}s") |
|
|
| pending_decode = executor.submit(decode_snac_to_audio, snac_model, generated_tokens, snac_device) |
| pending_idx = i |
|
|
| if pending_decode is not None: |
| audio = pending_decode.result() |
| all_audio[pending_idx] = audio |
| if len(audio) > 0: |
| print(f" [bg] Chunk {pending_idx+1} decoded: {len(audio)/24000:.2f}s") |
|
|
| executor.shutdown(wait=False) |
|
|
| valid_audio = [a for a in all_audio if a is not None and len(a) > 0] |
| if not valid_audio: |
| print("\nError: No audio was generated.", file=sys.stderr) |
| sys.exit(1) |
|
|
| silence = np.zeros(int(24000 * 0.3), dtype=np.float32) |
| combined = [] |
| for i, audio in enumerate(valid_audio): |
| combined.append(audio) |
| if i < len(valid_audio) - 1: |
| combined.append(silence) |
| final_audio = np.concatenate(combined) |
|
|
| total_time = time.perf_counter() - total_t0 |
| duration = len(final_audio) / 24000 |
| print(f"\nTotal audio: {duration:.2f}s, generated in {total_time:.1f}s") |
| print(f"Real-time factor: {duration / total_time:.2f}x") |
|
|
| sf.write(args.output, final_audio, 24000) |
| print(f"Saved to {args.output}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|