| |
| """ |
| Generate audio samples from dataset or directly from TTS. |
| Decodes SNAC tokens back to audio and saves as WAV files. |
| """ |
| import os |
| import sys |
| import argparse |
| import json |
| from pathlib import Path |
|
|
| import torch |
| import numpy as np |
| import soundfile as sf |
|
|
| |
| _orig_load = torch.load |
| torch.load = lambda *a, **kw: _orig_load(*a, **{**kw, 'weights_only': False}) |
|
|
| |
| _orig_default = json.JSONEncoder.default |
| def _patched_default(self, obj): |
| if isinstance(obj, torch.dtype): |
| return str(obj).split('.')[-1] |
| return _orig_default(self, obj) |
| json.JSONEncoder.default = _patched_default |
|
|
|
|
| def decode_snac_to_audio(snac_tokens, snac_model, device="cuda"): |
| """Decode SNAC tokens back to audio.""" |
| SNAC_BASE = 128266 |
|
|
| |
| tokens = [] |
| for tok in snac_tokens: |
| tok = int(tok) |
| if tok >= SNAC_BASE: |
| |
| position_offset = ((tok - SNAC_BASE) // 4096) * 4096 |
| raw_tok = tok - SNAC_BASE - position_offset |
| tokens.append(raw_tok) |
| else: |
| tokens.append(tok) |
|
|
| |
| num_frames = len(tokens) // 7 |
|
|
| codes = [[], [], []] |
| for i in range(num_frames): |
| base = i * 7 |
| codes[0].append(tokens[base]) |
| codes[1].append(tokens[base + 1]) |
| codes[1].append(tokens[base + 2]) |
| codes[2].append(tokens[base + 3]) |
| codes[2].append(tokens[base + 4]) |
| codes[2].append(tokens[base + 5]) |
| codes[2].append(tokens[base + 6]) |
|
|
| |
| codes_tensors = [ |
| torch.tensor(codes[0], dtype=torch.long).unsqueeze(0).to(device), |
| torch.tensor(codes[1], dtype=torch.long).unsqueeze(0).to(device), |
| torch.tensor(codes[2], dtype=torch.long).unsqueeze(0).to(device), |
| ] |
|
|
| |
| with torch.no_grad(): |
| audio = snac_model.decode(codes_tensors) |
|
|
| return audio.squeeze().cpu().numpy() |
|
|
|
|
| def main(): |
| parser = argparse.ArgumentParser(description="Generate audio samples") |
| parser.add_argument("--dataset", type=str, help="Path to dataset .pt file") |
| parser.add_argument("--num-samples", type=int, default=5, help="Number of samples to generate") |
| parser.add_argument("--output-dir", type=str, default="./audio_samples", help="Output directory") |
| parser.add_argument("--with-tts", action="store_true", help="Also generate question audio with TTS") |
| args = parser.parse_args() |
|
|
| output_dir = Path(args.output_dir) |
| output_dir.mkdir(parents=True, exist_ok=True) |
|
|
| device = "cuda" if torch.cuda.is_available() else "cpu" |
| print(f"Using device: {device}") |
|
|
| |
| print("Loading SNAC model...") |
| import snac |
| snac_model = snac.SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device).eval() |
|
|
| |
| tts = None |
| if args.with_tts: |
| print("Loading TTS model...") |
| from soprano import SopranoTTS |
| tts = SopranoTTS(backend="transformers", device=device) |
|
|
| if args.dataset: |
| |
| print(f"Loading dataset: {args.dataset}") |
| dataset = torch.load(args.dataset, map_location="cpu", weights_only=False) |
| print(f"Loaded {len(dataset)} samples") |
|
|
| |
| for i in range(min(args.num_samples, len(dataset))): |
| sample = dataset[i] |
| print(f"\n--- Sample {i} ---") |
|
|
| question = sample.get("text", f"Question {i}") |
| answer = sample.get("answer", f"Answer {i}") |
| print(f"Q: {question}") |
| print(f"A: {answer}") |
|
|
| |
| snac_tokens = sample["snac_tokens"] |
| print(f"Decoding {len(snac_tokens)} SNAC tokens...") |
|
|
| try: |
| answer_audio = decode_snac_to_audio(snac_tokens, snac_model, device) |
| answer_path = output_dir / f"sample_{i:03d}_answer.wav" |
| sf.write(answer_path, answer_audio, 24000) |
| print(f"Saved: {answer_path}") |
| except Exception as e: |
| print(f"Error decoding answer: {e}") |
|
|
| |
| if tts: |
| try: |
| question_audio = tts.infer(question) |
| if hasattr(question_audio, 'cpu'): |
| question_audio = question_audio.cpu().numpy() |
| question_path = output_dir / f"sample_{i:03d}_question.wav" |
| sf.write(question_path, question_audio, 32000) |
| print(f"Saved: {question_path}") |
| except Exception as e: |
| print(f"Error generating question: {e}") |
|
|
| |
| text_path = output_dir / f"sample_{i:03d}.txt" |
| with open(text_path, "w") as f: |
| f.write(f"Question: {question}\n\nAnswer: {answer}\n") |
| print(f"Saved: {text_path}") |
|
|
| else: |
| |
| if not tts: |
| print("Loading TTS model...") |
| from soprano import SopranoTTS |
| tts = SopranoTTS(backend="transformers", device=device) |
|
|
| |
| samples = [ |
| ("What is the capital of France?", "Paris is the capital of France, known for the Eiffel Tower."), |
| ("How many planets are in our solar system?", "There are eight planets in our solar system."), |
| ("Who painted the Mona Lisa?", "Leonardo da Vinci painted the Mona Lisa in the early 16th century."), |
| ("What is the speed of light?", "The speed of light is approximately 299,792 kilometers per second."), |
| ("What is Python?", "Python is a popular programming language known for its simplicity."), |
| ] |
|
|
| for i, (question, answer) in enumerate(samples[:args.num_samples]): |
| print(f"\n--- Sample {i} ---") |
| print(f"Q: {question}") |
| print(f"A: {answer}") |
|
|
| try: |
| |
| q_audio = tts.infer(question) |
| a_audio = tts.infer(answer) |
|
|
| if hasattr(q_audio, 'cpu'): |
| q_audio = q_audio.cpu().numpy() |
| if hasattr(a_audio, 'cpu'): |
| a_audio = a_audio.cpu().numpy() |
|
|
| |
| sf.write(output_dir / f"sample_{i:03d}_question.wav", q_audio, 32000) |
| sf.write(output_dir / f"sample_{i:03d}_answer.wav", a_audio, 32000) |
|
|
| with open(output_dir / f"sample_{i:03d}.txt", "w") as f: |
| f.write(f"Question: {question}\n\nAnswer: {answer}\n") |
|
|
| print(f"Saved to {output_dir}/sample_{i:03d}_*.wav") |
|
|
| except Exception as e: |
| print(f"Error: {e}") |
|
|
| print(f"\n✅ Done! Samples saved to: {output_dir}") |
|
|
|
|
| if __name__ == "__main__": |
| main() |
|
|