tts-v21 / demo.py
michael-chan-000's picture
Upload model
79529ed verified
"""demo.py — quick smoke test for vocence_miner_v1.
Reads the merged checkpoint either from a local path or from the Hugging Face Hub,
then generates a small set of preset clips that exercise the prompt-following range.
pip install qwen-tts transformers torch soundfile
python demo.py # uses the current directory
python demo.py --source magma90909/vocence_miner_v8 # pull from HF
"""
from __future__ import annotations
import argparse
import dataclasses
import sys
from pathlib import Path
import soundfile as sf
import torch
from qwen_tts import Qwen3TTSModel
@dataclasses.dataclass(frozen=True)
class Sample:
slug: str
say: str
voice: str
SAMPLES: tuple[Sample, ...] = (
Sample(
slug="warm_male_storyteller",
say="Long ago, in a kingdom by the sea, a young girl made a remarkable discovery.",
voice="An older male narrator reads a bedtime story slowly, with warmth.",
),
Sample(
slug="whisper_female",
say="Don't say a word. Just listen carefully.",
voice="A young woman whispers, conspiratorial, low energy, very quiet.",
),
Sample(
slug="projecting_announcer",
say="And he scores in the final second of the match!",
voice="A high-pitched announcer projects an exciting headline at a fast pace.",
),
)
SAMPLER = dict(
temperature=0.85,
top_k=50,
top_p=0.95,
repetition_penalty=1.05,
max_new_tokens=600,
do_sample=True,
)
def parse_args(argv: list[str] | None = None) -> argparse.Namespace:
p = argparse.ArgumentParser(description=__doc__.split("\n", 1)[0])
p.add_argument("--source", default=".", help="HF repo id or local checkpoint dir")
p.add_argument("--out", default="./demo_out", help="output dir for wav files")
p.add_argument("--precision", default="bfloat16", choices=("bfloat16", "float16", "float32"))
p.add_argument("--device", default="cuda:0" if torch.cuda.is_available() else "cpu")
return p.parse_args(argv)
def load(source: str, device: str, precision: str) -> Qwen3TTSModel:
dtype = {"bfloat16": torch.bfloat16, "float16": torch.float16, "float32": torch.float32}[precision]
print(f"[demo] loading {source!r} -> {device} ({precision})", flush=True)
return Qwen3TTSModel.from_pretrained(source, device_map=device, dtype=dtype)
def synth_one(model: Qwen3TTSModel, sample: Sample, out_dir: Path) -> Path:
wavs, sr = model.generate_voice_design(
text=sample.say,
instruct=sample.voice,
language="english",
**SAMPLER,
)
target = out_dir / f"{sample.slug}.wav"
sf.write(target, wavs[0], sr)
duration = len(wavs[0]) / sr
print(f" -> {target.name} ({duration:.2f}s @ {sr} Hz)")
return target
def run(args: argparse.Namespace) -> int:
out_dir = Path(args.out)
out_dir.mkdir(parents=True, exist_ok=True)
model = load(args.source, args.device, args.precision)
for sample in SAMPLES:
synth_one(model, sample, out_dir)
print(f"[demo] {len(SAMPLES)} clips written to {out_dir}/", flush=True)
return 0
if __name__ == "__main__":
sys.exit(run(parse_args()))