#!/usr/bin/env python3 """ Generate audio samples from dataset or directly from TTS. Decodes SNAC tokens back to audio and saves as WAV files. """ import os import sys import argparse import json from pathlib import Path import torch import numpy as np import soundfile as sf # Patch torch.load for older models _orig_load = torch.load torch.load = lambda *a, **kw: _orig_load(*a, **{**kw, 'weights_only': False}) # Patch JSON encoder to handle torch.dtype (transformers bug workaround) _orig_default = json.JSONEncoder.default def _patched_default(self, obj): if isinstance(obj, torch.dtype): return str(obj).split('.')[-1] return _orig_default(self, obj) json.JSONEncoder.default = _patched_default def decode_snac_to_audio(snac_tokens, snac_model, device="cuda"): """Decode SNAC tokens back to audio.""" SNAC_BASE = 128266 # Remove offsets to get raw tokens tokens = [] for tok in snac_tokens: tok = int(tok) if tok >= SNAC_BASE: # Remove position-based offset position_offset = ((tok - SNAC_BASE) // 4096) * 4096 raw_tok = tok - SNAC_BASE - position_offset tokens.append(raw_tok) else: tokens.append(tok) # Reshape into 3 codebooks (7 tokens per frame: 1 + 2 + 4) num_frames = len(tokens) // 7 codes = [[], [], []] for i in range(num_frames): base = i * 7 codes[0].append(tokens[base]) # 1 token from codebook 0 codes[1].append(tokens[base + 1]) # 2 tokens from codebook 1 codes[1].append(tokens[base + 2]) codes[2].append(tokens[base + 3]) # 4 tokens from codebook 2 codes[2].append(tokens[base + 4]) codes[2].append(tokens[base + 5]) codes[2].append(tokens[base + 6]) # Convert to tensors codes_tensors = [ torch.tensor(codes[0], dtype=torch.long).unsqueeze(0).to(device), torch.tensor(codes[1], dtype=torch.long).unsqueeze(0).to(device), torch.tensor(codes[2], dtype=torch.long).unsqueeze(0).to(device), ] # Decode with SNAC with torch.no_grad(): audio = snac_model.decode(codes_tensors) return audio.squeeze().cpu().numpy() def main(): parser = argparse.ArgumentParser(description="Generate audio samples") parser.add_argument("--dataset", type=str, help="Path to dataset .pt file") parser.add_argument("--num-samples", type=int, default=5, help="Number of samples to generate") parser.add_argument("--output-dir", type=str, default="./audio_samples", help="Output directory") parser.add_argument("--with-tts", action="store_true", help="Also generate question audio with TTS") args = parser.parse_args() output_dir = Path(args.output_dir) output_dir.mkdir(parents=True, exist_ok=True) device = "cuda" if torch.cuda.is_available() else "cpu" print(f"Using device: {device}") # Load SNAC model print("Loading SNAC model...") import snac snac_model = snac.SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device).eval() # Load TTS if needed tts = None if args.with_tts: print("Loading TTS model...") from soprano import SopranoTTS tts = SopranoTTS(backend="transformers", device=device) if args.dataset: # Load dataset print(f"Loading dataset: {args.dataset}") dataset = torch.load(args.dataset, map_location="cpu", weights_only=False) print(f"Loaded {len(dataset)} samples") # Generate samples for i in range(min(args.num_samples, len(dataset))): sample = dataset[i] print(f"\n--- Sample {i} ---") question = sample.get("text", f"Question {i}") answer = sample.get("answer", f"Answer {i}") print(f"Q: {question}") print(f"A: {answer}") # Decode answer audio from SNAC tokens snac_tokens = sample["snac_tokens"] print(f"Decoding {len(snac_tokens)} SNAC tokens...") try: answer_audio = decode_snac_to_audio(snac_tokens, snac_model, device) answer_path = output_dir / f"sample_{i:03d}_answer.wav" sf.write(answer_path, answer_audio, 24000) print(f"Saved: {answer_path}") except Exception as e: print(f"Error decoding answer: {e}") # Generate question audio with TTS if tts: try: question_audio = tts.infer(question) if hasattr(question_audio, 'cpu'): question_audio = question_audio.cpu().numpy() question_path = output_dir / f"sample_{i:03d}_question.wav" sf.write(question_path, question_audio, 32000) print(f"Saved: {question_path}") except Exception as e: print(f"Error generating question: {e}") # Save text text_path = output_dir / f"sample_{i:03d}.txt" with open(text_path, "w") as f: f.write(f"Question: {question}\n\nAnswer: {answer}\n") print(f"Saved: {text_path}") else: # Generate fresh samples with TTS if not tts: print("Loading TTS model...") from soprano import SopranoTTS tts = SopranoTTS(backend="transformers", device=device) # Sample Q&A pairs samples = [ ("What is the capital of France?", "Paris is the capital of France, known for the Eiffel Tower."), ("How many planets are in our solar system?", "There are eight planets in our solar system."), ("Who painted the Mona Lisa?", "Leonardo da Vinci painted the Mona Lisa in the early 16th century."), ("What is the speed of light?", "The speed of light is approximately 299,792 kilometers per second."), ("What is Python?", "Python is a popular programming language known for its simplicity."), ] for i, (question, answer) in enumerate(samples[:args.num_samples]): print(f"\n--- Sample {i} ---") print(f"Q: {question}") print(f"A: {answer}") try: # Generate audio q_audio = tts.infer(question) a_audio = tts.infer(answer) if hasattr(q_audio, 'cpu'): q_audio = q_audio.cpu().numpy() if hasattr(a_audio, 'cpu'): a_audio = a_audio.cpu().numpy() # Save sf.write(output_dir / f"sample_{i:03d}_question.wav", q_audio, 32000) sf.write(output_dir / f"sample_{i:03d}_answer.wav", a_audio, 32000) with open(output_dir / f"sample_{i:03d}.txt", "w") as f: f.write(f"Question: {question}\n\nAnswer: {answer}\n") print(f"Saved to {output_dir}/sample_{i:03d}_*.wav") except Exception as e: print(f"Error: {e}") print(f"\n✅ Done! Samples saved to: {output_dir}") if __name__ == "__main__": main()