Spaces:
Running
on
Zero
Running
on
Zero
Delete cli.py
Browse files
cli.py
DELETED
|
@@ -1,122 +0,0 @@
|
|
| 1 |
-
from __future__ import annotations
|
| 2 |
-
|
| 3 |
-
import argparse
|
| 4 |
-
|
| 5 |
-
import torch
|
| 6 |
-
|
| 7 |
-
from .engine import Dia2
|
| 8 |
-
from .generation import (
|
| 9 |
-
build_generation_config,
|
| 10 |
-
load_script_text,
|
| 11 |
-
validate_generation_params,
|
| 12 |
-
)
|
| 13 |
-
|
| 14 |
-
|
| 15 |
-
def main() -> None:
|
| 16 |
-
parser = argparse.ArgumentParser(description="Generate audio with Dia2")
|
| 17 |
-
parser.add_argument("--config", help="Path to config.json (overrides repo lookup)")
|
| 18 |
-
parser.add_argument(
|
| 19 |
-
"--weights", help="Path to model.safetensors (overrides repo lookup)"
|
| 20 |
-
)
|
| 21 |
-
parser.add_argument(
|
| 22 |
-
"--hf",
|
| 23 |
-
required=False,
|
| 24 |
-
help="Hugging Face repo id to download config/weights from (e.g. nari-labs/Dia2-2B)",
|
| 25 |
-
)
|
| 26 |
-
parser.add_argument(
|
| 27 |
-
"--input", default="input.txt", help="Script text file (default: input.txt)"
|
| 28 |
-
)
|
| 29 |
-
parser.add_argument("output", help="Output WAV path")
|
| 30 |
-
parser.add_argument(
|
| 31 |
-
"--device",
|
| 32 |
-
default=None,
|
| 33 |
-
help="Computation device (defaults to cuda if available, else cpu)",
|
| 34 |
-
)
|
| 35 |
-
parser.add_argument(
|
| 36 |
-
"--dtype",
|
| 37 |
-
choices=["auto", "float32", "bfloat16"],
|
| 38 |
-
default="bfloat16",
|
| 39 |
-
help="Computation dtype (default: bfloat16)",
|
| 40 |
-
)
|
| 41 |
-
parser.add_argument("--topk", type=int, default=50)
|
| 42 |
-
parser.add_argument("--temperature", type=float, default=0.8)
|
| 43 |
-
parser.add_argument("--cfg", type=float, default=1.0)
|
| 44 |
-
parser.add_argument("--tokenizer", help="Tokenizer repo or local path override")
|
| 45 |
-
parser.add_argument(
|
| 46 |
-
"--mimi", help="Mimi repo id override (defaults to config/assets)"
|
| 47 |
-
)
|
| 48 |
-
parser.add_argument("--prefix-speaker-1", help="Prefix audio file for speaker 1")
|
| 49 |
-
parser.add_argument("--prefix-speaker-2", help="Prefix audio file for speaker 2")
|
| 50 |
-
parser.add_argument(
|
| 51 |
-
"--include-prefix",
|
| 52 |
-
action="store_true",
|
| 53 |
-
help="Keep prefix audio in the final waveform (default: trimmed)",
|
| 54 |
-
)
|
| 55 |
-
parser.add_argument(
|
| 56 |
-
"--verbose", action="store_true", help="Print generation progress logs"
|
| 57 |
-
)
|
| 58 |
-
parser.add_argument(
|
| 59 |
-
"--cuda-graph",
|
| 60 |
-
action="store_true",
|
| 61 |
-
help="Run generation with CUDA graph capture",
|
| 62 |
-
)
|
| 63 |
-
args = parser.parse_args()
|
| 64 |
-
|
| 65 |
-
device = args.device
|
| 66 |
-
if device is None or device == "auto":
|
| 67 |
-
device = "cuda" if torch.cuda.is_available() else "cpu"
|
| 68 |
-
dtype = args.dtype or "bfloat16"
|
| 69 |
-
|
| 70 |
-
repo = args.hf
|
| 71 |
-
if repo:
|
| 72 |
-
dia = Dia2(
|
| 73 |
-
repo=repo,
|
| 74 |
-
device=device,
|
| 75 |
-
dtype=dtype,
|
| 76 |
-
tokenizer_id=args.tokenizer,
|
| 77 |
-
mimi_id=args.mimi,
|
| 78 |
-
)
|
| 79 |
-
elif args.config and args.weights:
|
| 80 |
-
dia = Dia2.from_local(
|
| 81 |
-
config_path=args.config,
|
| 82 |
-
weights_path=args.weights,
|
| 83 |
-
device=device,
|
| 84 |
-
dtype=dtype,
|
| 85 |
-
tokenizer_id=args.tokenizer,
|
| 86 |
-
mimi_id=args.mimi,
|
| 87 |
-
)
|
| 88 |
-
else:
|
| 89 |
-
raise ValueError("Provide --hf/--variant or both --config and --weights")
|
| 90 |
-
|
| 91 |
-
script = load_script_text(args.input)
|
| 92 |
-
temperature, top_k, cfg_scale = validate_generation_params(
|
| 93 |
-
temperature=args.temperature,
|
| 94 |
-
top_k=args.topk,
|
| 95 |
-
cfg_scale=args.cfg,
|
| 96 |
-
)
|
| 97 |
-
config = build_generation_config(
|
| 98 |
-
temperature=temperature,
|
| 99 |
-
top_k=top_k,
|
| 100 |
-
cfg_scale=cfg_scale,
|
| 101 |
-
)
|
| 102 |
-
overrides = {}
|
| 103 |
-
if args.cuda_graph:
|
| 104 |
-
overrides["use_cuda_graph"] = True
|
| 105 |
-
if args.prefix_speaker_1:
|
| 106 |
-
overrides["prefix_speaker_1"] = args.prefix_speaker_1
|
| 107 |
-
if args.prefix_speaker_2:
|
| 108 |
-
overrides["prefix_speaker_2"] = args.prefix_speaker_2
|
| 109 |
-
if args.include_prefix:
|
| 110 |
-
overrides["include_prefix"] = True
|
| 111 |
-
|
| 112 |
-
dia.generate(
|
| 113 |
-
script,
|
| 114 |
-
config=config,
|
| 115 |
-
output_wav=args.output,
|
| 116 |
-
verbose=args.verbose,
|
| 117 |
-
**overrides,
|
| 118 |
-
)
|
| 119 |
-
|
| 120 |
-
|
| 121 |
-
if __name__ == "__main__":
|
| 122 |
-
main()
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|