Spaces:
Build error
Build error
| import tempfile | |
| import time | |
| from pathlib import Path | |
| from typing import Optional, Tuple | |
| import spaces | |
| import gradio as gr | |
| import numpy as np | |
| import soundfile as sf | |
| import torch | |
| from dia.model import Dia | |
| from transformers import pipeline | |
| # Load Nari model | |
| print("Loading Nari model...") | |
| try: | |
| model = Dia.from_pretrained("nari-labs/Dia-1.6B", compute_dtype="float32") | |
| except Exception as e: | |
| print(f"Error loading Nari model: {e}") | |
| raise | |
| # Load summarization model | |
| print("Loading summarizer model...") | |
| try: | |
| summarizer = pipeline("summarization", model="facebook/bart-large-cnn") | |
| except Exception as e: | |
| print(f"Error loading summarizer: {e}") | |
| summarizer = None | |
| def run_inference( | |
| text_input: str, | |
| audio_prompt_input: Optional[Tuple[int, np.ndarray]], | |
| max_new_tokens: int, | |
| cfg_scale: float, | |
| temperature: float, | |
| top_p: float, | |
| cfg_filter_top_k: int, | |
| speed_factor: float, | |
| apply_summary: bool, | |
| ): | |
| """ | |
| Runs Nari inference using the globally loaded model and provided inputs. | |
| Uses temporary files for text and audio prompt compatibility with inference.generate. | |
| """ | |
| if not text_input or text_input.isspace(): | |
| raise gr.Error("Text input cannot be empty.") | |
| temp_audio_prompt_path = None | |
| output_audio = (44100, np.zeros(1, dtype=np.float32)) | |
| try: | |
| # Optionally summarize text | |
| if apply_summary and summarizer is not None: | |
| print("Summarizing input text...") | |
| summarized = summarizer(text_input, max_length=150, min_length=30, do_sample=False) | |
| if summarized and isinstance(summarized, list): | |
| text_input = summarized[0]["summary_text"] | |
| print(f"Summarized Text: {text_input}") | |
| # Process Audio Prompt | |
| prompt_path_for_generate = None | |
| if audio_prompt_input is not None: | |
| sr, audio_data = audio_prompt_input | |
| if audio_data is not None and audio_data.size != 0 and audio_data.max() != 0: | |
| with tempfile.NamedTemporaryFile(mode="wb", suffix=".wav", delete=False) as f_audio: | |
| temp_audio_prompt_path = f_audio.name | |
| if np.issubdtype(audio_data.dtype, np.integer): | |
| max_val = np.iinfo(audio_data.dtype).max | |
| audio_data = audio_data.astype(np.float32) / max_val | |
| elif not np.issubdtype(audio_data.dtype, np.floating): | |
| try: | |
| audio_data = audio_data.astype(np.float32) | |
| except Exception as conv_e: | |
| raise gr.Error(f"Failed to convert audio prompt to float32: {conv_e}") | |
| if audio_data.ndim > 1: | |
| audio_data = np.mean(audio_data, axis=-1) | |
| audio_data = np.ascontiguousarray(audio_data) | |
| try: | |
| sf.write(temp_audio_prompt_path, audio_data, sr, subtype="FLOAT") | |
| prompt_path_for_generate = temp_audio_prompt_path | |
| print(f"Saved temporary audio prompt: {temp_audio_prompt_path}") | |
| except Exception as write_e: | |
| raise gr.Error(f"Failed to save audio prompt: {write_e}") | |
| # Multi-Voice Handling | |
| text_segments = split_by_speaker(text_input) | |
| print(f"Detected {len(text_segments)} speaker segments.") | |
| final_audio = [] | |
| start_time = time.time() | |
| for idx, segment in enumerate(text_segments): | |
| if not segment.strip(): | |
| continue | |
| with torch.inference_mode(): | |
| output_audio_np = model.generate( | |
| segment, | |
| max_tokens=max_new_tokens, | |
| cfg_scale=cfg_scale, | |
| temperature=temperature, | |
| top_p=top_p, | |
| cfg_filter_top_k=cfg_filter_top_k, | |
| use_torch_compile=False, | |
| audio_prompt=prompt_path_for_generate, | |
| ) | |
| if output_audio_np is not None: | |
| final_audio.append(output_audio_np) | |
| if final_audio: | |
| output_audio_np = np.concatenate(final_audio) | |
| end_time = time.time() | |
| print(f"Generation completed in {end_time - start_time:.2f}s.") | |
| # Resample for speed adjustment | |
| output_sr = 44100 | |
| original_len = len(output_audio_np) | |
| speed_factor = max(0.1, min(speed_factor, 5.0)) | |
| target_len = int(original_len / speed_factor) | |
| if target_len != original_len and target_len > 0: | |
| x_original = np.arange(original_len) | |
| x_resampled = np.linspace(0, original_len - 1, target_len) | |
| resampled_audio_np = np.interp(x_resampled, x_original, output_audio_np) | |
| output_audio = (output_sr, resampled_audio_np.astype(np.float32)) | |
| else: | |
| output_audio = (output_sr, output_audio_np) | |
| # Convert float32 audio to int16 for Gradio | |
| audio_for_gradio = np.clip(output_audio[1], -1.0, 1.0) | |
| audio_for_gradio = (audio_for_gradio * 32767).astype(np.int16) | |
| output_audio = (output_sr, audio_for_gradio) | |
| except Exception as e: | |
| import traceback | |
| traceback.print_exc() | |
| raise gr.Error(f"Inference failed: {e}") | |
| finally: | |
| if temp_audio_prompt_path and Path(temp_audio_prompt_path).exists(): | |
| try: | |
| Path(temp_audio_prompt_path).unlink() | |
| print(f"Deleted temporary audio prompt file: {temp_audio_prompt_path}") | |
| except Exception as e: | |
| print(f"Warning: {e}") | |
| return output_audio | |
| def split_by_speaker(text: str): | |
| """Split text into segments by speaker labels like [S1], [S2], etc.""" | |
| import re | |
| segments = re.split(r'(?=\[S\d\])', text) | |
| return [seg.strip() for seg in segments if seg.strip()] | |
| # --- Build Gradio UI --- | |
| css = """ | |
| #col-container {max-width: 90%; margin-left: auto; margin-right: auto;} | |
| """ | |
| default_text = "[S1] Hello there! How are you? \n[S2] I'm great, thanks! And you? \n[S1] Doing well! (laughs)" | |
| example_txt_path = Path("./example.txt") | |
| if example_txt_path.exists(): | |
| try: | |
| file_text = example_txt_path.read_text(encoding="utf-8").strip() | |
| if file_text: | |
| default_text = file_text | |
| except Exception: | |
| pass | |
| with gr.Blocks(css=css) as demo: | |
| gr.Markdown("# Nari Text-to-Speech with Multi-Voice and Summarization") | |
| with gr.Row(equal_height=False): | |
| with gr.Column(scale=1): | |
| text_input = gr.Textbox( | |
| label="Input Text", | |
| placeholder="Enter multi-speaker dialogue...", | |
| value=default_text, | |
| lines=8, | |
| ) | |
| audio_prompt_input = gr.Audio( | |
| label="Audio Prompt (Optional)", | |
| show_label=True, | |
| sources=["upload", "microphone"], | |
| type="numpy", | |
| ) | |
| with gr.Accordion("Advanced Settings", open=False): | |
| max_new_tokens = gr.Slider( | |
| label="Max New Tokens", | |
| minimum=860, | |
| maximum=3072, | |
| value=model.config.data.audio_length, | |
| step=50, | |
| ) | |
| cfg_scale = gr.Slider( | |
| label="CFG Scale", | |
| minimum=1.0, | |
| maximum=5.0, | |
| value=3.0, | |
| step=0.1, | |
| ) | |
| temperature = gr.Slider( | |
| label="Temperature", | |
| minimum=1.0, | |
| maximum=1.5, | |
| value=1.3, | |
| step=0.05, | |
| ) | |
| top_p = gr.Slider( | |
| label="Top P", | |
| minimum=0.8, | |
| maximum=1.0, | |
| value=0.95, | |
| step=0.01, | |
| ) | |
| cfg_filter_top_k = gr.Slider( | |
| label="CFG Filter Top K", | |
| minimum=15, | |
| maximum=50, | |
| value=30, | |
| step=1, | |
| ) | |
| speed_factor_slider = gr.Slider( | |
| label="Speed Factor", | |
| minimum=0.5, | |
| maximum=1.5, | |
| value=0.94, | |
| step=0.02, | |
| ) | |
| apply_summary = gr.Checkbox( | |
| label="Summarize Input Text before Generation?", | |
| value=False, | |
| ) | |
| run_button = gr.Button("Generate Audio", variant="primary") | |
| with gr.Column(scale=1): | |
| audio_output = gr.Audio( | |
| label="Generated Audio", | |
| type="numpy", | |
| autoplay=False, | |
| ) | |
| run_button.click( | |
| fn=run_inference, | |
| inputs=[ | |
| text_input, | |
| audio_prompt_input, | |
| max_new_tokens, | |
| cfg_scale, | |
| temperature, | |
| top_p, | |
| cfg_filter_top_k, | |
| speed_factor_slider, | |
| apply_summary, | |
| ], | |
| outputs=[audio_output], | |
| api_name="generate_audio", | |
| ) | |
| # --- Launch --- | |
| if __name__ == "__main__": | |
| print("Launching Gradio app...") | |
| demo.launch() | |