roombox / smoke_test.py
ak36's picture
Upload folder using huggingface_hub
3e21dc5 verified
"""
Quick sanity‑check: make Dia speak one sentence and write mono WAV.
Run inside the container: python smoke_test.py
"""
import argparse
import soundfile as sf
import torch
from dia.model import Dia
# Parse command line arguments
parser = argparse.ArgumentParser(description="Dia model smoke test")
parser.add_argument("--device", type=str, default=None, help="Force device (e.g., 'cuda', 'cpu')")
args = parser.parse_args()
# Determine device
if args.device:
device = torch.device(args.device)
elif torch.cuda.is_available():
device = torch.device("cuda")
elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available():
device = torch.device("mps")
else:
device = torch.device("cpu")
print(f"Using device: {device}")
# Load Dia model
print("Loading Dia model...")
try:
model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16", device=device)
print("Model loaded successfully")
except Exception as e:
print(f"Error loading Dia model: {e}")
raise
# Generate audio
text = "[S1] Hello world, this is Dia on a clean build!"
print(f"Generating audio for: {text}")
waveform = model.generate(text) # returns (T,) float32 numpy, 24 kHz
print("Shape:", waveform.shape, "dtype:", waveform.dtype)
sf.write("dia_hello.wav", waveform, 24000)
print("Audio saved to dia_hello.wav")