| import subprocess |
| import sys |
| import os |
| import shutil |
| import site |
| import importlib |
|
|
| |
| for path in site.getsitepackages(): |
| dia2_path = os.path.join(path, "dia2") |
| if os.path.exists(dia2_path): |
| print(f"Removing: {dia2_path}") |
| shutil.rmtree(dia2_path) |
| for entry in os.listdir(path): |
| if entry.startswith("dia2") and (entry.endswith(".dist-info") or entry.endswith(".egg-info")): |
| shutil.rmtree(os.path.join(path, entry)) |
|
|
| |
| repo_dir = "/tmp/dia2_src" |
| if os.path.exists(repo_dir): |
| shutil.rmtree(repo_dir) |
| subprocess.run(["git", "clone", "https://github.com/nari-labs/dia2.git", repo_dir], check=True) |
|
|
| |
| subprocess.run([ |
| sys.executable, "-m", "pip", "install", "--quiet", "-e", repo_dir |
| ], check=True) |
| |
| sys.path.insert(0, repo_dir) |
| importlib.invalidate_caches() |
|
|
| import gradio as gr |
| import tempfile |
| import numpy as np |
| import soundfile as sf |
| import scipy.io.wavfile as wavfile |
| from dia2 import Dia2, GenerationConfig, SamplingConfig |
|
|
| print("Loading Dia2-1B model...") |
| model = Dia2.from_repo("nari-labs/Dia2-1B", device="cuda", dtype="bfloat16") |
| print("Model loaded!") |
|
|
| def prepare_audio(audio_path, max_seconds=10): |
| if audio_path is None: |
| return None |
| data, samplerate = sf.read(audio_path) |
| if len(data.shape) > 1: |
| data = data.mean(axis=1) |
| max_samples = int(max_seconds * samplerate) |
| if len(data) > max_samples: |
| data = data[:max_samples] |
| tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) |
| data_int16 = (data * 32767).astype(np.int16) |
| wavfile.write(tmp.name, samplerate, data_int16) |
| return tmp.name |
|
|
| def generate(text, speaker1_audio, speaker2_audio, cfg_scale, temperature, top_k): |
| config = GenerationConfig( |
| cfg_scale=cfg_scale, |
| audio=SamplingConfig(temperature=temperature, top_k=int(top_k)), |
| use_cuda_graph=True, |
| ) |
|
|
| output_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name |
|
|
| kwargs = dict(config=config, output_wav=output_file, verbose=True) |
|
|
| if speaker1_audio: |
| kwargs["prefix_speaker_1"] = prepare_audio(speaker1_audio) |
| if speaker2_audio: |
| kwargs["prefix_speaker_2"] = prepare_audio(speaker2_audio) |
|
|
| model.generate(text, **kwargs) |
| return output_file |
|
|
| demo = gr.Interface( |
| fn=generate, |
| inputs=[ |
| gr.Textbox( |
| label="Script ([S1]/[S2] speaker tags)", |
| placeholder="[S1] Hey, how are you?\n[S2] I'm doing great, thanks!", |
| lines=5 |
| ), |
| gr.Audio(label="Speaker 1 Voice Prompt (optional, max 10s)", type="filepath"), |
| gr.Audio(label="Speaker 2 Voice Prompt (optional, max 10s)", type="filepath"), |
| gr.Slider(1.0, 5.0, value=2.0, step=0.1, label="CFG Scale"), |
| gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature"), |
| gr.Slider(10, 100, value=50, step=5, label="Top K"), |
| ], |
| outputs=gr.Audio(label="Generated Audio"), |
| title="Dia2-2B TTS", |
| ) |
|
|
| demo.launch() |