Spaces:
Running
Running
| # server.py | |
| # Main FastAPI server for Dia TTS | |
| import sys | |
| import logging | |
| import time | |
| import os | |
| import io | |
| import uuid | |
| import sys | |
| import shutil # For file copying | |
| import yaml # For loading presets | |
| from datetime import datetime | |
| from contextlib import asynccontextmanager | |
| from typing import Optional, Literal, List, Dict, Any | |
| import webbrowser | |
| import threading | |
| import time | |
| from fastapi import ( | |
| FastAPI, | |
| HTTPException, | |
| Request, | |
| Response, | |
| Form, | |
| UploadFile, | |
| File, | |
| BackgroundTasks, | |
| ) | |
| from fastapi.responses import ( | |
| StreamingResponse, | |
| JSONResponse, | |
| HTMLResponse, | |
| RedirectResponse, | |
| ) | |
| from fastapi.staticfiles import StaticFiles | |
| from fastapi.templating import Jinja2Templates | |
| import uvicorn | |
| import numpy as np | |
| # Internal imports | |
| from config import ( | |
| config_manager, | |
| get_host, | |
| get_port, | |
| get_output_path, | |
| get_reference_audio_path, | |
| # register_config_routes is now defined locally | |
| get_model_cache_path, | |
| get_model_repo_id, | |
| get_model_config_filename, | |
| get_model_weights_filename, | |
| # Generation default getters | |
| get_gen_default_speed_factor, | |
| get_gen_default_cfg_scale, | |
| get_gen_default_temperature, | |
| get_gen_default_top_p, | |
| get_gen_default_cfg_filter_top_k, | |
| DEFAULT_CONFIG, | |
| ) | |
| from models import OpenAITTSRequest, CustomTTSRequest, ErrorResponse | |
| import engine | |
| from engine import ( | |
| load_model as load_dia_model, | |
| generate_speech, | |
| EXPECTED_SAMPLE_RATE, | |
| ) | |
| from utils import encode_audio, save_audio_to_file, PerformanceMonitor | |
| # Configure logging (Basic setup, can be enhanced) | |
| logging.basicConfig( | |
| level=logging.INFO, format="%(asctime)s [%(levelname)s] %(name)s: %(message)s" | |
| ) | |
| # Reduce verbosity of noisy libraries if needed | |
| # logging.getLogger("uvicorn.access").setLevel(logging.WARNING) | |
| # logging.getLogger("watchfiles").setLevel(logging.WARNING) | |
| logger = logging.getLogger(__name__) # Logger for this module | |
| # --- Global Variables & Constants --- | |
| PRESETS_FILE = "ui/presets.yaml" | |
| loaded_presets: List[Dict[str, Any]] = [] # Cache presets in memory | |
| startup_complete_event = threading.Event() | |
| # --- Helper Functions --- | |
| def load_presets(): | |
| """Loads presets from the YAML file.""" | |
| global loaded_presets | |
| try: | |
| if os.path.exists(PRESETS_FILE): | |
| with open(PRESETS_FILE, "r", encoding="utf-8") as f: | |
| loaded_presets = yaml.safe_load(f) | |
| if not isinstance(loaded_presets, list): | |
| logger.error( | |
| f"Presets file '{PRESETS_FILE}' should contain a list, but found {type(loaded_presets)}. No presets loaded." | |
| ) | |
| loaded_presets = [] | |
| else: | |
| logger.info( | |
| f"Successfully loaded {len(loaded_presets)} presets from {PRESETS_FILE}." | |
| ) | |
| else: | |
| logger.warning( | |
| f"Presets file not found at '{PRESETS_FILE}'. No presets will be available." | |
| ) | |
| loaded_presets = [] | |
| except yaml.YAMLError as e: | |
| logger.error( | |
| f"Error parsing presets YAML file '{PRESETS_FILE}': {e}", exc_info=True | |
| ) | |
| loaded_presets = [] | |
| except Exception as e: | |
| logger.error(f"Error loading presets file '{PRESETS_FILE}': {e}", exc_info=True) | |
| loaded_presets = [] | |
| def get_valid_reference_files() -> list[str]: | |
| """Gets a list of valid audio files (.wav, .mp3) from the reference directory.""" | |
| ref_path = get_reference_audio_path() | |
| valid_files = [] | |
| allowed_extensions = (".wav", ".mp3") | |
| try: | |
| if os.path.isdir(ref_path): | |
| for filename in os.listdir(ref_path): | |
| if filename.lower().endswith(allowed_extensions): | |
| # Optional: Add check for file size or basic validity if needed | |
| valid_files.append(filename) | |
| else: | |
| logger.warning(f"Reference audio directory not found: {ref_path}") | |
| except Exception as e: | |
| logger.error( | |
| f"Error reading reference audio directory '{ref_path}': {e}", exc_info=True | |
| ) | |
| return sorted(valid_files) | |
| def sanitize_filename(filename: str) -> str: | |
| """Removes potentially unsafe characters and path components from a filename.""" | |
| # Remove directory separators | |
| filename = os.path.basename(filename) | |
| # Keep only alphanumeric, underscore, hyphen, dot. Replace others with underscore. | |
| safe_chars = set( | |
| "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789._-" | |
| ) | |
| sanitized = "".join(c if c in safe_chars else "_" for c in filename) | |
| # Prevent names starting with dot or consisting only of dots/spaces | |
| if not sanitized or sanitized.lstrip("._ ") == "": | |
| return f"uploaded_file_{uuid.uuid4().hex[:8]}" # Generate a safe fallback name | |
| # Limit length | |
| max_len = 100 | |
| if len(sanitized) > max_len: | |
| name, ext = os.path.splitext(sanitized) | |
| sanitized = name[: max_len - len(ext)] + ext | |
| return sanitized | |
| # --- Application Lifespan (Startup/Shutdown) --- | |
| async def lifespan(app: FastAPI): | |
| """Application lifespan manager for startup/shutdown.""" | |
| model_loaded_successfully = False # Flag to track success | |
| try: | |
| logger.info("Starting Dia TTS server initialization...") | |
| # Ensure base directories exist | |
| os.makedirs(get_output_path(), exist_ok=True) | |
| os.makedirs(get_reference_audio_path(), exist_ok=True) | |
| os.makedirs(get_model_cache_path(), exist_ok=True) | |
| os.makedirs("ui", exist_ok=True) | |
| os.makedirs("static", exist_ok=True) | |
| # Load presets from YAML file | |
| load_presets() | |
| # Load the main TTS model during startup | |
| if not load_dia_model(): | |
| # Model loading failed | |
| error_msg = ( | |
| "CRITICAL: Failed to load Dia model on startup. Server cannot start." | |
| ) | |
| logger.critical(error_msg) | |
| # Option 1: Raise an exception to stop Uvicorn startup cleanly | |
| raise RuntimeError(error_msg) | |
| # Option 2: Force exit (less clean, might bypass some Uvicorn shutdown) | |
| # sys.exit(1) | |
| else: | |
| logger.info("Dia model loaded successfully.") | |
| model_loaded_successfully = True | |
| # Create and start a delayed browser opening thread | |
| # IMPORTANT: Create this thread AFTER model loading completes | |
| host = get_host() | |
| port = get_port() | |
| browser_thread = threading.Thread( | |
| target=lambda: _delayed_browser_open(host, port), daemon=True | |
| ) | |
| browser_thread.start() | |
| # --- Signal completion AFTER potentially long operations --- | |
| logger.info("Application startup sequence finished. Signaling readiness.") | |
| startup_complete_event.set() | |
| yield # Application runs here | |
| except Exception as e: | |
| # Catch the RuntimeError we raised or any other startup error | |
| logger.error(f"Fatal error during application startup: {e}", exc_info=True) | |
| # Do NOT set the event here if startup failed | |
| # Re-raise the exception or exit to ensure the server stops | |
| raise e # Re-raising ensures Uvicorn knows startup failed | |
| # Alternatively: sys.exit(1) | |
| finally: | |
| # Cleanup on shutdown | |
| logger.info("Application shutdown initiated...") | |
| # Add any specific cleanup needed | |
| logger.info("Application shutdown complete.") | |
| def _delayed_browser_open(host, port): | |
| """Opens browser after a short delay to ensure server is ready""" | |
| try: | |
| # Small delay to ensure Uvicorn is fully ready | |
| time.sleep(2) | |
| display_host = "localhost" if host == "0.0.0.0" else host | |
| browser_url = f"http://{display_host}:{port}/" | |
| # Log to file for debugging | |
| with open("browser_thread_debug.log", "a") as f: | |
| f.write(f"[{time.time()}] Opening browser at {browser_url}\n") | |
| # Try to use logger as well (might work at this point) | |
| try: | |
| logger.info(f"Opening browser at {browser_url}") | |
| except: | |
| pass | |
| # Open browser directly without health checks | |
| webbrowser.open(browser_url) | |
| except Exception as e: | |
| with open("browser_thread_debug.log", "a") as f: | |
| f.write(f"[{time.time()}] Browser open error: {str(e)}\n") | |
| # --- FastAPI App Initialization --- | |
| app = FastAPI( | |
| title="Dia TTS Server", | |
| description="Text-to-Speech server using the Dia model, providing API and Web UI.", | |
| version="1.1.0", # Incremented version | |
| lifespan=lifespan, | |
| ) | |
| # List of folders to check/create | |
| folders = ["reference_audio", "model_cache", "outputs"] | |
| # Check each folder and create if it doesn't exist | |
| for folder in folders: | |
| if not os.path.exists(folder): | |
| os.makedirs(folder) | |
| print(f"Created directory: {folder}") | |
| # --- Static Files and Templates --- | |
| # Serve generated audio files from the configured output path | |
| app.mount("/outputs", StaticFiles(directory=get_output_path()), name="outputs") | |
| # Serve UI files (CSS, JS) from the 'ui' directory | |
| app.mount("/ui", StaticFiles(directory="ui"), name="ui_static") | |
| # Initialize Jinja2 templates to look in the 'ui' directory | |
| templates = Jinja2Templates(directory="ui") | |
| # --- Configuration Routes Definition --- | |
| # Defined locally now instead of importing from config.py | |
| def register_config_routes(app: FastAPI): | |
| """Adds configuration management endpoints to the FastAPI app.""" | |
| logger.info( | |
| "Registering configuration routes (/get_config, /save_config, /restart_server, /save_generation_defaults)." | |
| ) | |
| async def get_current_config(): | |
| """Returns the current server configuration values (from .env or defaults).""" | |
| logger.info("Request received for /get_config") | |
| return JSONResponse(content=config_manager.get_all()) | |
| async def save_new_config(request: Request): | |
| """ | |
| Saves updated server configuration values (Host, Port, Model paths, etc.) | |
| to the .env file. Requires server restart to apply most changes. | |
| """ | |
| logger.info("Request received for /save_config") | |
| try: | |
| new_config_data = await request.json() | |
| if not isinstance(new_config_data, dict): | |
| raise ValueError("Request body must be a JSON object.") | |
| logger.debug(f"Received server config data to save: {new_config_data}") | |
| # Filter data to only include keys present in DEFAULT_CONFIG | |
| filtered_data = { | |
| k: v for k, v in new_config_data.items() if k in DEFAULT_CONFIG | |
| } | |
| unknown_keys = set(new_config_data.keys()) - set(filtered_data.keys()) | |
| if unknown_keys: | |
| logger.warning( | |
| f"Ignoring unknown keys in save_config request: {unknown_keys}" | |
| ) | |
| config_manager.update(filtered_data) # Update in memory first | |
| if config_manager.save(): # Attempt to save to .env | |
| logger.info("Server configuration saved successfully to .env.") | |
| return JSONResponse( | |
| content={ | |
| "message": "Server configuration saved. Restart server to apply changes." | |
| } | |
| ) | |
| else: | |
| logger.error("Failed to save server configuration to .env file.") | |
| raise HTTPException( | |
| status_code=500, detail="Failed to save configuration file." | |
| ) | |
| except ValueError as ve: | |
| logger.error(f"Invalid data format for /save_config: {ve}") | |
| raise HTTPException( | |
| status_code=400, detail=f"Invalid request data: {str(ve)}" | |
| ) | |
| except Exception as e: | |
| logger.error(f"Error processing /save_config request: {e}", exc_info=True) | |
| raise HTTPException( | |
| status_code=500, detail=f"Internal server error during save: {str(e)}" | |
| ) | |
| async def save_generation_defaults(request: Request): | |
| """ | |
| Saves the provided generation parameters (speed, cfg, temp, etc.) | |
| as the new defaults in the .env file. These are loaded by the UI on startup. | |
| """ | |
| logger.info("Request received for /save_generation_defaults") | |
| try: | |
| gen_params = await request.json() | |
| if not isinstance(gen_params, dict): | |
| raise ValueError("Request body must be a JSON object.") | |
| logger.debug(f"Received generation defaults to save: {gen_params}") | |
| # Map received keys (e.g., 'speed_factor') to .env keys (e.g., 'GEN_DEFAULT_SPEED_FACTOR') | |
| defaults_to_save = {} | |
| key_map = { | |
| "speed_factor": "GEN_DEFAULT_SPEED_FACTOR", | |
| "cfg_scale": "GEN_DEFAULT_CFG_SCALE", | |
| "temperature": "GEN_DEFAULT_TEMPERATURE", | |
| "top_p": "GEN_DEFAULT_TOP_P", | |
| "cfg_filter_top_k": "GEN_DEFAULT_CFG_FILTER_TOP_K", | |
| } | |
| valid_keys_found = False | |
| for ui_key, env_key in key_map.items(): | |
| if ui_key in gen_params: | |
| # Basic validation could be added here (e.g., check if float/int) | |
| defaults_to_save[env_key] = str( | |
| gen_params[ui_key] | |
| ) # Ensure saving as string | |
| valid_keys_found = True | |
| else: | |
| logger.warning( | |
| f"Missing expected key '{ui_key}' in save_generation_defaults request." | |
| ) | |
| if not valid_keys_found: | |
| raise ValueError("No valid generation parameters found in the request.") | |
| config_manager.update(defaults_to_save) # Update in memory | |
| if ( | |
| config_manager.save() | |
| ): # Save all current config (including these) to .env | |
| logger.info("Generation defaults saved successfully to .env.") | |
| return JSONResponse(content={"message": "Generation defaults saved."}) | |
| else: | |
| logger.error("Failed to save generation defaults to .env file.") | |
| raise HTTPException( | |
| status_code=500, detail="Failed to save configuration file." | |
| ) | |
| except ValueError as ve: | |
| logger.error(f"Invalid data format for /save_generation_defaults: {ve}") | |
| raise HTTPException( | |
| status_code=400, detail=f"Invalid request data: {str(ve)}" | |
| ) | |
| except Exception as e: | |
| logger.error( | |
| f"Error processing /save_generation_defaults request: {e}", | |
| exc_info=True, | |
| ) | |
| raise HTTPException( | |
| status_code=500, detail=f"Internal server error during save: {str(e)}" | |
| ) | |
| async def trigger_server_restart(background_tasks: BackgroundTasks): | |
| """ | |
| Attempts to restart the server process. | |
| NOTE: This is highly dependent on how the server is run (e.g., with uvicorn --reload, | |
| or managed by systemd/supervisor). A simple exit might just stop the process. | |
| This implementation attempts a clean exit, relying on the runner to restart it. | |
| """ | |
| logger.warning("Received request to restart server via API.") | |
| def _do_restart(): | |
| time.sleep(1) # Short delay to allow response to be sent | |
| logger.warning("Attempting clean exit for restart...") | |
| # Option 1: Clean exit (relies on Uvicorn reload or process manager) | |
| sys.exit(0) | |
| # Option 2: Forceful re-execution (use with caution, might not work as expected) | |
| # try: | |
| # logger.warning("Attempting os.execv for restart...") | |
| # os.execv(sys.executable, ['python'] + sys.argv) | |
| # except Exception as exec_e: | |
| # logger.error(f"os.execv failed: {exec_e}. Server may not restart automatically.") | |
| # # Fallback to sys.exit if execv fails | |
| # sys.exit(1) | |
| background_tasks.add_task(_do_restart) | |
| return JSONResponse( | |
| content={ | |
| "message": "Restart signal sent. Server should restart shortly if run with auto-reload." | |
| } | |
| ) | |
| # --- Register Configuration Routes --- | |
| register_config_routes(app) | |
| # --- API Endpoints --- | |
| async def openai_tts_endpoint(request: OpenAITTSRequest): | |
| """ | |
| Generates speech audio from text, compatible with the OpenAI TTS API structure. | |
| Maps the 'voice' parameter to Dia's voice modes ('S1', 'S2', 'dialogue', or filename for clone). | |
| """ | |
| monitor = PerformanceMonitor() | |
| monitor.record("Request received") | |
| logger.info( | |
| f"Received OpenAI request: voice='{request.voice}', speed={request.speed}, format='{request.response_format}'" | |
| ) | |
| logger.debug(f"Input text (start): '{request.input[:100]}...'") | |
| voice_mode = "single_s1" # Default if mapping fails | |
| clone_ref_file = None | |
| ref_path = get_reference_audio_path() | |
| # --- Map OpenAI 'voice' parameter to Dia's modes --- | |
| voice_param = request.voice.strip() | |
| if voice_param.lower() == "dialogue": | |
| voice_mode = "dialogue" | |
| elif voice_param.lower() == "s1": | |
| voice_mode = "single_s1" | |
| elif voice_param.lower() == "s2": | |
| voice_mode = "single_s2" | |
| # Check if it looks like a filename for cloning (allow .wav or .mp3) | |
| elif voice_param.lower().endswith((".wav", ".mp3")): | |
| potential_path = os.path.join(ref_path, voice_param) | |
| # Check if the file actually exists in the reference directory | |
| if os.path.isfile(potential_path): | |
| voice_mode = "clone" | |
| clone_ref_file = voice_param # Use the provided filename | |
| logger.info( | |
| f"OpenAI request mapped to clone mode with file: {clone_ref_file}" | |
| ) | |
| else: | |
| logger.warning( | |
| f"Reference file '{voice_param}' specified in OpenAI request not found in '{ref_path}'. Defaulting voice mode." | |
| ) | |
| # Fallback to default 'single_s1' if file not found | |
| else: | |
| logger.warning( | |
| f"Unrecognized OpenAI voice parameter '{voice_param}'. Defaulting voice mode to 'single_s1'." | |
| ) | |
| # Fallback for any other value | |
| monitor.record("Parameters processed") | |
| try: | |
| # Call the core engine function using mapped parameters | |
| result = generate_speech( | |
| text=request.input, | |
| voice_mode=voice_mode, | |
| clone_reference_filename=clone_ref_file, | |
| speed_factor=request.speed, # Pass speed factor for post-processing | |
| # Use Dia's configured defaults for other generation params unless mapped | |
| max_tokens=None, # Let Dia use its default unless specified otherwise | |
| cfg_scale=get_gen_default_cfg_scale(), # Use saved defaults | |
| temperature=get_gen_default_temperature(), | |
| top_p=get_gen_default_top_p(), | |
| cfg_filter_top_k=get_gen_default_cfg_filter_top_k(), | |
| ) | |
| monitor.record("Generation complete") | |
| if result is None: | |
| logger.error("Speech generation failed (engine returned None).") | |
| raise HTTPException(status_code=500, detail="Speech generation failed.") | |
| audio_array, sample_rate = result | |
| if sample_rate != EXPECTED_SAMPLE_RATE: | |
| logger.warning( | |
| f"Engine returned sample rate {sample_rate}, but expected {EXPECTED_SAMPLE_RATE}. Encoding might assume {EXPECTED_SAMPLE_RATE}." | |
| ) | |
| # Use EXPECTED_SAMPLE_RATE for encoding as it's what the model is trained for | |
| sample_rate = EXPECTED_SAMPLE_RATE | |
| # Encode the audio in memory to the requested format | |
| encoded_audio = encode_audio(audio_array, sample_rate, request.response_format) | |
| monitor.record("Audio encoding complete") | |
| if encoded_audio is None: | |
| logger.error(f"Failed to encode audio to format: {request.response_format}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to encode audio to {request.response_format}", | |
| ) | |
| # Determine the correct media type for the response header | |
| media_type = "audio/opus" if request.response_format == "opus" else "audio/wav" | |
| # Note: OpenAI uses audio/opus, not audio/ogg;codecs=opus. Let's match OpenAI. | |
| logger.info( | |
| f"Successfully generated {len(encoded_audio)} bytes in format {request.response_format}" | |
| ) | |
| logger.debug(monitor.report()) | |
| # Stream the encoded audio back to the client | |
| return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type) | |
| except HTTPException as http_exc: | |
| # Re-raise HTTPExceptions directly (e.g., from parameter validation) | |
| logger.error(f"HTTP exception during OpenAI request: {http_exc.detail}") | |
| raise http_exc | |
| except Exception as e: | |
| logger.error(f"Error processing OpenAI TTS request: {e}", exc_info=True) | |
| logger.debug(monitor.report()) | |
| # Return generic server error for unexpected issues | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| async def custom_tts_endpoint(request: CustomTTSRequest): | |
| """ | |
| Generates speech audio from text using explicit Dia parameters. | |
| """ | |
| monitor = PerformanceMonitor() | |
| monitor.record("Request received") | |
| logger.info( | |
| f"Received custom TTS request: mode='{request.voice_mode}', format='{request.output_format}'" | |
| ) | |
| logger.debug(f"Input text (start): '{request.text[:100]}...'") | |
| logger.debug( | |
| f"Params: max_tokens={request.max_tokens}, cfg={request.cfg_scale}, temp={request.temperature}, top_p={request.top_p}, speed={request.speed_factor}, top_k={request.cfg_filter_top_k}" | |
| ) | |
| clone_ref_file = None | |
| if request.voice_mode == "clone": | |
| if not request.clone_reference_filename: | |
| raise HTTPException( | |
| status_code=400, # Bad request | |
| detail="Missing 'clone_reference_filename' which is required for clone mode.", | |
| ) | |
| ref_path = get_reference_audio_path() | |
| potential_path = os.path.join(ref_path, request.clone_reference_filename) | |
| if not os.path.isfile(potential_path): | |
| logger.error( | |
| f"Reference audio file not found for clone mode: {potential_path}" | |
| ) | |
| raise HTTPException( | |
| status_code=404, # Not found | |
| detail=f"Reference audio file not found: {request.clone_reference_filename}", | |
| ) | |
| clone_ref_file = request.clone_reference_filename | |
| logger.info(f"Custom request using clone mode with file: {clone_ref_file}") | |
| monitor.record("Parameters processed") | |
| try: | |
| # Call the core engine function with parameters from the request | |
| result = generate_speech( | |
| text=request.text, | |
| voice_mode=request.voice_mode, | |
| clone_reference_filename=clone_ref_file, | |
| max_tokens=request.max_tokens, # Pass user value or None | |
| cfg_scale=request.cfg_scale, | |
| temperature=request.temperature, | |
| top_p=request.top_p, | |
| speed_factor=request.speed_factor, # For post-processing | |
| cfg_filter_top_k=request.cfg_filter_top_k, | |
| ) | |
| monitor.record("Generation complete") | |
| if result is None: | |
| logger.error("Speech generation failed (engine returned None).") | |
| raise HTTPException(status_code=500, detail="Speech generation failed.") | |
| audio_array, sample_rate = result | |
| if sample_rate != EXPECTED_SAMPLE_RATE: | |
| logger.warning( | |
| f"Engine returned sample rate {sample_rate}, expected {EXPECTED_SAMPLE_RATE}. Encoding will use {EXPECTED_SAMPLE_RATE}." | |
| ) | |
| sample_rate = EXPECTED_SAMPLE_RATE | |
| # Encode the audio in memory | |
| encoded_audio = encode_audio(audio_array, sample_rate, request.output_format) | |
| monitor.record("Audio encoding complete") | |
| if encoded_audio is None: | |
| logger.error(f"Failed to encode audio to format: {request.output_format}") | |
| raise HTTPException( | |
| status_code=500, | |
| detail=f"Failed to encode audio to {request.output_format}", | |
| ) | |
| # Determine media type | |
| media_type = "audio/opus" if request.output_format == "opus" else "audio/wav" | |
| logger.info( | |
| f"Successfully generated {len(encoded_audio)} bytes in format {request.output_format}" | |
| ) | |
| logger.debug(monitor.report()) | |
| # Stream the response | |
| return StreamingResponse(io.BytesIO(encoded_audio), media_type=media_type) | |
| except HTTPException as http_exc: | |
| logger.error(f"HTTP exception during custom TTS request: {http_exc.detail}") | |
| raise http_exc | |
| except Exception as e: | |
| logger.error(f"Error processing custom TTS request: {e}", exc_info=True) | |
| logger.debug(monitor.report()) | |
| raise HTTPException(status_code=500, detail=f"Internal server error: {str(e)}") | |
| # --- Web UI Endpoints --- | |
| async def get_web_ui(request: Request): | |
| """Serves the main TTS web interface.""" | |
| logger.info("Serving TTS Web UI (index.html)") | |
| # Get current list of reference files for the clone dropdown | |
| reference_files = get_valid_reference_files() | |
| # Get current server config and default generation params | |
| current_config = config_manager.get_all() | |
| default_gen_params = { | |
| "speed_factor": get_gen_default_speed_factor(), | |
| "cfg_scale": get_gen_default_cfg_scale(), | |
| "temperature": get_gen_default_temperature(), | |
| "top_p": get_gen_default_top_p(), | |
| "cfg_filter_top_k": get_gen_default_cfg_filter_top_k(), | |
| } | |
| return templates.TemplateResponse( | |
| "index.html", # Use the renamed file | |
| { | |
| "request": request, | |
| "reference_files": reference_files, | |
| "config": current_config, # Pass current server config | |
| "presets": loaded_presets, # Pass loaded presets | |
| "default_gen_params": default_gen_params, # Pass default gen params | |
| # Add other variables needed by the template for initial state | |
| "error": None, | |
| "success": None, | |
| "output_file_url": None, | |
| "generation_time": None, | |
| "submitted_text": "", | |
| "submitted_voice_mode": "dialogue", # Default to combined mode | |
| "submitted_clone_file": None, | |
| # Initial generation params will be set by default_gen_params | |
| }, | |
| ) | |
| async def handle_web_ui_generate( | |
| request: Request, | |
| text: str = Form(...), | |
| voice_mode: Literal["dialogue", "clone"] = Form(...), # Updated modes | |
| clone_reference_select: Optional[str] = Form(None), | |
| # Generation parameters from form | |
| speed_factor: float = Form(...), # Make required or use Depends with default | |
| cfg_scale: float = Form(...), | |
| temperature: float = Form(...), | |
| top_p: float = Form(...), | |
| cfg_filter_top_k: int = Form(...), | |
| ): | |
| """Handles the generation request from the web UI form.""" | |
| logger.info(f"Web UI generation request: mode='{voice_mode}'") | |
| monitor = PerformanceMonitor() | |
| monitor.record("Web request received") | |
| output_file_url = None | |
| generation_time = None | |
| error_message = None | |
| success_message = None | |
| output_filename_base = "dia_output" # Default base name | |
| # --- Pre-generation Validation --- | |
| if not text.strip(): | |
| error_message = "Please enter some text to synthesize." | |
| clone_ref_file = None | |
| if voice_mode == "clone": | |
| if not clone_reference_select or clone_reference_select == "none": | |
| error_message = "Please select a reference audio file for clone mode." | |
| else: | |
| # Verify selected file still exists (important if files can be deleted) | |
| ref_path = get_reference_audio_path() | |
| potential_path = os.path.join(ref_path, clone_reference_select) | |
| if not os.path.isfile(potential_path): | |
| error_message = f"Selected reference file '{clone_reference_select}' no longer exists. Please refresh or upload." | |
| # Invalidate selection | |
| clone_ref_file = None | |
| clone_reference_select = None # Clear submitted value for re-rendering | |
| else: | |
| clone_ref_file = clone_reference_select | |
| logger.info(f"Using selected reference file: {clone_ref_file}") | |
| # If validation failed, re-render the page with error and submitted values | |
| if error_message: | |
| logger.warning(f"Web UI validation error: {error_message}") | |
| reference_files = get_valid_reference_files() | |
| current_config = config_manager.get_all() | |
| default_gen_params = { # Pass defaults again for consistency | |
| "speed_factor": get_gen_default_speed_factor(), | |
| "cfg_scale": get_gen_default_cfg_scale(), | |
| "temperature": get_gen_default_temperature(), | |
| "top_p": get_gen_default_top_p(), | |
| "cfg_filter_top_k": get_gen_default_cfg_filter_top_k(), | |
| } | |
| # Pass back the values the user submitted | |
| submitted_gen_params = { | |
| "speed_factor": speed_factor, | |
| "cfg_scale": cfg_scale, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "cfg_filter_top_k": cfg_filter_top_k, | |
| } | |
| return templates.TemplateResponse( | |
| "index.html", | |
| { | |
| "request": request, | |
| "error": error_message, | |
| "reference_files": reference_files, | |
| "config": current_config, | |
| "presets": loaded_presets, | |
| "default_gen_params": default_gen_params, # Base defaults | |
| # Submitted values to repopulate form | |
| "submitted_text": text, | |
| "submitted_voice_mode": voice_mode, | |
| "submitted_clone_file": clone_reference_select, # Use potentially invalidated value | |
| "submitted_gen_params": submitted_gen_params, # Pass submitted params back | |
| # Ensure other necessary template variables are passed | |
| "success": None, | |
| "output_file_url": None, | |
| "generation_time": None, | |
| }, | |
| ) | |
| # --- Generation --- | |
| try: | |
| monitor.record("Parameters processed") | |
| # Call the core engine function | |
| result = generate_speech( | |
| text=text, | |
| voice_mode=voice_mode, | |
| clone_reference_filename=clone_ref_file, | |
| speed_factor=speed_factor, | |
| cfg_scale=cfg_scale, | |
| temperature=temperature, | |
| top_p=top_p, | |
| cfg_filter_top_k=cfg_filter_top_k, | |
| max_tokens=None, # Use model default for UI simplicity | |
| ) | |
| monitor.record("Generation complete") | |
| if result: | |
| audio_array, sample_rate = result | |
| output_path_base = get_output_path() | |
| timestamp = datetime.now().strftime("%Y%m%d_%H%M%S") | |
| # Create a more descriptive filename | |
| mode_tag = voice_mode | |
| if voice_mode == "clone" and clone_ref_file: | |
| safe_ref_name = sanitize_filename(os.path.splitext(clone_ref_file)[0]) | |
| mode_tag = f"clone_{safe_ref_name[:20]}" # Limit length | |
| output_filename = ( | |
| f"{mode_tag}_{timestamp}.wav" # Always save as WAV for simplicity | |
| ) | |
| output_filepath = os.path.join(output_path_base, output_filename) | |
| # Save the audio to a WAV file | |
| saved = save_audio_to_file(audio_array, sample_rate, output_filepath) | |
| monitor.record("Audio saved") | |
| if saved: | |
| output_file_url = ( | |
| f"/outputs/{output_filename}" # URL path for browser access | |
| ) | |
| generation_time = ( | |
| monitor.events[-1][1] - monitor.start_time | |
| ) # Time until save complete | |
| success_message = f"Audio generated successfully!" | |
| logger.info(f"Web UI generated audio saved to: {output_filepath}") | |
| else: | |
| error_message = "Failed to save generated audio file." | |
| logger.error("Failed to save audio file from web UI request.") | |
| else: | |
| error_message = "Speech generation failed (engine returned None)." | |
| logger.error("Speech generation failed for web UI request.") | |
| except Exception as e: | |
| logger.error(f"Error processing web UI TTS request: {e}", exc_info=True) | |
| error_message = f"An unexpected error occurred: {str(e)}" | |
| logger.debug(monitor.report()) | |
| # --- Re-render Template with Results --- | |
| reference_files = get_valid_reference_files() | |
| current_config = config_manager.get_all() | |
| default_gen_params = { | |
| "speed_factor": get_gen_default_speed_factor(), | |
| "cfg_scale": get_gen_default_cfg_scale(), | |
| "temperature": get_gen_default_temperature(), | |
| "top_p": get_gen_default_top_p(), | |
| "cfg_filter_top_k": get_gen_default_cfg_filter_top_k(), | |
| } | |
| # Pass back submitted values to repopulate form correctly | |
| submitted_gen_params = { | |
| "speed_factor": speed_factor, | |
| "cfg_scale": cfg_scale, | |
| "temperature": temperature, | |
| "top_p": top_p, | |
| "cfg_filter_top_k": cfg_filter_top_k, | |
| } | |
| return templates.TemplateResponse( | |
| "index.html", | |
| { | |
| "request": request, | |
| "error": error_message, | |
| "success": success_message, | |
| "output_file_url": output_file_url, | |
| "generation_time": f"{generation_time:.2f}" if generation_time else None, | |
| "reference_files": reference_files, | |
| "config": current_config, | |
| "presets": loaded_presets, | |
| "default_gen_params": default_gen_params, # Base defaults | |
| # Pass back submitted values | |
| "submitted_text": text, | |
| "submitted_voice_mode": voice_mode, | |
| "submitted_clone_file": clone_ref_file, # Pass the validated filename back | |
| "submitted_gen_params": submitted_gen_params, # Pass submitted params back | |
| }, | |
| ) | |
| # --- Reference Audio Upload Endpoint --- | |
| async def upload_reference_audio(files: List[UploadFile] = File(...)): | |
| """Handles uploading of reference audio files (.wav, .mp3) for voice cloning.""" | |
| logger.info(f"Received request to upload {len(files)} reference audio file(s).") | |
| ref_path = get_reference_audio_path() | |
| uploaded_filenames = [] | |
| errors = [] | |
| allowed_mime_types = [ | |
| "audio/wav", | |
| "audio/mpeg", | |
| "audio/x-wav", | |
| ] # Common WAV/MP3 types | |
| allowed_extensions = [".wav", ".mp3"] | |
| for file in files: | |
| try: | |
| # Basic validation | |
| if not file.filename: | |
| errors.append("Received file with no filename.") | |
| continue | |
| # Sanitize filename | |
| safe_filename = sanitize_filename(file.filename) | |
| _, ext = os.path.splitext(safe_filename) | |
| if ext.lower() not in allowed_extensions: | |
| errors.append( | |
| f"File '{file.filename}' has unsupported extension '{ext}'. Allowed: {allowed_extensions}" | |
| ) | |
| continue | |
| # Check MIME type (more reliable than extension) | |
| if file.content_type not in allowed_mime_types: | |
| errors.append( | |
| f"File '{file.filename}' has unsupported content type '{file.content_type}'. Allowed: {allowed_mime_types}" | |
| ) | |
| continue | |
| # Construct full save path | |
| destination_path = os.path.join(ref_path, safe_filename) | |
| # Prevent overwriting existing files (optional, could add counter) | |
| if os.path.exists(destination_path): | |
| # Simple approach: skip if exists | |
| logger.warning( | |
| f"Reference file '{safe_filename}' already exists. Skipping upload." | |
| ) | |
| # Add to list so UI knows it's available, even if not newly uploaded this time | |
| if safe_filename not in uploaded_filenames: | |
| uploaded_filenames.append(safe_filename) | |
| continue | |
| # Alternative: add counter like file_1.wav, file_2.wav | |
| # Save the file using shutil.copyfileobj for efficiency with large files | |
| try: | |
| with open(destination_path, "wb") as buffer: | |
| shutil.copyfileobj(file.file, buffer) | |
| logger.info(f"Successfully saved reference file: {destination_path}") | |
| uploaded_filenames.append(safe_filename) | |
| except Exception as save_exc: | |
| errors.append(f"Failed to save file '{safe_filename}': {save_exc}") | |
| logger.error( | |
| f"Failed to save uploaded file '{safe_filename}' to '{destination_path}': {save_exc}", | |
| exc_info=True, | |
| ) | |
| finally: | |
| # Ensure the UploadFile resource is closed | |
| await file.close() | |
| except Exception as e: | |
| errors.append( | |
| f"Error processing file '{getattr(file, 'filename', 'unknown')}': {e}" | |
| ) | |
| logger.error( | |
| f"Unexpected error processing uploaded file: {e}", exc_info=True | |
| ) | |
| # Ensure file is closed even if other errors occur | |
| if file: | |
| await file.close() | |
| # Get the updated list of all valid files in the directory | |
| updated_file_list = get_valid_reference_files() | |
| response_data = { | |
| "message": f"Processed {len(files)} file(s).", | |
| "uploaded_files": uploaded_filenames, # List of successfully saved *new* files this request | |
| "all_reference_files": updated_file_list, # Complete current list | |
| "errors": errors, | |
| } | |
| status_code = ( | |
| 200 if not errors or len(errors) < len(files) else 400 | |
| ) # OK if at least one succeeded, else Bad Request | |
| if errors: | |
| logger.warning(f"Upload completed with errors: {errors}") | |
| return JSONResponse(content=response_data, status_code=status_code) | |
| # --- Health Check Endpoint --- | |
| async def health_check(): | |
| """Basic health check, indicates if the server is running and if the model is loaded.""" | |
| # Access the MODEL_LOADED variable *directly* from the engine module | |
| # each time the endpoint is called to get the current status. | |
| current_model_status = getattr(engine, "MODEL_LOADED", False) # Safely get status | |
| logger.debug( | |
| f"Health check returning model_loaded status: {current_model_status}" | |
| ) # Add debug log | |
| return {"status": "healthy", "model_loaded": current_model_status} | |
| # --- Main Execution --- | |
| if __name__ == "__main__": | |
| host = get_host() | |
| port = get_port() | |
| logger.info(f"Starting Dia TTS server on {host}:{port}") | |
| logger.info(f"Model Repository: {get_model_repo_id()}") | |
| logger.info(f"Model Config File: {get_model_config_filename()}") | |
| logger.info(f"Model Weights File: {get_model_weights_filename()}") | |
| logger.info(f"Model Cache Path: {get_model_cache_path()}") | |
| logger.info(f"Reference Audio Path: {get_reference_audio_path()}") | |
| logger.info(f"Output Path: {get_output_path()}") | |
| # Determine the host to display in logs and use for browser opening | |
| display_host = "localhost" if host == "0.0.0.0" else host | |
| logger.info(f"Web UI will be available at http://{display_host}:{port}/") | |
| logger.info(f"API Docs available at http://{display_host}:{port}/docs") | |
| # Ensure UI directory and index.html exist for UI | |
| ui_dir = "ui" | |
| index_file = os.path.join(ui_dir, "index.html") | |
| if not os.path.isdir(ui_dir) or not os.path.isfile(index_file): | |
| logger.warning( | |
| f"'{ui_dir}' directory or '{index_file}' not found. Web UI may not work." | |
| ) | |
| # Optionally create dummy files/dirs if needed for startup | |
| os.makedirs(ui_dir, exist_ok=True) | |
| if not os.path.isfile(index_file): | |
| try: | |
| with open(index_file, "w") as f: | |
| f.write( | |
| "<html><body>Web UI template missing. See project source for index.html.</body></html>" | |
| ) | |
| logger.info(f"Created dummy {index_file}.") | |
| except Exception as e: | |
| logger.error(f"Failed to create dummy {index_file}: {e}") | |
| # --- Create synchronization event --- | |
| # This event will be set by the lifespan manager once startup (incl. model loading) is complete. | |
| startup_complete_event = threading.Event() | |
| # Run Uvicorn server | |
| # The lifespan context manager ('lifespan="on"') will run during startup. | |
| # The 'lifespan' function is responsible for loading models and setting the 'startup_complete_event'. | |
| uvicorn.run( | |
| "server:app", # Use the format 'module:app_instance' | |
| host=host, | |
| port=port, | |
| reload=False, # Set reload as needed for development/production | |
| # reload_dirs=[".", "ui"], # Only use reload=True with reload_dirs/includes for development | |
| # reload_includes=[ | |
| # "*.py", | |
| # "*.html", | |
| # "*.css", | |
| # "*.js", | |
| # ".env", | |
| # "*.yaml", | |
| # ], | |
| lifespan="on", # Use the lifespan context manager defined in this file | |
| # workers=1 # Keep workers=1 when using reload=True or complex global state/models | |
| ) | |