import tempfile import time from pathlib import Path from typing import Optional, Tuple import spaces import gradio as gr import numpy as np import soundfile as sf import torch from dia.model import Dia from transformers import pipeline # Load Nari model print("Loading Nari model...") try: model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32") except Exception as e: print(f"Error loading Nari model: {e}") raise # Load summarization model print("Loading summarizer model...") try: summarizer = pipeline("summarization", model="facebook/bart-large-cnn") except Exception as e: print(f"Error loading summarizer: {e}") summarizer = None @spaces.GPU def run_inference( text_input: str, audio_prompt_input: Optional[Tuple[int, np.ndarray]], max_new_tokens: int, cfg_scale: float, temperature: float, top_p: float, cfg_filter_top_k: int, speed_factor: float, apply_summary: bool, ): """ Runs Nari inference using the globally loaded model and provided inputs. Uses temporary files for text and audio prompt compatibility with inference.generate. """ if not text_input or text_input.isspace(): raise gr.Error("Text input cannot be empty.") temp_audio_prompt_path = None output_audio = (44100, np.zeros(1, dtype=np.float32)) try: # Optionally summarize text if apply_summary and summarizer is not None: print("Summarizing input text...") summarized = summarizer(text_input, max_length=150, min_length=30, do_sample=False) if summarized and isinstance(summarized, list): text_input = summarized[0]["summary_text"] print(f"Summarized Text: {text_input}") # Process Audio Prompt prompt_path_for_generate = None if audio_prompt_input is not None: sr, audio_data = audio_prompt_input if audio_data is not None and audio_data.size != 0 and audio_data.max() != 0: with tempfile.NamedTemporaryFile(mode="wb", suffix=".wav", delete=False) as f_audio: temp_audio_prompt_path = f_audio.name if np.issubdtype(audio_data.dtype, np.integer): max_val = np.iinfo(audio_data.dtype).max audio_data = audio_data.astype(np.float32) / max_val elif not np.issubdtype(audio_data.dtype, np.floating): try: audio_data = audio_data.astype(np.float32) except Exception as conv_e: raise gr.Error(f"Failed to convert audio prompt to float32: {conv_e}") if audio_data.ndim > 1: audio_data = np.mean(audio_data, axis=-1) audio_data = np.ascontiguousarray(audio_data) try: sf.write(temp_audio_prompt_path, audio_data, sr, subtype="FLOAT") prompt_path_for_generate = temp_audio_prompt_path print(f"Saved temporary audio prompt: {temp_audio_prompt_path}") except Exception as write_e: raise gr.Error(f"Failed to save audio prompt: {write_e}") # Multi-Voice Handling text_segments = split_by_speaker(text_input) print(f"Detected {len(text_segments)} speaker segments.") final_audio = [] start_time = time.time() for idx, segment in enumerate(text_segments): if not segment.strip(): continue with torch.inference_mode(): output_audio_np = model.generate( segment, max_tokens=max_new_tokens, cfg_scale=cfg_scale, temperature=temperature, top_p=top_p, cfg_filter_top_k=cfg_filter_top_k, use_torch_compile=False, audio_prompt=prompt_path_for_generate, ) if output_audio_np is not None: final_audio.append(output_audio_np) if final_audio: output_audio_np = np.concatenate(final_audio) end_time = time.time() print(f"Generation completed in {end_time - start_time:.2f}s.") # Resample for speed adjustment output_sr = 44100 original_len = len(output_audio_np) speed_factor = max(0.1, min(speed_factor, 5.0)) target_len = int(original_len / speed_factor) if target_len != original_len and target_len > 0: x_original = np.arange(original_len) x_resampled = np.linspace(0, original_len - 1, target_len) resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np) output_audio = (output_sr, resampled_audio_np.astype(np.float32)) else: output_audio = (output_sr, output_audio_np) # Convert float32 audio to int16 for Gradio audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0) audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16) output_audio = (output_sr, audio_for_gradio) except Exception as e: import traceback traceback.print_exc() raise gr.Error(f"Inference failed: {e}") finally: if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists(): try: Path(temp_audio_prompt_path).unlink() print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}") except Exception as e: print(f"Warning: {e}") return output_audio def split_by_speaker(text: str): """Split text into segments by speaker labels like [S1], [S2], etc.""" import re segments = re.split(r'(?=\[S\d\])', text) return [seg.strip() for seg in segments if seg.strip()] # --- Build Gradio UI --- css = """ #col-container {max-width: 90%; margin-left: auto; margin-right: auto;} """ default_text = "[S1] Hello there! How are you? \n[S2] I'm great, thanks! And you? \n[S1] Doing well! (laughs)" example_txt_path = Path("./example.txt") if example_txt_path.exists(): try: file_text = example_txt_path.read_text(encoding="utf-8").strip() if file_text: default_text = file_text except Exception: pass with gr.Blocks(css=css) as demo: gr.Markdown("# Nari Text-to-Speech with Multi-Voice and Summarization") with gr.Row(equal_height=False): with gr.Column(scale=1): text_input = gr.Textbox( label="Input Text", placeholder="Enter multi-speaker dialogue...", value=default_text, lines=8, ) audio_prompt_input = gr.Audio( label="Audio Prompt (Optional)", show_label=True, sources=["upload", "microphone"], type="numpy", ) with gr.Accordion("Advanced Settings", open=False): max_new_tokens = gr.Slider( label="Max New Tokens", minimum=860, maximum=3072, value=model.config.data.audio_length, step=50, ) cfg_scale = gr.Slider( label="CFG Scale", minimum=1.0, maximum=5.0, value=3.0, step=0.1, ) temperature = gr.Slider( label="Temperature", minimum=1.0, maximum=1.5, value=1.3, step=0.05, ) top_p = gr.Slider( label="Top P", minimum=0.8, maximum=1.0, value=0.95, step=0.01, ) cfg_filter_top_k = gr.Slider( label="CFG Filter Top K", minimum=15, maximum=50, value=30, step=1, ) speed_factor_slider = gr.Slider( label="Speed Factor", minimum=0.5, maximum=1.5, value=0.94, step=0.02, ) apply_summary = gr.Checkbox( label="Summarize Input Text before Generation?", value=False, ) run_button = gr.Button("Generate Audio", variant="primary") with gr.Column(scale=1): audio_output = gr.Audio( label="Generated Audio", type="numpy", autoplay=False, ) run_button.click( fn=run_inference, inputs=[ text_input, audio_prompt_input, max_new_tokens, cfg_scale, temperature, top_p, cfg_filter_top_k, speed_factor_slider, apply_summary, ], outputs=[audio_output], api_name="generate_audio", ) # --- Launch --- if __name__ == "__main__": print("Launching Gradio app...") demo.launch()