MOSS-TTSD-NF4 / clis /moss_ttsd_app.py
groxaxo's picture
Upload MOSS-TTSD NF4 quantized model
3afa0cd verified
import argparse
import functools
import importlib.util
import re
import time
from pathlib import Path
from typing import Optional
import gradio as gr
import numpy as np
import torch
import torchaudio
from transformers import AutoModel, AutoProcessor
# Disable the broken cuDNN SDPA backend
torch.backends.cuda.enable_cudnn_sdp(False)
# Keep these enabled as fallbacks
torch.backends.cuda.enable_flash_sdp(True)
torch.backends.cuda.enable_mem_efficient_sdp(True)
torch.backends.cuda.enable_math_sdp(True)
MODEL_PATH = "OpenMOSS-Team/MOSS-TTSD-v1.0"
CODEC_MODEL_PATH = "OpenMOSS-Team/MOSS-Audio-Tokenizer"
DEFAULT_ATTN_IMPLEMENTATION = "auto"
DEFAULT_MAX_NEW_TOKENS = 2000
MIN_SPEAKERS = 1
MAX_SPEAKERS = 5
PRESET_REF_AUDIO_S1 = "assets/audio/reference_02_s1.wav"
PRESET_REF_AUDIO_S2 = "assets/audio/reference_02_s2.wav"
PRESET_PROMPT_TEXT_S1 = (
"[S1] In short, we embarked on a mission to make America great again for all Americans."
)
PRESET_PROMPT_TEXT_S2 = (
"[S2] NVIDIA reinvented computing for the first time after 60 years. In fact, Erwin at IBM knows quite "
"well that the computer has largely been the same since the 60s."
)
PRESET_DIALOGUE_TEXT = (
"[S1] Listen, let's talk business. China. I'm hearing things.\n"
"People are saying they're catching up. Fast. What's the real scoop?\n"
"Their AI, is it a threat?\n"
"[S2] Well, the pace of innovation there is extraordinary, honestly.\n"
"They have the researchers, and they have the drive.\n"
"[S1] Extraordinary? I don't like that. I want us to be extraordinary.\n"
"Are they winning?\n"
"[S2] I wouldn't say winning, but their progress is very promising.\n"
"They are building massive clusters. They're very determined.\n"
"[S1] Promising. There it is. I hate that word.\n"
"When China is promising, it means we're losing.\n"
"It's a disaster, Jensen. A total disaster."
)
PRESET_EXAMPLES = [
{
"name": "Quick Start | reference_02_s1/s2",
"speaker_count": 2,
"s1_audio": PRESET_REF_AUDIO_S1,
"s1_prompt": PRESET_PROMPT_TEXT_S1,
"s2_audio": PRESET_REF_AUDIO_S2,
"s2_prompt": PRESET_PROMPT_TEXT_S2,
"dialogue_text": PRESET_DIALOGUE_TEXT,
}
]
PRESET_DISPLAY_FIELDS = [
("Speaker Count", "speaker_count"),
("S1 Reference Audio (Optional)", "s1_audio"),
("S1 Prompt Text (Required with reference audio)", "s1_prompt"),
("S2 Reference Audio (Optional)", "s2_audio"),
("S2 Prompt Text (Required with reference audio)", "s2_prompt"),
("Dialogue Text", "dialogue_text"),
]
def _build_preset_table_rows():
rows = []
row_to_preset = []
for preset_idx, preset in enumerate(PRESET_EXAMPLES):
for field_name, field_key in PRESET_DISPLAY_FIELDS:
value = str(preset.get(field_key, ""))
if field_key == "dialogue_text":
value = value.replace("\n", " ").strip()
if len(value) > 120:
value = value[:120] + " ..."
rows.append([field_name, value])
row_to_preset.append(preset_idx)
return rows, row_to_preset
PRESET_TABLE_ROWS, PRESET_TABLE_ROW_TO_PRESET = _build_preset_table_rows()
def resolve_attn_implementation(requested: str, device: torch.device, dtype: torch.dtype) -> str | None:
requested_norm = (requested or "").strip().lower()
if requested_norm in {"none"}:
return None
if requested_norm not in {"", "auto"}:
return requested
# Prefer FlashAttention 2 when package + device conditions are met.
if (
device.type == "cuda"
and importlib.util.find_spec("flash_attn") is not None
and dtype in {torch.float16, torch.bfloat16}
):
major, _ = torch.cuda.get_device_capability(device)
if major >= 8:
return "flash_attention_2"
# CUDA fallback: use PyTorch SDPA kernels.
if device.type == "cuda":
return "sdpa"
# CPU fallback.
return "eager"
@functools.lru_cache(maxsize=1)
def load_backend(model_path: str, codec_path: str, device_str: str, attn_implementation: str):
device = torch.device(device_str if torch.cuda.is_available() else "cpu")
dtype = torch.bfloat16 if device.type == "cuda" else torch.float32
resolved_attn_implementation = resolve_attn_implementation(
requested=attn_implementation,
device=device,
dtype=dtype,
)
processor = AutoProcessor.from_pretrained(
model_path,
trust_remote_code=True,
codec_path=codec_path,
)
if hasattr(processor, "audio_tokenizer"):
processor.audio_tokenizer = processor.audio_tokenizer.to(device)
processor.audio_tokenizer.eval()
model_kwargs = {
"trust_remote_code": True,
"torch_dtype": dtype,
}
if resolved_attn_implementation:
model_kwargs["attn_implementation"] = resolved_attn_implementation
model = AutoModel.from_pretrained(model_path, **model_kwargs).to(device)
model.eval()
sample_rate = int(getattr(processor.model_config, "sampling_rate", 24000))
return model, processor, device, sample_rate
def _resample_wav(wav: torch.Tensor, orig_sr: int, target_sr: int) -> torch.Tensor:
if int(orig_sr) == int(target_sr):
return wav
new_num_samples = int(round(wav.shape[-1] * float(target_sr) / float(orig_sr)))
if new_num_samples <= 0:
raise ValueError(f"Invalid resample length from {orig_sr}Hz to {target_sr}Hz.")
return torch.nn.functional.interpolate(
wav.unsqueeze(0),
size=new_num_samples,
mode="linear",
align_corners=False,
).squeeze(0)
def _load_audio(audio_path: str) -> tuple[torch.Tensor, int]:
path = Path(audio_path).expanduser()
if not path.exists():
raise FileNotFoundError(f"Reference audio not found: {path}")
wav, sr = torchaudio.load(str(path))
if wav.numel() == 0:
raise ValueError(f"Reference audio is empty: {path}")
if wav.shape[0] > 1:
wav = wav.mean(dim=0, keepdim=True)
return wav, int(sr)
def normalize_text(text: str) -> str:
text = re.sub(r"\[(\d+)\]", r"[S\1]", text)
remove_chars = "【】《》()『』「」" '"-_“”~~‘’'
segments = re.split(r"(?=\[S\d+\])", text.replace("\n", " "))
processed_parts = []
for seg in segments:
seg = seg.strip()
if not seg:
continue
matched = re.match(r"^(\[S\d+\])\s*(.*)", seg)
tag, content = matched.groups() if matched else ("", seg)
content = re.sub(f"[{re.escape(remove_chars)}]", "", content)
content = re.sub(r"哈{2,}", "[笑]", content)
content = re.sub(r"\b(ha(\s*ha)+)\b", "[laugh]", content, flags=re.IGNORECASE)
content = content.replace("——", ",")
content = content.replace("……", ",")
content = content.replace("...", ",")
content = content.replace("⸺", ",")
content = content.replace("―", ",")
content = content.replace("—", ",")
content = content.replace("…", ",")
internal_punct_map = str.maketrans(
{";": ",", ";": ",", ":": ",", ":": ",", "、": ","}
)
content = content.translate(internal_punct_map)
content = content.strip()
content = re.sub(r"([,。?!,.?!])[,。?!,.?!]+", r"\1", content)
if len(content) > 1:
last_ch = "。" if content[-1] == "," else ("." if content[-1] == "," else content[-1])
body = content[:-1].replace("。", ",")
content = body + last_ch
processed_parts.append({"tag": tag, "content": content})
if not processed_parts:
return ""
merged_lines = []
current_tag = processed_parts[0]["tag"]
current_content = [processed_parts[0]["content"]]
for part in processed_parts[1:]:
if part["tag"] == current_tag and current_tag:
current_content.append(part["content"])
else:
merged_lines.append(f"{current_tag}{''.join(current_content)}".strip())
current_tag = part["tag"]
current_content = [part["content"]]
merged_lines.append(f"{current_tag}{''.join(current_content)}".strip())
return "".join(merged_lines).replace("‘", "'").replace("’", "'")
def _validate_dialogue_text(dialogue_text: str, speaker_count: int) -> str:
text = (dialogue_text or "").strip()
if not text:
raise ValueError("Please enter dialogue text.")
tags = re.findall(r"\[S(\d+)\]", text)
if not tags:
raise ValueError("Dialogue must include speaker tags like [S1], [S2], ...")
max_tag = max(int(t) for t in tags)
if max_tag > speaker_count:
raise ValueError(
f"Dialogue contains [S{max_tag}], but speaker count is set to {speaker_count}."
)
return text
def update_speaker_panels(speaker_count: int):
count = int(speaker_count)
count = max(MIN_SPEAKERS, min(MAX_SPEAKERS, count))
return [gr.update(visible=(idx < count)) for idx in range(MAX_SPEAKERS)]
def apply_preset_selection(evt: gr.SelectData):
if evt is None or evt.index is None:
return (
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
*[gr.update() for _ in range(MAX_SPEAKERS)],
)
if isinstance(evt.index, (tuple, list)):
row_idx = int(evt.index[0])
else:
row_idx = int(evt.index)
if row_idx < 0 or row_idx >= len(PRESET_TABLE_ROW_TO_PRESET):
return (
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
*[gr.update() for _ in range(MAX_SPEAKERS)],
)
preset_idx = PRESET_TABLE_ROW_TO_PRESET[row_idx]
if preset_idx < 0 or preset_idx >= len(PRESET_EXAMPLES):
return (
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
gr.update(),
*[gr.update() for _ in range(MAX_SPEAKERS)],
)
preset = PRESET_EXAMPLES[preset_idx]
panel_updates = update_speaker_panels(int(preset["speaker_count"]))
return (
gr.update(value=int(preset["speaker_count"])),
gr.update(value=str(preset["s1_audio"])),
gr.update(value=str(preset["s1_prompt"])),
gr.update(value=str(preset["s2_audio"])),
gr.update(value=str(preset["s2_prompt"])),
gr.update(value=str(preset["dialogue_text"])),
*panel_updates,
)
def _merge_consecutive_speaker_tags(text: str) -> str:
segments = re.split(r"(?=\[S\d+\])", text)
if not segments:
return text
merged_parts = []
current_tag = None
for seg in segments:
seg = seg.strip()
if not seg:
continue
matched = re.match(r"^(\[S\d+\])\s*(.*)", seg, re.DOTALL)
if not matched:
merged_parts.append(seg)
continue
tag, content = matched.groups()
if tag == current_tag:
merged_parts.append(content)
else:
current_tag = tag
merged_parts.append(f"{tag}{content}")
return "".join(merged_parts)
def _normalize_prompt_text(prompt_text: str, speaker_id: int) -> str:
text = (prompt_text or "").strip()
if not text:
raise ValueError(f"S{speaker_id} prompt text is empty.")
expected_tag = f"[S{speaker_id}]"
if not text.lstrip().startswith(expected_tag):
text = f"{expected_tag} {text}"
return text
def _build_prefixed_text(
dialogue_text: str,
prompt_text_map: dict[int, str],
cloned_speakers: list[int],
) -> str:
prompt_prefix = "".join([prompt_text_map[speaker_id] for speaker_id in cloned_speakers])
return _merge_consecutive_speaker_tags(prompt_prefix + dialogue_text)
def _encode_reference_audio_codes(
processor,
clone_wavs: list[torch.Tensor],
cloned_speakers: list[int],
speaker_count: int,
sample_rate: int,
) -> list[Optional[torch.Tensor]]:
encoded_list = processor.encode_audios_from_wav(clone_wavs, sampling_rate=sample_rate)
reference_audio_codes: list[Optional[torch.Tensor]] = [None for _ in range(speaker_count)]
for speaker_id, audio_codes in zip(cloned_speakers, encoded_list):
reference_audio_codes[speaker_id - 1] = audio_codes
return reference_audio_codes
def build_conversation(
dialogue_text: str,
reference_audio_codes: list[Optional[torch.Tensor]],
prompt_audio: torch.Tensor | None,
processor,
):
if prompt_audio is None:
return [[processor.build_user_message(text=dialogue_text)]], "generation", "Generation"
user_message = processor.build_user_message(
text=dialogue_text,
reference=reference_audio_codes,
)
return (
[
[
user_message,
processor.build_assistant_message(audio_codes_list=[prompt_audio]),
],
],
"continuation",
"voice_clone_and_continuation",
)
def run_inference(speaker_count: int, *all_inputs):
speaker_count = int(speaker_count)
speaker_count = max(MIN_SPEAKERS, min(MAX_SPEAKERS, speaker_count))
reference_audio_values = all_inputs[:MAX_SPEAKERS]
prompt_text_values = all_inputs[MAX_SPEAKERS : 2 * MAX_SPEAKERS]
dialogue_text = all_inputs[2 * MAX_SPEAKERS]
text_normalize, sample_rate_normalize, temperature, top_p, top_k, repetition_penalty, max_new_tokens, model_path, codec_path, device, attn_implementation = all_inputs[
2 * MAX_SPEAKERS + 1 :
]
started_at = time.monotonic()
model, processor, torch_device, sample_rate = load_backend(
model_path=str(model_path),
codec_path=str(codec_path),
device_str=str(device),
attn_implementation=str(attn_implementation),
)
text_normalize = bool(text_normalize)
sample_rate_normalize = bool(sample_rate_normalize)
normalized_dialogue = str(dialogue_text or "").strip()
if text_normalize:
normalized_dialogue = normalize_text(normalized_dialogue)
normalized_dialogue = _validate_dialogue_text(normalized_dialogue, speaker_count)
cloned_speakers: list[int] = []
loaded_clone_wavs: list[tuple[torch.Tensor, int]] = []
prompt_text_map: dict[int, str] = {}
for idx in range(speaker_count):
ref_audio = reference_audio_values[idx]
prompt_text = str(prompt_text_values[idx] or "").strip()
has_reference = bool(ref_audio)
has_prompt_text = bool(prompt_text)
if has_reference != has_prompt_text:
raise ValueError(
f"S{idx + 1} must provide both reference audio and prompt text together."
)
if has_reference:
speaker_id = idx + 1
ref_audio_path = str(ref_audio)
cloned_speakers.append(speaker_id)
loaded_clone_wavs.append(_load_audio(ref_audio_path))
prompt_text_map[speaker_id] = _normalize_prompt_text(prompt_text, speaker_id)
prompt_audio: Optional[torch.Tensor] = None
reference_audio_codes: list[Optional[torch.Tensor]] = []
conversation_text = normalized_dialogue
if cloned_speakers:
conversation_text = _build_prefixed_text(
dialogue_text=normalized_dialogue,
prompt_text_map=prompt_text_map,
cloned_speakers=cloned_speakers,
)
if text_normalize:
conversation_text = normalize_text(conversation_text)
conversation_text = _validate_dialogue_text(conversation_text, speaker_count)
if sample_rate_normalize:
min_sr = min(sr for _, sr in loaded_clone_wavs)
else:
min_sr = None
clone_wavs: list[torch.Tensor] = []
for wav, orig_sr in loaded_clone_wavs:
processed_wav = wav
current_sr = int(orig_sr)
if min_sr is not None:
processed_wav = _resample_wav(processed_wav, current_sr, int(min_sr))
current_sr = int(min_sr)
processed_wav = _resample_wav(processed_wav, current_sr, sample_rate)
clone_wavs.append(processed_wav)
reference_audio_codes = _encode_reference_audio_codes(
processor=processor,
clone_wavs=clone_wavs,
cloned_speakers=cloned_speakers,
speaker_count=speaker_count,
sample_rate=sample_rate,
)
concat_prompt_wav = torch.cat(clone_wavs, dim=-1)
prompt_audio = processor.encode_audios_from_wav([concat_prompt_wav], sampling_rate=sample_rate)[0]
conversations, mode, mode_name = build_conversation(
dialogue_text=conversation_text,
reference_audio_codes=reference_audio_codes,
prompt_audio=prompt_audio,
processor=processor,
)
batch = processor(conversations, mode=mode)
input_ids = batch["input_ids"].to(torch_device)
attention_mask = batch["attention_mask"].to(torch_device)
with torch.no_grad():
outputs = model.generate(
input_ids=input_ids,
attention_mask=attention_mask,
max_new_tokens=int(max_new_tokens),
audio_temperature=float(temperature),
audio_top_p=float(top_p),
audio_top_k=int(top_k),
audio_repetition_penalty=float(repetition_penalty),
)
messages = processor.decode(outputs)
if not messages or messages[0] is None:
raise RuntimeError("The model did not return a decodable audio result.")
audio = messages[0].audio_codes_list[0]
if isinstance(audio, torch.Tensor):
audio_np = audio.detach().float().cpu().numpy()
else:
audio_np = np.asarray(audio, dtype=np.float32)
if audio_np.ndim > 1:
audio_np = audio_np.reshape(-1)
audio_np = audio_np.astype(np.float32, copy=False)
clone_summary = "none" if not cloned_speakers else ",".join([f"S{i}" for i in cloned_speakers])
elapsed = time.monotonic() - started_at
status = (
f"Done | mode={mode_name} | speakers={speaker_count} | cloned={clone_summary} | elapsed={elapsed:.2f}s | "
f"text_normalize={text_normalize}, sample_rate_normalize={sample_rate_normalize} | "
f"max_new_tokens={int(max_new_tokens)}, "
f"audio_temperature={float(temperature):.2f}, audio_top_p={float(top_p):.2f}, "
f"audio_top_k={int(top_k)}, audio_repetition_penalty={float(repetition_penalty):.2f}"
)
return (sample_rate, audio_np), status
def build_demo(args: argparse.Namespace):
custom_css = """
:root {
--bg: #f6f7f8;
--panel: #ffffff;
--ink: #111418;
--muted: #4d5562;
--line: #e5e7eb;
--accent: #0f766e;
}
.gradio-container {
background: linear-gradient(180deg, #f7f8fa 0%, #f3f5f7 100%);
color: var(--ink);
}
.app-card {
border: 1px solid var(--line);
border-radius: 16px;
background: var(--panel);
padding: 14px;
}
.app-title {
font-size: 22px;
font-weight: 700;
margin-bottom: 6px;
letter-spacing: 0.2px;
}
.app-subtitle {
color: var(--muted);
font-size: 14px;
margin-bottom: 8px;
}
#output_panel {
overflow: hidden !important;
}
#output_audio {
padding-bottom: 24px;
margin-bottom: 0;
overflow: hidden !important;
}
#output_audio > .wrap,
#output_audio .wrap,
#output_audio .audio-container,
#output_audio .block {
overflow: hidden !important;
}
#output_audio .audio-container {
padding-bottom: 10px;
min-height: 96px;
}
#output_audio_spacer {
height: 12px;
}
#output_status {
margin-top: 0;
}
#run-btn {
background: var(--accent);
border: none;
}
"""
with gr.Blocks(title="MOSS-TTSD Demo", css=custom_css) as demo:
gr.Markdown(
"""
<div class="app-card">
<div class="app-title">MOSS-TTSD</div>
<div class="app-subtitle">Multi-speaker dialogue synthesis with optional per-speaker voice cloning.</div>
</div>
"""
)
speaker_panels: list[gr.Group] = []
speaker_refs = []
speaker_prompts = []
with gr.Row(equal_height=False):
with gr.Column(scale=3):
speaker_count = gr.Slider(
minimum=MIN_SPEAKERS,
maximum=MAX_SPEAKERS,
step=1,
value=2,
label="Speaker Count",
info="Default 2 speakers. Minimum 1, maximum 5.",
)
gr.Markdown("### Voice Cloning (Optional, placed first)")
gr.Markdown(
"If you provide reference audio for a speaker, you must also provide that speaker's prompt text. "
"Prompt text may omit [Sx]; the app will auto-prepend it."
)
for idx in range(1, MAX_SPEAKERS + 1):
with gr.Group(visible=idx <= 2) as panel:
speaker_ref = gr.Audio(
label=f"S{idx} Reference Audio (Optional)",
type="filepath",
)
speaker_prompt = gr.Textbox(
label=f"S{idx} Prompt Text (Required with reference audio)",
lines=2,
placeholder=f"Example: [S{idx}] This is a prompt line for S{idx}.",
)
speaker_panels.append(panel)
speaker_refs.append(speaker_ref)
speaker_prompts.append(speaker_prompt)
gr.Markdown("### Multi-turn Dialogue")
dialogue_text = gr.Textbox(
label="Dialogue Text",
lines=12,
placeholder=(
"Use explicit tags in a single box, e.g.\n"
"[S1] Hello.\n"
"[S2] Hi, how are you?\n"
"[S1] Great, let's continue."
),
)
gr.Markdown(
"Without any reference audio, the model runs in generation mode. "
"Once any reference audio is provided, the model switches to voice-clone continuation mode."
)
with gr.Accordion("Sampling Parameters (Audio)", open=True):
gr.Markdown(
"- `text_normalize`: Normalize input text (**recommended to always enable**).\n"
"- `sample_rate_normalize`: Resample prompt audios to the lowest sample rate before encoding "
"(**recommended when using 2 or more speakers**)."
)
text_normalize = gr.Checkbox(
value=True,
label="text_normalize",
)
sample_rate_normalize = gr.Checkbox(
value=False,
label="sample_rate_normalize",
)
temperature = gr.Slider(
minimum=0.1,
maximum=3.0,
step=0.05,
value=1.1,
label="temperature",
)
top_p = gr.Slider(
minimum=0.1,
maximum=1.0,
step=0.01,
value=0.9,
label="top_p",
)
top_k = gr.Slider(
minimum=1,
maximum=200,
step=1,
value=50,
label="top_k",
)
repetition_penalty = gr.Slider(
minimum=0.8,
maximum=2.0,
step=0.05,
value=1.1,
label="repetition_penalty",
)
max_new_tokens = gr.Slider(
minimum=256,
maximum=8192,
step=128,
value=DEFAULT_MAX_NEW_TOKENS,
label="max_new_tokens",
)
run_btn = gr.Button("Generate Dialogue Audio", variant="primary", elem_id="run-btn")
with gr.Column(scale=2, elem_id="output_panel"):
output_audio = gr.Audio(label="Output Audio", type="numpy", elem_id="output_audio")
gr.HTML("", elem_id="output_audio_spacer")
status = gr.Textbox(label="Status", lines=4, interactive=False, elem_id="output_status")
preset_examples = gr.Dataframe(
headers=["Field", "Value (click any row to fill inputs)"],
value=PRESET_TABLE_ROWS,
datatype=["str", "str"],
row_count=(len(PRESET_TABLE_ROWS), "fixed"),
col_count=(2, "fixed"),
interactive=False,
wrap=True,
label="Preset Examples",
)
speaker_count.change(
fn=update_speaker_panels,
inputs=[speaker_count],
outputs=speaker_panels,
)
preset_examples.select(
fn=apply_preset_selection,
outputs=[
speaker_count,
speaker_refs[0],
speaker_prompts[0],
speaker_refs[1],
speaker_prompts[1],
dialogue_text,
*speaker_panels,
],
)
run_btn.click(
fn=lambda speaker_count, *inputs: run_inference(
speaker_count,
*inputs,
args.model_path,
args.codec_path,
args.device,
args.attn_implementation,
),
inputs=[
speaker_count,
*speaker_refs,
*speaker_prompts,
dialogue_text,
text_normalize,
sample_rate_normalize,
temperature,
top_p,
top_k,
repetition_penalty,
max_new_tokens,
],
outputs=[output_audio, status],
)
return demo
def main() -> None:
parser = argparse.ArgumentParser(description="MOSS-TTSD Gradio Demo")
parser.add_argument("--model_path", type=str, default=MODEL_PATH)
parser.add_argument("--codec_path", type=str, default=CODEC_MODEL_PATH)
parser.add_argument("--device", type=str, default="cuda:0")
parser.add_argument("--attn_implementation", type=str, default=DEFAULT_ATTN_IMPLEMENTATION)
parser.add_argument("--host", type=str, default="0.0.0.0")
parser.add_argument("--port", type=int, default=7863)
parser.add_argument("--share", action="store_true")
args = parser.parse_args()
runtime_device = torch.device(args.device if torch.cuda.is_available() else "cpu")
runtime_dtype = torch.bfloat16 if runtime_device.type == "cuda" else torch.float32
args.attn_implementation = resolve_attn_implementation(
requested=args.attn_implementation,
device=runtime_device,
dtype=runtime_dtype,
) or "none"
print(f"[INFO] Using attn_implementation={args.attn_implementation}", flush=True)
preload_started_at = time.monotonic()
print(
f"[Startup] Preloading backend: model={args.model_path}, codec={args.codec_path}, "
f"device={args.device}, attn={args.attn_implementation}",
flush=True,
)
load_backend(
model_path=args.model_path,
codec_path=args.codec_path,
device_str=args.device,
attn_implementation=args.attn_implementation,
)
print(
f"[Startup] Backend preload finished in {time.monotonic() - preload_started_at:.2f}s",
flush=True,
)
demo = build_demo(args)
demo.queue(default_concurrency_limit=2).launch(
server_name=args.host,
server_port=args.port,
share=args.share,
)
if __name__ == "__main__":
main()