File size: 4,858 Bytes
e3f3734
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
"""

Gradio and CLI entrypoint for the text-to-audio pipeline.

Run: python demo.py [--cli] [--model PRESET] [--quantize] [--text "Hello world"]

"""

from __future__ import annotations

import argparse
import sys

import numpy as np


def run_gradio(

    preset: str = "csm-1b",

    use_4bit: bool = False,

    use_8bit: bool = False,

) -> None:
    import gradio as gr
    import soundfile as sf
    from src.text_to_audio import build_pipeline, list_presets

    presets = list_presets()
    pipe = build_pipeline(
        preset=preset,
        use_4bit=use_4bit,
        use_8bit=use_8bit,
    )

    def generate_audio(text: str, progress=gr.Progress()) -> str | None:
        if not text or not text.strip():
            return None
        progress(0.2, desc="Generating...")
        try:
            out, profile = pipe.generate_with_profile(text.strip())
            single = out if isinstance(out, dict) else out[0]
            audio = single["audio"]
            sr = single["sampling_rate"]
            if hasattr(audio, "numpy"):
                arr = audio.numpy()
            else:
                arr = np.asarray(audio)
            path = "/tmp/tta_output.wav"
            sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
            progress(1.0, desc=f"Done — {profile.get('time_s', 0):.2f}s, RTF={profile.get('rtf', 0):.2f}")
            return path
        except Exception as e:
            raise gr.Error(str(e)) from e

    with gr.Blocks(title="TransformerPrime TTA", theme=gr.themes.Soft()) as app:
        gr.Markdown("# Text-to-Audio (HF pipeline, GPU-optimized)")
        with gr.Row():
            text_in = gr.Textbox(
                label="Text",
                placeholder="Enter text to synthesize (e.g. Hello, this is a test.)",
                lines=3,
            )
        with gr.Row():
            gen_btn = gr.Button("Generate", variant="primary")
        with gr.Row():
            audio_out = gr.Audio(label="Output", type="filepath")
        status = gr.Markdown("")
        gen_btn.click(
            fn=generate_audio,
            inputs=[text_in],
            outputs=[audio_out],
        ).then(
            fn=lambda: "Ready.",
            outputs=[status],
        )
        gr.Markdown("### Prompt ideas\n- **Speech:** \"Welcome to the demo. This model runs on GPU with low latency.\"\n- **Expressive:** Use punctuation and short sentences for best quality.\n- **Music (MusicGen):** Switch preset to musicgen-small and try: \"Upbeat electronic dance music with a strong bass line.\"")

    app.launch(server_name="0.0.0.0", server_port=7860)


def run_cli(

    text: str,

    output_path: str,

    preset: str = "csm-1b",

    use_4bit: bool = False,

    use_8bit: bool = False,

    profile: bool = True,

) -> int:
    from src.text_to_audio import build_pipeline
    import soundfile as sf

    pipe = build_pipeline(preset=preset, use_4bit=use_4bit, use_8bit=use_8bit)
    if profile:
        out, prof = pipe.generate_with_profile(text)
        print(f"Time: {prof.get('time_s', 0):.2f}s | RTF: {prof.get('rtf', 0):.2f} | VRAM peak: {prof.get('vram_peak_mb', 0):.0f} MB")
    else:
        out = pipe.generate(text)
    single = out if isinstance(out, dict) else out[0]
    audio = single["audio"]
    sr = single["sampling_rate"]
    if hasattr(audio, "numpy"):
        arr = audio.numpy()
    else:
        arr = np.asarray(audio)
    sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
    print(f"Wrote {output_path} ({sr} Hz)")
    return 0


def main() -> int:
    parser = argparse.ArgumentParser(description="TransformerPrime text-to-audio demo")
    parser.add_argument("--cli", action="store_true", help="Use CLI instead of Gradio")
    parser.add_argument("--model", default="csm-1b", choices=["csm-1b", "bark-small", "speecht5", "musicgen-small"], help="Model preset")
    parser.add_argument("--quantize", action="store_true", help="Load in 4-bit (low VRAM)")
    parser.add_argument("--text", default="", help="Input text (CLI mode)")
    parser.add_argument("--output", "-o", default="output.wav", help="Output WAV path (CLI)")
    parser.add_argument("--no-profile", action="store_true", help="Disable timing/VRAM print")
    args = parser.parse_args()

    if args.cli:
        text = args.text or "Hello from TransformerPrime. This is a GPU-accelerated text-to-audio pipeline."
        return run_cli(
            text=text,
            output_path=args.output,
            preset=args.model,
            use_4bit=args.quantize,
            use_8bit=False,
            profile=not args.no_profile,
        )
    run_gradio(preset=args.model, use_4bit=args.quantize, use_8bit=False)
    return 0


if __name__ == "__main__":
    sys.exit(main())