File size: 3,984 Bytes
aed1d05
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
#!/usr/bin/env python3
"""
Generate audio from MODEL-W session specs using ACE-Step 1.5.

End-to-end pipeline:
  session JSON → caption + metadata → ACE-Step DiT → rendered audio

Usage:
  python scripts/generate_audio.py --sessions synthetic/sessions/corpus_200 --out output/audio
  python scripts/generate_audio.py --session synthetic/sessions/example_trap_fullsong.json
  python scripts/generate_audio.py --caption "dark trap beat, D minor, 140 BPM" --duration 60
"""

from __future__ import annotations

import argparse
import json
import sys
from pathlib import Path

_ROOT = Path(__file__).resolve().parents[1]
if str(_ROOT) not in sys.path:
    sys.path.insert(0, str(_ROOT))

from modelw.acestep_bridge import (
    ACEStepBridge,
    ACEStepConfig,
    session_to_caption,
    preview_captions,
)


def load_env_config() -> dict:
    """Read .env.acestep if it exists."""
    env_file = _ROOT / ".env.acestep"
    cfg = {}
    if env_file.exists():
        for line in env_file.read_text().splitlines():
            if "=" in line and not line.startswith("#"):
                k, v = line.split("=", 1)
                cfg[k.strip()] = v.strip()
    return cfg


def main():
    ap = argparse.ArgumentParser(description="Generate audio from MODEL-W sessions via ACE-Step")
    grp = ap.add_mutually_exclusive_group(required=True)
    grp.add_argument("--sessions", type=str, help="Directory of session JSON files")
    grp.add_argument("--session", type=str, help="Single session JSON file")
    grp.add_argument("--caption", type=str, help="Direct text caption (no session file)")
    grp.add_argument("--preview", type=str, help="Preview captions without generating (no GPU)")

    ap.add_argument("--out", type=str, default="output/audio")
    ap.add_argument("--max-files", type=int, default=None)
    ap.add_argument("--batch-size", type=int, default=1)
    ap.add_argument("--duration", type=float, default=None)
    ap.add_argument("--bpm", type=int, default=120)
    ap.add_argument("--seed", type=int, default=-1)
    ap.add_argument("--dit", type=str, default=None, help="DiT config override")
    ap.add_argument("--lm", type=str, default=None, help="LM model override")
    ap.add_argument("--device", type=str, default=None)
    args = ap.parse_args()

    if args.preview:
        preview_captions(args.preview, max_files=args.max_files or 20)
        return

    env = load_env_config()
    config = ACEStepConfig(
        acestep_root=env.get("ACESTEP_ROOT", str(_ROOT / "models/ace-step")),
        dit_config=args.dit or env.get("ACESTEP_DIT_CONFIG", "acestep-v15-turbo"),
        lm_model=args.lm or env.get("ACESTEP_LM_MODEL", "acestep-5Hz-lm-1.7B"),
        output_dir=args.out,
    )
    if args.device:
        config.device = args.device

    bridge = ACEStepBridge(config)
    bridge.initialize()

    if args.caption:
        result = bridge.generate_from_caption(
            caption=args.caption,
            bpm=args.bpm,
            duration=args.duration or 30.0,
            batch_size=args.batch_size,
            seed=args.seed,
            save_dir=args.out,
        )
        if result.success:
            for audio in result.audios:
                print(f"Generated: {audio['path']}")
        else:
            print(f"Error: {result.error}")

    elif args.session:
        result = bridge.generate_from_session_file(
            args.session,
            duration=args.duration,
            batch_size=args.batch_size,
            seed=args.seed,
            save_dir=args.out,
        )
        if result.success:
            for audio in result.audios:
                print(f"Generated: {audio['path']}")
        else:
            print(f"Error: {result.error}")

    elif args.sessions:
        bridge.batch_generate_corpus(
            sessions_dir=args.sessions,
            save_dir=args.out,
            max_files=args.max_files,
            batch_size=args.batch_size,
        )


if __name__ == "__main__":
    main()