Spaces:
Runtime error
Runtime error
| """ | |
| Gradio and CLI entrypoint for the text-to-audio pipeline. | |
| Run: python demo.py [--cli] [--model PRESET] [--quantize] [--text "Hello world"] | |
| """ | |
| from __future__ import annotations | |
| import argparse | |
| import sys | |
| import numpy as np | |
| def run_gradio( | |
| preset: str = "csm-1b", | |
| use_4bit: bool = False, | |
| use_8bit: bool = False, | |
| ) -> None: | |
| import gradio as gr | |
| import soundfile as sf | |
| from src.text_to_audio import build_pipeline, list_presets | |
| presets = list_presets() | |
| pipe = build_pipeline( | |
| preset=preset, | |
| use_4bit=use_4bit, | |
| use_8bit=use_8bit, | |
| ) | |
| def generate_audio(text: str, progress=gr.Progress()) -> str | None: | |
| if not text or not text.strip(): | |
| return None | |
| progress(0.2, desc="Generating...") | |
| try: | |
| out, profile = pipe.generate_with_profile(text.strip()) | |
| single = out if isinstance(out, dict) else out[0] | |
| audio = single["audio"] | |
| sr = single["sampling_rate"] | |
| if hasattr(audio, "numpy"): | |
| arr = audio.numpy() | |
| else: | |
| arr = np.asarray(audio) | |
| path = "/tmp/tta_output.wav" | |
| sf.write(path, arr.T if arr.ndim == 2 else arr, sr) | |
| progress(1.0, desc=f"Done — {profile.get('time_s', 0):.2f}s, RTF={profile.get('rtf', 0):.2f}") | |
| return path | |
| except Exception as e: | |
| raise gr.Error(str(e)) from e | |
| with gr.Blocks(title="TransformerPrime TTA", theme=gr.themes.Soft()) as app: | |
| gr.Markdown("# Text-to-Audio (HF pipeline, GPU-optimized)") | |
| with gr.Row(): | |
| text_in = gr.Textbox( | |
| label="Text", | |
| placeholder="Enter text to synthesize (e.g. Hello, this is a test.)", | |
| lines=3, | |
| ) | |
| with gr.Row(): | |
| gen_btn = gr.Button("Generate", variant="primary") | |
| with gr.Row(): | |
| audio_out = gr.Audio(label="Output", type="filepath") | |
| status = gr.Markdown("") | |
| gen_btn.click( | |
| fn=generate_audio, | |
| inputs=[text_in], | |
| outputs=[audio_out], | |
| ).then( | |
| fn=lambda: "Ready.", | |
| outputs=[status], | |
| ) | |
| gr.Markdown("### Prompt ideas\n- **Speech:** \"Welcome to the demo. This model runs on GPU with low latency.\"\n- **Expressive:** Use punctuation and short sentences for best quality.\n- **Music (MusicGen):** Switch preset to musicgen-small and try: \"Upbeat electronic dance music with a strong bass line.\"") | |
| app.launch(server_name="0.0.0.0", server_port=7860) | |
| def run_cli( | |
| text: str, | |
| output_path: str, | |
| preset: str = "csm-1b", | |
| use_4bit: bool = False, | |
| use_8bit: bool = False, | |
| profile: bool = True, | |
| ) -> int: | |
| from src.text_to_audio import build_pipeline | |
| import soundfile as sf | |
| pipe = build_pipeline(preset=preset, use_4bit=use_4bit, use_8bit=use_8bit) | |
| if profile: | |
| out, prof = pipe.generate_with_profile(text) | |
| print(f"Time: {prof.get('time_s', 0):.2f}s | RTF: {prof.get('rtf', 0):.2f} | VRAM peak: {prof.get('vram_peak_mb', 0):.0f} MB") | |
| else: | |
| out = pipe.generate(text) | |
| single = out if isinstance(out, dict) else out[0] | |
| audio = single["audio"] | |
| sr = single["sampling_rate"] | |
| if hasattr(audio, "numpy"): | |
| arr = audio.numpy() | |
| else: | |
| arr = np.asarray(audio) | |
| sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr) | |
| print(f"Wrote {output_path} ({sr} Hz)") | |
| return 0 | |
| def main() -> int: | |
| parser = argparse.ArgumentParser(description="TransformerPrime text-to-audio demo") | |
| parser.add_argument("--cli", action="store_true", help="Use CLI instead of Gradio") | |
| parser.add_argument("--model", default="csm-1b", choices=["csm-1b", "bark-small", "speecht5", "musicgen-small"], help="Model preset") | |
| parser.add_argument("--quantize", action="store_true", help="Load in 4-bit (low VRAM)") | |
| parser.add_argument("--text", default="", help="Input text (CLI mode)") | |
| parser.add_argument("--output", "-o", default="output.wav", help="Output WAV path (CLI)") | |
| parser.add_argument("--no-profile", action="store_true", help="Disable timing/VRAM print") | |
| args = parser.parse_args() | |
| if args.cli: | |
| text = args.text or "Hello from TransformerPrime. This is a GPU-accelerated text-to-audio pipeline." | |
| return run_cli( | |
| text=text, | |
| output_path=args.output, | |
| preset=args.model, | |
| use_4bit=args.quantize, | |
| use_8bit=False, | |
| profile=not args.no_profile, | |
| ) | |
| run_gradio(preset=args.model, use_4bit=args.quantize, use_8bit=False) | |
| return 0 | |
| if __name__ == "__main__": | |
| sys.exit(main()) | |