dots.tts / app.py
YuMS's picture
set default generation time to 60s
927c02a
from __future__ import annotations
import os
import sys
from pathlib import Path
from typing import Any, Callable
REPO_ROOT = Path(__file__).resolve().parent
SRC_ROOT = REPO_ROOT / "src"
for import_root in (REPO_ROOT, SRC_ROOT):
import_root_str = str(import_root)
if import_root_str not in sys.path:
sys.path.insert(0, import_root_str)
class _SpacesFallback:
@staticmethod
def GPU(*decorator_args, **_decorator_kwargs):
if decorator_args and callable(decorator_args[0]):
return decorator_args[0]
def decorate(fn: Callable[..., Any]) -> Callable[..., Any]:
return fn
return decorate
try:
import spaces # type: ignore
except Exception: # pragma: no cover - only used outside Hugging Face Spaces.
spaces = _SpacesFallback() # type: ignore
def _env_bool(name: str, default: bool) -> bool:
value = os.environ.get(name)
if value is None:
return default
return value.strip().lower() in {"1", "true", "yes", "on"}
def _env_int(name: str, default: int) -> int:
value = os.environ.get(name)
if value is None or not value.strip():
return default
return int(value)
def _configure_zero_gpu_environment() -> None:
os.environ.setdefault("DOTS_TTS_COMPILE_BACKEND", "aoti")
os.environ.setdefault("DOTS_TTS_SKIP_INIT_WARMUP", "1")
os.environ.setdefault("PYTORCH_CUDA_ALLOC_CONF", "expandable_segments:True")
def _preload_runtime(app_service, app_config, compile_backend: str):
runtime, resolved_model_name_or_path = app_service._get_runtime( # noqa: SLF001
app_config.default_model_name_or_path,
)
runtime.optimize = bool(app_config.optimize)
runtime.model.set_optimize(bool(app_config.optimize))
if hasattr(runtime.model, "set_compile_backend"):
runtime.model.set_compile_backend(compile_backend)
return runtime, resolved_model_name_or_path
def main() -> None:
_configure_zero_gpu_environment()
import gradio as gr
from loguru import logger
from apps.gradio.app import PLAYGROUND_CSS, build_demo, build_playground_theme
from apps.gradio.service import GradioAppService, build_gradio_app_config
from dots_tts.utils.logging import configure_logging
host = os.environ.get("DOTS_TTS_HOST", "0.0.0.0")
port = _env_int("DOTS_TTS_PORT", 7860)
model_name_or_path = os.environ.get(
"DOTS_TTS_MODEL_NAME_OR_PATH",
"rednote-hilab/dots.tts",
)
precision = os.environ.get("DOTS_TTS_PRECISION", "bfloat16")
execution_mode = os.environ.get("DOTS_TTS_EXECUTION_MODE", "generate_stream")
max_generate_length = _env_int("DOTS_TTS_MAX_GENERATE_LENGTH", 500)
default_num_steps = _env_int("DOTS_TTS_DEFAULT_NUM_STEPS", 10)
compile_backend = os.environ.get("DOTS_TTS_COMPILE_BACKEND", "aoti").strip().lower()
enable_aoti = _env_bool("DOTS_TTS_ENABLE_AOTI", True)
startup_compile = _env_bool("DOTS_TTS_AOTI_COMPILE_ON_STARTUP", True)
optimize = _env_bool("DOTS_TTS_OPTIMIZE", True)
generation_duration = _env_int("DOTS_TTS_ZERO_GPU_DURATION", 60)
compile_duration = _env_int("DOTS_TTS_ZERO_GPU_COMPILE_DURATION", 1500)
output_dir = Path(os.environ.get("DOTS_TTS_OUTPUT_DIR", "/data/generated"))
log_file = Path(os.environ.get("DOTS_TTS_LOG_FILE", "/tmp/dots_tts_gradio.log"))
configure_logging(log_file=log_file)
logger.info(
"Space app starting: model={} execution_mode={} precision={} optimize={} "
"compile_backend={} enable_aoti={} startup_compile={} max_generate_length={}",
model_name_or_path,
execution_mode,
precision,
optimize,
compile_backend,
enable_aoti,
startup_compile,
max_generate_length,
)
app_config = build_gradio_app_config(
host=host,
port=port,
execution_mode=execution_mode,
precision=precision,
optimize=optimize,
model_name_or_path=model_name_or_path,
output_dir=output_dir,
max_generate_length=max_generate_length,
default_num_steps=default_num_steps,
default_max_generate_length=max_generate_length,
repo_root=REPO_ROOT,
)
app_service = GradioAppService(app_config)
runtime, resolved_model_name_or_path = _preload_runtime(
app_service,
app_config,
compile_backend if enable_aoti else "torch_compile",
)
if enable_aoti and startup_compile and optimize:
@spaces.GPU(duration=compile_duration)
def compile_aoti_cache():
child_runtime, _ = _preload_runtime(
app_service,
app_config,
compile_backend,
)
child_runtime.model.run_warmup(
max_generate_length=app_config.max_generate_length,
precision=app_config.precision,
num_steps=app_config.default_num_steps,
guidance_scale=app_config.default_guidance_scale,
)
return child_runtime.model.export_compiled_models()
compiled_models = compile_aoti_cache()
if compiled_models:
runtime.model.import_compiled_models(compiled_models)
logger.info(
"AOTI startup compile completed: compiled_target_count={}",
len(compiled_models or {}),
)
app_service.generate = spaces.GPU(duration=generation_duration)(app_service.generate)
demo = build_demo(gr, app_config, app_service)
logger.info(
"Space app ready: host={} port={} resolved_model={} compiled_target_count={}",
app_config.host,
app_config.port,
resolved_model_name_or_path,
len(runtime.model.export_compiled_models())
if hasattr(runtime.model, "export_compiled_models")
else 0,
)
demo.launch(
server_name=app_config.host,
server_port=app_config.port,
theme=build_playground_theme(gr),
css=PLAYGROUND_CSS,
)
if __name__ == "__main__":
main()