Keith
Initial commit for HF Space
e3f3734
"""
Gradio and CLI entrypoint for the text-to-audio pipeline.
Run: python demo.py [--cli] [--model PRESET] [--quantize] [--text "Hello world"]
"""
from __future__ import annotations
import argparse
import sys
import numpy as np
def run_gradio(
preset: str = "csm-1b",
use_4bit: bool = False,
use_8bit: bool = False,
) -> None:
import gradio as gr
import soundfile as sf
from src.text_to_audio import build_pipeline, list_presets
presets = list_presets()
pipe = build_pipeline(
preset=preset,
use_4bit=use_4bit,
use_8bit=use_8bit,
)
def generate_audio(text: str, progress=gr.Progress()) -> str | None:
if not text or not text.strip():
return None
progress(0.2, desc="Generating...")
try:
out, profile = pipe.generate_with_profile(text.strip())
single = out if isinstance(out, dict) else out[0]
audio = single["audio"]
sr = single["sampling_rate"]
if hasattr(audio, "numpy"):
arr = audio.numpy()
else:
arr = np.asarray(audio)
path = "/tmp/tta_output.wav"
sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
progress(1.0, desc=f"Done — {profile.get('time_s', 0):.2f}s, RTF={profile.get('rtf', 0):.2f}")
return path
except Exception as e:
raise gr.Error(str(e)) from e
with gr.Blocks(title="TransformerPrime TTA", theme=gr.themes.Soft()) as app:
gr.Markdown("# Text-to-Audio (HF pipeline, GPU-optimized)")
with gr.Row():
text_in = gr.Textbox(
label="Text",
placeholder="Enter text to synthesize (e.g. Hello, this is a test.)",
lines=3,
)
with gr.Row():
gen_btn = gr.Button("Generate", variant="primary")
with gr.Row():
audio_out = gr.Audio(label="Output", type="filepath")
status = gr.Markdown("")
gen_btn.click(
fn=generate_audio,
inputs=[text_in],
outputs=[audio_out],
).then(
fn=lambda: "Ready.",
outputs=[status],
)
gr.Markdown("### Prompt ideas\n- **Speech:** \"Welcome to the demo. This model runs on GPU with low latency.\"\n- **Expressive:** Use punctuation and short sentences for best quality.\n- **Music (MusicGen):** Switch preset to musicgen-small and try: \"Upbeat electronic dance music with a strong bass line.\"")
app.launch(server_name="0.0.0.0", server_port=7860)
def run_cli(
text: str,
output_path: str,
preset: str = "csm-1b",
use_4bit: bool = False,
use_8bit: bool = False,
profile: bool = True,
) -> int:
from src.text_to_audio import build_pipeline
import soundfile as sf
pipe = build_pipeline(preset=preset, use_4bit=use_4bit, use_8bit=use_8bit)
if profile:
out, prof = pipe.generate_with_profile(text)
print(f"Time: {prof.get('time_s', 0):.2f}s | RTF: {prof.get('rtf', 0):.2f} | VRAM peak: {prof.get('vram_peak_mb', 0):.0f} MB")
else:
out = pipe.generate(text)
single = out if isinstance(out, dict) else out[0]
audio = single["audio"]
sr = single["sampling_rate"]
if hasattr(audio, "numpy"):
arr = audio.numpy()
else:
arr = np.asarray(audio)
sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
print(f"Wrote {output_path} ({sr} Hz)")
return 0
def main() -> int:
parser = argparse.ArgumentParser(description="TransformerPrime text-to-audio demo")
parser.add_argument("--cli", action="store_true", help="Use CLI instead of Gradio")
parser.add_argument("--model", default="csm-1b", choices=["csm-1b", "bark-small", "speecht5", "musicgen-small"], help="Model preset")
parser.add_argument("--quantize", action="store_true", help="Load in 4-bit (low VRAM)")
parser.add_argument("--text", default="", help="Input text (CLI mode)")
parser.add_argument("--output", "-o", default="output.wav", help="Output WAV path (CLI)")
parser.add_argument("--no-profile", action="store_true", help="Disable timing/VRAM print")
args = parser.parse_args()
if args.cli:
text = args.text or "Hello from TransformerPrime. This is a GPU-accelerated text-to-audio pipeline."
return run_cli(
text=text,
output_path=args.output,
preset=args.model,
use_4bit=args.quantize,
use_8bit=False,
profile=not args.no_profile,
)
run_gradio(preset=args.model, use_4bit=args.quantize, use_8bit=False)
return 0
if __name__ == "__main__":
sys.exit(main())