from __future__ import annotations import os import sys from pathlib import Path from typing import Any, Callable REPO_ROOT = Path(__file__).resolve().parent SRC_ROOT = REPO_ROOT / "src" for import_root in (REPO_ROOT, SRC_ROOT): import_root_str = str(import_root) if import_root_str not in sys.path: sys.path.insert(0, import_root_str) class _SpacesFallback: @staticmethod def GPU(*decorator_args, **_decorator_kwargs): if decorator_args and callable(decorator_args[0]): return decorator_args[0] def decorate(fn: Callable[..., Any]) -> Callable[..., Any]: return fn return decorate try: import spaces # type: ignore except Exception: # pragma: no cover - only used outside Hugging Face Spaces. spaces = _SpacesFallback() # type: ignore def _env_bool(name: str, default: bool) -> bool: value = os.environ.get(name) if value is None: return default return value.strip().lower() in {"1", "true", "yes", "on"} def _env_int(name: str, default: int) -> int: value = os.environ.get(name) if value is None or not value.strip(): return default return int(value) def _configure_zero_gpu_environment() -> None: os.environ.setdefault("DOTS_TTS_COMPILE_BACKEND", "aoti") os.environ.setdefault("DOTS_TTS_SKIP_INIT_WARMUP", "1") os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") def _preload_runtime(app_service, app_config, compile_backend: str): runtime, resolved_model_name_or_path = app_service._get_runtime( # noqa: SLF001 app_config.default_model_name_or_path, ) runtime.optimize = bool(app_config.optimize) runtime.model.set_optimize(bool(app_config.optimize)) if hasattr(runtime.model, "set_compile_backend"): runtime.model.set_compile_backend(compile_backend) return runtime, resolved_model_name_or_path def main() -> None: _configure_zero_gpu_environment() import gradio as gr from loguru import logger from apps.gradio.app import PLAYGROUND_CSS, build_demo, build_playground_theme from apps.gradio.service import GradioAppService, build_gradio_app_config from dots_tts.utils.logging import configure_logging host = os.environ.get("DOTS_TTS_HOST", "0.0.0.0") port = _env_int("DOTS_TTS_PORT", 7860) model_name_or_path = os.environ.get( "DOTS_TTS_MODEL_NAME_OR_PATH", "rednote-hilab/dots.tts", ) precision = os.environ.get("DOTS_TTS_PRECISION", "bfloat16") execution_mode = os.environ.get("DOTS_TTS_EXECUTION_MODE", "generate_stream") max_generate_length = _env_int("DOTS_TTS_MAX_GENERATE_LENGTH", 500) default_num_steps = _env_int("DOTS_TTS_DEFAULT_NUM_STEPS", 10) compile_backend = os.environ.get("DOTS_TTS_COMPILE_BACKEND", "aoti").strip().lower() enable_aoti = _env_bool("DOTS_TTS_ENABLE_AOTI", True) startup_compile = _env_bool("DOTS_TTS_AOTI_COMPILE_ON_STARTUP", True) optimize = _env_bool("DOTS_TTS_OPTIMIZE", True) generation_duration = _env_int("DOTS_TTS_ZERO_GPU_DURATION", 60) compile_duration = _env_int("DOTS_TTS_ZERO_GPU_COMPILE_DURATION", 1500) output_dir = Path(os.environ.get("DOTS_TTS_OUTPUT_DIR", "/data/generated")) log_file = Path(os.environ.get("DOTS_TTS_LOG_FILE", "/tmp/dots_tts_gradio.log")) configure_logging(log_file=log_file) logger.info( "Space app starting: model={} execution_mode={} precision={} optimize={} " "compile_backend={} enable_aoti={} startup_compile={} max_generate_length={}", model_name_or_path, execution_mode, precision, optimize, compile_backend, enable_aoti, startup_compile, max_generate_length, ) app_config = build_gradio_app_config( host=host, port=port, execution_mode=execution_mode, precision=precision, optimize=optimize, model_name_or_path=model_name_or_path, output_dir=output_dir, max_generate_length=max_generate_length, default_num_steps=default_num_steps, default_max_generate_length=max_generate_length, repo_root=REPO_ROOT, ) app_service = GradioAppService(app_config) runtime, resolved_model_name_or_path = _preload_runtime( app_service, app_config, compile_backend if enable_aoti else "torch_compile", ) if enable_aoti and startup_compile and optimize: @spaces.GPU(duration=compile_duration) def compile_aoti_cache(): child_runtime, _ = _preload_runtime( app_service, app_config, compile_backend, ) child_runtime.model.run_warmup( max_generate_length=app_config.max_generate_length, precision=app_config.precision, num_steps=app_config.default_num_steps, guidance_scale=app_config.default_guidance_scale, ) return child_runtime.model.export_compiled_models() compiled_models = compile_aoti_cache() if compiled_models: runtime.model.import_compiled_models(compiled_models) logger.info( "AOTI startup compile completed: compiled_target_count={}", len(compiled_models or {}), ) app_service.generate = spaces.GPU(duration=generation_duration)(app_service.generate) demo = build_demo(gr, app_config, app_service) logger.info( "Space app ready: host={} port={} resolved_model={} compiled_target_count={}", app_config.host, app_config.port, resolved_model_name_or_path, len(runtime.model.export_compiled_models()) if hasattr(runtime.model, "export_compiled_models") else 0, ) demo.launch( server_name=app_config.host, server_port=app_config.port, theme=build_playground_theme(gr), css=PLAYGROUND_CSS, ) if __name__ == "__main__": main()