| import torch |
| from transformers import AutoProcessor, DiaForConditionalGeneration |
| import gradio as gr |
|
|
| |
| device = "cuda" if torch.cuda.is_available() else "cpu" |
|
|
| |
| checkpoint = "nari-labs/Dia-1.6B-0626" |
| processor = AutoProcessor.from_pretrained(checkpoint) |
| model = DiaForConditionalGeneration.from_pretrained(checkpoint).to(device) |
|
|
| def tts_dialogue(script: str): |
| """ |
| Expects script formatted with [S1], [S2] tags for dialogue. |
| Example: "[S1] Hello there! [S2] Hi, how are you?" |
| """ |
| inputs = processor(text=script, return_tensors="pt", padding=True).to(device) |
| outputs = model.generate( |
| **inputs, |
| max_new_tokens=3072, |
| guidance_scale=3.0, |
| temperature=1.8, |
| top_p=0.9, |
| top_k=45 |
| ) |
| audio_list = processor.batch_decode(outputs) |
| return (audio_list[0],) |
|
|
| iface = gr.Interface( |
| fn=tts_dialogue, |
| inputs=gr.Textbox(label="Dialogue Script", placeholder="[S1] Hello [S2] Hi!"), |
| outputs=gr.Audio(label="Generated Audio"), |
| title="📢 Dia 1.6B TTS Dialogue Demo", |
| description="A demo using Dia 1.6B for expressive, multi‑speaker TTS" |
| ) |
|
|
| if __name__ == "__main__": |
| iface.launch() |
|
|