omini-model / datasets /generate_samples.py
marcos
feat: Add full fine-tuning (no LoRA) and dataset generation tools
cbe0918
Raw
History Blame Contribute Delete
7.13 kB
#!/usr/bin/env python3
"""
Generate audio samples from dataset or directly from TTS.
Decodes SNAC tokens back to audio and saves as WAV files.
"""
import os
import sys
import argparse
import json
from pathlib import Path
import torch
import numpy as np
import soundfile as sf
# Patch torch.load for older models
_orig_load = torch.load
torch.load = lambda *a, **kw: _orig_load(*a, **{**kw, 'weights_only': False})
# Patch JSON encoder to handle torch.dtype (transformers bug workaround)
_orig_default = json.JSONEncoder.default
def _patched_default(self, obj):
if isinstance(obj, torch.dtype):
return str(obj).split('.')[-1]
return _orig_default(self, obj)
json.JSONEncoder.default = _patched_default
def decode_snac_to_audio(snac_tokens, snac_model, device="cuda"):
"""Decode SNAC tokens back to audio."""
SNAC_BASE = 128266
# Remove offsets to get raw tokens
tokens = []
for tok in snac_tokens:
tok = int(tok)
if tok >= SNAC_BASE:
# Remove position-based offset
position_offset = ((tok - SNAC_BASE) // 4096) * 4096
raw_tok = tok - SNAC_BASE - position_offset
tokens.append(raw_tok)
else:
tokens.append(tok)
# Reshape into 3 codebooks (7 tokens per frame: 1 + 2 + 4)
num_frames = len(tokens) // 7
codes = [[], [], []]
for i in range(num_frames):
base = i * 7
codes[0].append(tokens[base]) # 1 token from codebook 0
codes[1].append(tokens[base + 1]) # 2 tokens from codebook 1
codes[1].append(tokens[base + 2])
codes[2].append(tokens[base + 3]) # 4 tokens from codebook 2
codes[2].append(tokens[base + 4])
codes[2].append(tokens[base + 5])
codes[2].append(tokens[base + 6])
# Convert to tensors
codes_tensors = [
torch.tensor(codes[0], dtype=torch.long).unsqueeze(0).to(device),
torch.tensor(codes[1], dtype=torch.long).unsqueeze(0).to(device),
torch.tensor(codes[2], dtype=torch.long).unsqueeze(0).to(device),
]
# Decode with SNAC
with torch.no_grad():
audio = snac_model.decode(codes_tensors)
return audio.squeeze().cpu().numpy()
def main():
parser = argparse.ArgumentParser(description="Generate audio samples")
parser.add_argument("--dataset", type=str, help="Path to dataset .pt file")
parser.add_argument("--num-samples", type=int, default=5, help="Number of samples to generate")
parser.add_argument("--output-dir", type=str, default="./audio_samples", help="Output directory")
parser.add_argument("--with-tts", action="store_true", help="Also generate question audio with TTS")
args = parser.parse_args()
output_dir = Path(args.output_dir)
output_dir.mkdir(parents=True, exist_ok=True)
device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"Using device: {device}")
# Load SNAC model
print("Loading SNAC model...")
import snac
snac_model = snac.SNAC.from_pretrained("hubertsiuzdak/snac_24khz").to(device).eval()
# Load TTS if needed
tts = None
if args.with_tts:
print("Loading TTS model...")
from soprano import SopranoTTS
tts = SopranoTTS(backend="transformers", device=device)
if args.dataset:
# Load dataset
print(f"Loading dataset: {args.dataset}")
dataset = torch.load(args.dataset, map_location="cpu", weights_only=False)
print(f"Loaded {len(dataset)} samples")
# Generate samples
for i in range(min(args.num_samples, len(dataset))):
sample = dataset[i]
print(f"\n--- Sample {i} ---")
question = sample.get("text", f"Question {i}")
answer = sample.get("answer", f"Answer {i}")
print(f"Q: {question}")
print(f"A: {answer}")
# Decode answer audio from SNAC tokens
snac_tokens = sample["snac_tokens"]
print(f"Decoding {len(snac_tokens)} SNAC tokens...")
try:
answer_audio = decode_snac_to_audio(snac_tokens, snac_model, device)
answer_path = output_dir / f"sample_{i:03d}_answer.wav"
sf.write(answer_path, answer_audio, 24000)
print(f"Saved: {answer_path}")
except Exception as e:
print(f"Error decoding answer: {e}")
# Generate question audio with TTS
if tts:
try:
question_audio = tts.infer(question)
if hasattr(question_audio, 'cpu'):
question_audio = question_audio.cpu().numpy()
question_path = output_dir / f"sample_{i:03d}_question.wav"
sf.write(question_path, question_audio, 32000)
print(f"Saved: {question_path}")
except Exception as e:
print(f"Error generating question: {e}")
# Save text
text_path = output_dir / f"sample_{i:03d}.txt"
with open(text_path, "w") as f:
f.write(f"Question: {question}\n\nAnswer: {answer}\n")
print(f"Saved: {text_path}")
else:
# Generate fresh samples with TTS
if not tts:
print("Loading TTS model...")
from soprano import SopranoTTS
tts = SopranoTTS(backend="transformers", device=device)
# Sample Q&A pairs
samples = [
("What is the capital of France?", "Paris is the capital of France, known for the Eiffel Tower."),
("How many planets are in our solar system?", "There are eight planets in our solar system."),
("Who painted the Mona Lisa?", "Leonardo da Vinci painted the Mona Lisa in the early 16th century."),
("What is the speed of light?", "The speed of light is approximately 299,792 kilometers per second."),
("What is Python?", "Python is a popular programming language known for its simplicity."),
]
for i, (question, answer) in enumerate(samples[:args.num_samples]):
print(f"\n--- Sample {i} ---")
print(f"Q: {question}")
print(f"A: {answer}")
try:
# Generate audio
q_audio = tts.infer(question)
a_audio = tts.infer(answer)
if hasattr(q_audio, 'cpu'):
q_audio = q_audio.cpu().numpy()
if hasattr(a_audio, 'cpu'):
a_audio = a_audio.cpu().numpy()
# Save
sf.write(output_dir / f"sample_{i:03d}_question.wav", q_audio, 32000)
sf.write(output_dir / f"sample_{i:03d}_answer.wav", a_audio, 32000)
with open(output_dir / f"sample_{i:03d}.txt", "w") as f:
f.write(f"Question: {question}\n\nAnswer: {answer}\n")
print(f"Saved to {output_dir}/sample_{i:03d}_*.wav")
except Exception as e:
print(f"Error: {e}")
print(f"\n✅ Done! Samples saved to: {output_dir}")
if __name__ == "__main__":
main()