import subprocess import sys import os import shutil import site import importlib # Step 1: Delete wrong PyPI dia2 files for path in site.getsitepackages(): dia2_path = os.path.join(path, "dia2") if os.path.exists(dia2_path): print(f"Removing: {dia2_path}") shutil.rmtree(dia2_path) for entry in os.listdir(path): if entry.startswith("dia2") and (entry.endswith(".dist-info") or entry.endswith(".egg-info")): shutil.rmtree(os.path.join(path, entry)) # Step 2: Clone real repo repo_dir = "/tmp/dia2_src" if os.path.exists(repo_dir): shutil.rmtree(repo_dir) subprocess.run(["git", "clone", "https://github.com/nari-labs/dia2.git", repo_dir], check=True) # Step 3: Install deps from pyproject.toml (no-deps since torch already in requirements.txt) subprocess.run([ sys.executable, "-m", "pip", "install", "--quiet", "-e", repo_dir ], check=True) # Step 4: Bust cache and prepend path sys.path.insert(0, repo_dir) importlib.invalidate_caches() import gradio as gr import tempfile import numpy as np import soundfile as sf import scipy.io.wavfile as wavfile from dia2 import Dia2, GenerationConfig, SamplingConfig print("Loading Dia2-1B model...") model = Dia2.from_repo("nari-labs/Dia2-1B", device="cuda", dtype="bfloat16") print("Model loaded!") def prepare_audio(audio_path, max_seconds=10): if audio_path is None: return None data, samplerate = sf.read(audio_path) if len(data.shape) > 1: data = data.mean(axis=1) max_samples = int(max_seconds * samplerate) if len(data) > max_samples: data = data[:max_samples] tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False) data_int16 = (data * 32767).astype(np.int16) wavfile.write(tmp.name, samplerate, data_int16) return tmp.name def generate(text, speaker1_audio, speaker2_audio, cfg_scale, temperature, top_k): config = GenerationConfig( cfg_scale=cfg_scale, audio=SamplingConfig(temperature=temperature, top_k=int(top_k)), use_cuda_graph=True, ) output_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name kwargs = dict(config=config, output_wav=output_file, verbose=True) if speaker1_audio: kwargs["prefix_speaker_1"] = prepare_audio(speaker1_audio) if speaker2_audio: kwargs["prefix_speaker_2"] = prepare_audio(speaker2_audio) model.generate(text, **kwargs) return output_file demo = gr.Interface( fn=generate, inputs=[ gr.Textbox( label="Script ([S1]/[S2] speaker tags)", placeholder="[S1] Hey, how are you?\n[S2] I'm doing great, thanks!", lines=5 ), gr.Audio(label="Speaker 1 Voice Prompt (optional, max 10s)", type="filepath"), gr.Audio(label="Speaker 2 Voice Prompt (optional, max 10s)", type="filepath"), gr.Slider(1.0, 5.0, value=2.0, step=0.1, label="CFG Scale"), gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature"), gr.Slider(10, 100, value=50, step=5, label="Top K"), ], outputs=gr.Audio(label="Generated Audio"), title="Dia2-2B TTS", ) demo.launch()