Spaces:
Runtime error
Runtime error
File size: 4,858 Bytes
e3f3734 | 1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48 49 50 51 52 53 54 55 56 57 58 59 60 61 62 63 64 65 66 67 68 69 70 71 72 73 74 75 76 77 78 79 80 81 82 83 84 85 86 87 88 89 90 91 92 93 94 95 96 97 98 99 100 101 102 103 104 105 106 107 108 109 110 111 112 113 114 115 116 117 118 119 120 121 122 123 124 125 126 127 128 129 130 131 | """
Gradio and CLI entrypoint for the text-to-audio pipeline.
Run: python demo.py [--cli] [--model PRESET] [--quantize] [--text "Hello world"]
"""
from __future__ import annotations
import argparse
import sys
import numpy as np
def run_gradio(
preset: str = "csm-1b",
use_4bit: bool = False,
use_8bit: bool = False,
) -> None:
import gradio as gr
import soundfile as sf
from src.text_to_audio import build_pipeline, list_presets
presets = list_presets()
pipe = build_pipeline(
preset=preset,
use_4bit=use_4bit,
use_8bit=use_8bit,
)
def generate_audio(text: str, progress=gr.Progress()) -> str | None:
if not text or not text.strip():
return None
progress(0.2, desc="Generating...")
try:
out, profile = pipe.generate_with_profile(text.strip())
single = out if isinstance(out, dict) else out[0]
audio = single["audio"]
sr = single["sampling_rate"]
if hasattr(audio, "numpy"):
arr = audio.numpy()
else:
arr = np.asarray(audio)
path = "/tmp/tta_output.wav"
sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
progress(1.0, desc=f"Done — {profile.get('time_s', 0):.2f}s, RTF={profile.get('rtf', 0):.2f}")
return path
except Exception as e:
raise gr.Error(str(e)) from e
with gr.Blocks(title="TransformerPrime TTA", theme=gr.themes.Soft()) as app:
gr.Markdown("# Text-to-Audio (HF pipeline, GPU-optimized)")
with gr.Row():
text_in = gr.Textbox(
label="Text",
placeholder="Enter text to synthesize (e.g. Hello, this is a test.)",
lines=3,
)
with gr.Row():
gen_btn = gr.Button("Generate", variant="primary")
with gr.Row():
audio_out = gr.Audio(label="Output", type="filepath")
status = gr.Markdown("")
gen_btn.click(
fn=generate_audio,
inputs=[text_in],
outputs=[audio_out],
).then(
fn=lambda: "Ready.",
outputs=[status],
)
gr.Markdown("### Prompt ideas\n- **Speech:** \"Welcome to the demo. This model runs on GPU with low latency.\"\n- **Expressive:** Use punctuation and short sentences for best quality.\n- **Music (MusicGen):** Switch preset to musicgen-small and try: \"Upbeat electronic dance music with a strong bass line.\"")
app.launch(server_name="0.0.0.0", server_port=7860)
def run_cli(
text: str,
output_path: str,
preset: str = "csm-1b",
use_4bit: bool = False,
use_8bit: bool = False,
profile: bool = True,
) -> int:
from src.text_to_audio import build_pipeline
import soundfile as sf
pipe = build_pipeline(preset=preset, use_4bit=use_4bit, use_8bit=use_8bit)
if profile:
out, prof = pipe.generate_with_profile(text)
print(f"Time: {prof.get('time_s', 0):.2f}s | RTF: {prof.get('rtf', 0):.2f} | VRAM peak: {prof.get('vram_peak_mb', 0):.0f} MB")
else:
out = pipe.generate(text)
single = out if isinstance(out, dict) else out[0]
audio = single["audio"]
sr = single["sampling_rate"]
if hasattr(audio, "numpy"):
arr = audio.numpy()
else:
arr = np.asarray(audio)
sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
print(f"Wrote {output_path} ({sr} Hz)")
return 0
def main() -> int:
parser = argparse.ArgumentParser(description="TransformerPrime text-to-audio demo")
parser.add_argument("--cli", action="store_true", help="Use CLI instead of Gradio")
parser.add_argument("--model", default="csm-1b", choices=["csm-1b", "bark-small", "speecht5", "musicgen-small"], help="Model preset")
parser.add_argument("--quantize", action="store_true", help="Load in 4-bit (low VRAM)")
parser.add_argument("--text", default="", help="Input text (CLI mode)")
parser.add_argument("--output", "-o", default="output.wav", help="Output WAV path (CLI)")
parser.add_argument("--no-profile", action="store_true", help="Disable timing/VRAM print")
args = parser.parse_args()
if args.cli:
text = args.text or "Hello from TransformerPrime. This is a GPU-accelerated text-to-audio pipeline."
return run_cli(
text=text,
output_path=args.output,
preset=args.model,
use_4bit=args.quantize,
use_8bit=False,
profile=not args.no_profile,
)
run_gradio(preset=args.model, use_4bit=args.quantize, use_8bit=False)
return 0
if __name__ == "__main__":
sys.exit(main())
|