Spaces:
Running on Zero
Running on Zero
| from __future__ import annotations | |
| import os | |
| import sys | |
| from pathlib import Path | |
| from typing import Any, Callable | |
| REPO_ROOT = Path(__file__).resolve().parent | |
| SRC_ROOT = REPO_ROOT / "src" | |
| for import_root in (REPO_ROOT, SRC_ROOT): | |
| import_root_str = str(import_root) | |
| if import_root_str not in sys.path: | |
| sys.path.insert(0, import_root_str) | |
| class _SpacesFallback: | |
| def GPU(*decorator_args, **_decorator_kwargs): | |
| if decorator_args and callable(decorator_args[0]): | |
| return decorator_args[0] | |
| def decorate(fn: Callable[..., Any]) -> Callable[..., Any]: | |
| return fn | |
| return decorate | |
| try: | |
| import spaces # type: ignore | |
| except Exception: # pragma: no cover - only used outside Hugging Face Spaces. | |
| spaces = _SpacesFallback() # type: ignore | |
| def _env_bool(name: str, default: bool) -> bool: | |
| value = os.environ.get(name) | |
| if value is None: | |
| return default | |
| return value.strip().lower() in {"1", "true", "yes", "on"} | |
| def _env_int(name: str, default: int) -> int: | |
| value = os.environ.get(name) | |
| if value is None or not value.strip(): | |
| return default | |
| return int(value) | |
| def _configure_zero_gpu_environment() -> None: | |
| os.environ.setdefault("DOTS_TTS_COMPILE_BACKEND", "aoti") | |
| os.environ.setdefault("DOTS_TTS_SKIP_INIT_WARMUP", "1") | |
| os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True") | |
| def _preload_runtime(app_service, app_config, compile_backend: str): | |
| runtime, resolved_model_name_or_path = app_service._get_runtime( # noqa: SLF001 | |
| app_config.default_model_name_or_path, | |
| ) | |
| runtime.optimize = bool(app_config.optimize) | |
| runtime.model.set_optimize(bool(app_config.optimize)) | |
| if hasattr(runtime.model, "set_compile_backend"): | |
| runtime.model.set_compile_backend(compile_backend) | |
| return runtime, resolved_model_name_or_path | |
| def main() -> None: | |
| _configure_zero_gpu_environment() | |
| import gradio as gr | |
| from loguru import logger | |
| from apps.gradio.app import PLAYGROUND_CSS, build_demo, build_playground_theme | |
| from apps.gradio.service import GradioAppService, build_gradio_app_config | |
| from dots_tts.utils.logging import configure_logging | |
| host = os.environ.get("DOTS_TTS_HOST", "0.0.0.0") | |
| port = _env_int("DOTS_TTS_PORT", 7860) | |
| model_name_or_path = os.environ.get( | |
| "DOTS_TTS_MODEL_NAME_OR_PATH", | |
| "rednote-hilab/dots.tts", | |
| ) | |
| precision = os.environ.get("DOTS_TTS_PRECISION", "bfloat16") | |
| execution_mode = os.environ.get("DOTS_TTS_EXECUTION_MODE", "generate_stream") | |
| max_generate_length = _env_int("DOTS_TTS_MAX_GENERATE_LENGTH", 500) | |
| default_num_steps = _env_int("DOTS_TTS_DEFAULT_NUM_STEPS", 10) | |
| compile_backend = os.environ.get("DOTS_TTS_COMPILE_BACKEND", "aoti").strip().lower() | |
| enable_aoti = _env_bool("DOTS_TTS_ENABLE_AOTI", True) | |
| startup_compile = _env_bool("DOTS_TTS_AOTI_COMPILE_ON_STARTUP", True) | |
| optimize = _env_bool("DOTS_TTS_OPTIMIZE", True) | |
| generation_duration = _env_int("DOTS_TTS_ZERO_GPU_DURATION", 60) | |
| compile_duration = _env_int("DOTS_TTS_ZERO_GPU_COMPILE_DURATION", 1500) | |
| output_dir = Path(os.environ.get("DOTS_TTS_OUTPUT_DIR", "/data/generated")) | |
| log_file = Path(os.environ.get("DOTS_TTS_LOG_FILE", "/tmp/dots_tts_gradio.log")) | |
| configure_logging(log_file=log_file) | |
| logger.info( | |
| "Space app starting: model={} execution_mode={} precision={} optimize={} " | |
| "compile_backend={} enable_aoti={} startup_compile={} max_generate_length={}", | |
| model_name_or_path, | |
| execution_mode, | |
| precision, | |
| optimize, | |
| compile_backend, | |
| enable_aoti, | |
| startup_compile, | |
| max_generate_length, | |
| ) | |
| app_config = build_gradio_app_config( | |
| host=host, | |
| port=port, | |
| execution_mode=execution_mode, | |
| precision=precision, | |
| optimize=optimize, | |
| model_name_or_path=model_name_or_path, | |
| output_dir=output_dir, | |
| max_generate_length=max_generate_length, | |
| default_num_steps=default_num_steps, | |
| default_max_generate_length=max_generate_length, | |
| repo_root=REPO_ROOT, | |
| ) | |
| app_service = GradioAppService(app_config) | |
| runtime, resolved_model_name_or_path = _preload_runtime( | |
| app_service, | |
| app_config, | |
| compile_backend if enable_aoti else "torch_compile", | |
| ) | |
| if enable_aoti and startup_compile and optimize: | |
| def compile_aoti_cache(): | |
| child_runtime, _ = _preload_runtime( | |
| app_service, | |
| app_config, | |
| compile_backend, | |
| ) | |
| child_runtime.model.run_warmup( | |
| max_generate_length=app_config.max_generate_length, | |
| precision=app_config.precision, | |
| num_steps=app_config.default_num_steps, | |
| guidance_scale=app_config.default_guidance_scale, | |
| ) | |
| return child_runtime.model.export_compiled_models() | |
| compiled_models = compile_aoti_cache() | |
| if compiled_models: | |
| runtime.model.import_compiled_models(compiled_models) | |
| logger.info( | |
| "AOTI startup compile completed: compiled_target_count={}", | |
| len(compiled_models or {}), | |
| ) | |
| app_service.generate = spaces.GPU(duration=generation_duration)(app_service.generate) | |
| demo = build_demo(gr, app_config, app_service) | |
| logger.info( | |
| "Space app ready: host={} port={} resolved_model={} compiled_target_count={}", | |
| app_config.host, | |
| app_config.port, | |
| resolved_model_name_or_path, | |
| len(runtime.model.export_compiled_models()) | |
| if hasattr(runtime.model, "export_compiled_models") | |
| else 0, | |
| ) | |
| demo.launch( | |
| server_name=app_config.host, | |
| server_port=app_config.port, | |
| theme=build_playground_theme(gr), | |
| css=PLAYGROUND_CSS, | |
| ) | |
| if __name__ == "__main__": | |
| main() | |