dia2 / app.py
yvansevic's picture
Update app.py
bc2d6d0 verified
import subprocess
import sys
import os
import shutil
import site
import importlib
# Step 1: Delete wrong PyPI dia2 files
for path in site.getsitepackages():
dia2_path = os.path.join(path, "dia2")
if os.path.exists(dia2_path):
print(f"Removing: {dia2_path}")
shutil.rmtree(dia2_path)
for entry in os.listdir(path):
if entry.startswith("dia2") and (entry.endswith(".dist-info") or entry.endswith(".egg-info")):
shutil.rmtree(os.path.join(path, entry))
# Step 2: Clone real repo
repo_dir = "/tmp/dia2_src"
if os.path.exists(repo_dir):
shutil.rmtree(repo_dir)
subprocess.run(["git", "clone", "https://github.com/nari-labs/dia2.git", repo_dir], check=True)
# Step 3: Install deps from pyproject.toml (no-deps since torch already in requirements.txt)
subprocess.run([
sys.executable, "-m", "pip", "install", "--quiet", "-e", repo_dir
], check=True)
# Step 4: Bust cache and prepend path
sys.path.insert(0, repo_dir)
importlib.invalidate_caches()
import gradio as gr
import tempfile
import numpy as np
import soundfile as sf
import scipy.io.wavfile as wavfile
from dia2 import Dia2, GenerationConfig, SamplingConfig
print("Loading Dia2-1B model...")
model = Dia2.from_repo("nari-labs/Dia2-1B", device="cuda", dtype="bfloat16")
print("Model loaded!")
def prepare_audio(audio_path, max_seconds=10):
if audio_path is None:
return None
data, samplerate = sf.read(audio_path)
if len(data.shape) > 1:
data = data.mean(axis=1)
max_samples = int(max_seconds * samplerate)
if len(data) > max_samples:
data = data[:max_samples]
tmp = tempfile.NamedTemporaryFile(suffix=".wav", delete=False)
data_int16 = (data * 32767).astype(np.int16)
wavfile.write(tmp.name, samplerate, data_int16)
return tmp.name
def generate(text, speaker1_audio, speaker2_audio, cfg_scale, temperature, top_k):
config = GenerationConfig(
cfg_scale=cfg_scale,
audio=SamplingConfig(temperature=temperature, top_k=int(top_k)),
use_cuda_graph=True,
)
output_file = tempfile.NamedTemporaryFile(suffix=".wav", delete=False).name
kwargs = dict(config=config, output_wav=output_file, verbose=True)
if speaker1_audio:
kwargs["prefix_speaker_1"] = prepare_audio(speaker1_audio)
if speaker2_audio:
kwargs["prefix_speaker_2"] = prepare_audio(speaker2_audio)
model.generate(text, **kwargs)
return output_file
demo = gr.Interface(
fn=generate,
inputs=[
gr.Textbox(
label="Script ([S1]/[S2] speaker tags)",
placeholder="[S1] Hey, how are you?\n[S2] I'm doing great, thanks!",
lines=5
),
gr.Audio(label="Speaker 1 Voice Prompt (optional, max 10s)", type="filepath"),
gr.Audio(label="Speaker 2 Voice Prompt (optional, max 10s)", type="filepath"),
gr.Slider(1.0, 5.0, value=2.0, step=0.1, label="CFG Scale"),
gr.Slider(0.1, 2.0, value=0.8, step=0.05, label="Temperature"),
gr.Slider(10, 100, value=50, step=5, label="Top K"),
],
outputs=gr.Audio(label="Generated Audio"),
title="Dia2-2B TTS",
)
demo.launch()