| | """ |
| | Quick sanity‑check: make Dia speak one sentence and write mono WAV. |
| | Run inside the container: python smoke_test.py |
| | """ |
| | import argparse |
| | import soundfile as sf |
| | import torch |
| |
|
| | from dia.model import Dia |
| |
|
| | |
| | parser = argparse.ArgumentParser(description="Dia model smoke test") |
| | parser.add_argument("--device", type=str, default=None, help="Force device (e.g., 'cuda', 'cpu')") |
| | args = parser.parse_args() |
| |
|
| | |
| | if args.device: |
| | device = torch.device(args.device) |
| | elif torch.cuda.is_available(): |
| | device = torch.device("cuda") |
| | elif hasattr(torch.backends, "mps") and torch.backends.mps.is_available(): |
| | device = torch.device("mps") |
| | else: |
| | device = torch.device("cpu") |
| |
|
| | print(f"Using device: {device}") |
| |
|
| | |
| | print("Loading Dia model...") |
| | try: |
| | model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float16", device=device) |
| | print("Model loaded successfully") |
| | except Exception as e: |
| | print(f"Error loading Dia model: {e}") |
| | raise |
| |
|
| | |
| | text = "[S1] Hello world, this is Dia on a clean build!" |
| | print(f"Generating audio for: {text}") |
| | waveform = model.generate(text) |
| |
|
| | print("Shape:", waveform.shape, "dtype:", waveform.dtype) |
| | sf.write("dia_hello.wav", waveform, 24000) |
| | print("Audio saved to dia_hello.wav") |
| |
|