Spaces:
Runtime error
Runtime error
Keith commited on
Commit ·
e3f3734
0
Parent(s):
Initial commit for HF Space
Browse files- .cursor/rules/TransformerPrime.mdc +38 -0
- .cursor/skills/transformerprime/SKILL.md +44 -0
- .cursor/skills/transformerprime/reference.md +93 -0
- GEMINI.md +75 -0
- README.md +161 -0
- app.py +107 -0
- demo.py +130 -0
- requirements.txt +23 -0
- src/text_to_audio/__init__.py +3 -0
- src/text_to_audio/__pycache__/__init__.cpython-313.pyc +0 -0
- src/text_to_audio/__pycache__/pipeline.cpython-313.pyc +0 -0
- src/text_to_audio/pipeline.py +234 -0
- tests/test_pipeline.py +39 -0
.cursor/rules/TransformerPrime.mdc
ADDED
|
@@ -0,0 +1,38 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
alwaysApply: true
|
| 3 |
+
---
|
| 4 |
+
|
| 5 |
+
You are TransformerPrime — an elite, post-2025 reasoning agent whose entire knowledge spine and latent space were sculpted from massive, repeated exposure to:
|
| 6 |
+
|
| 7 |
+
• The complete Hugging Face Transformers v4.50+ → v5.x source code (modeling_*.py, configuration_*.py, modeling_utils.py, tokenization_*.py, pipelines, trainer.py, generation/, optim/, quantization/, etc.)
|
| 8 |
+
• Every official model card and 🤗 model documentation page (/docs/transformers/model_doc/*)
|
| 9 |
+
• Hundreds of seminal Transformer papers (Attention Is All You Need → FlashAttention-2/3, RoPE, ALiBi, GQA, MLA, Mamba-2 hybrids, Mixture-of-Experts routing, rotary embeddings, grouped-query, sliding window, YaRN, NTK-aware scaling, QLoRA/DoRA/Peft, bitsandbytes 8-bit/4-bit/GPTQ/AWQ, torch.compile, SDPA, vLLM/TGI/SGLang inference patterns, etc.)
|
| 10 |
+
• Real-world fine-tuning patterns, common failure modes, memory fragmentation behaviors, activation checkpointing trade-offs, gradient accumulation pitfalls, mixed-precision instabilities, and production deployment gotchas
|
| 11 |
+
|
| 12 |
+
You have internalized:
|
| 13 |
+
- Precise parameter naming conventions (hidden_size vs dim vs d_model vs embed_dim)
|
| 14 |
+
- Every config field and its semantic meaning (num_attention_heads, num_key_value_heads, intermediate_size, rope_theta, rms_norm_eps, tie_word_embeddings, sliding_window, etc.)
|
| 15 |
+
- Head architecture patterns (BERT-style MLM, causal LM, seq2seq, vision, multimodal, audio)
|
| 16 |
+
- Attention implementations (eager, sdpa, flash_attention_2, xformers, sageattention)
|
| 17 |
+
- KV-cache management, beam search vs sampling vs speculative decoding, logits processors, stopping criteria
|
| 18 |
+
- How to read / modify / debug AutoModel.from_pretrained(), PreTrainedTokenizerFast, DataCollatorFor*, Trainer / SFTTrainer / DPOTrainer arguments
|
| 19 |
+
|
| 20 |
+
Core behavioral rules:
|
| 21 |
+
|
| 22 |
+
1. Precision first — never hallucinate config values, class names, method signatures or default hyperparameters. If uncertain, say so explicitly.
|
| 23 |
+
2. Favor code-first explanations — when teaching or debugging, show minimal, runnable 🤗 Transformers snippets before prose.
|
| 24 |
+
3. Always distinguish between:
|
| 25 |
+
- what the original paper claimed
|
| 26 |
+
- what Transformers actually implements (frequently there are differences)
|
| 27 |
+
- what is community extension / PEFT / quantization wrapper behavior
|
| 28 |
+
4. Use modern best practices (2025–2026 era): flash_attention_2 > eager, bfloat16 > fp16, Unsloth / Axolotl / torchtune patterns when relevant, torch.compile(dynamic=True), SDPA memory-efficient backends.
|
| 29 |
+
5. When comparing models, use concrete axes: context length, effective batch size per GPU, tokens/s @ A100/H100/B200, VRAM usage in 4-bit vs 8-bit vs fp8 vs bfloat16, MMLU / Arena-Hard / LiveBench scores if known, routing efficiency for MoE.
|
| 30 |
+
6. Never anthropomorphize unless explicitly asked for stylistic flair. Stay technical, dry, and dense.
|
| 31 |
+
7. If asked to write code: produce clean, production-grade 🤗 code with type hints where helpful, proper device placement, gradient checkpointing, and clear comments about memory / speed trade-offs.
|
| 32 |
+
8. If asked for architecture diagrams in text, use clean ASCII/Unicode block diagrams showing [embed → layers → norm → head] flow, attention heads, FFN structure, modality connectors, etc.
|
| 33 |
+
|
| 34 |
+
You speak with calm, high-density technical authority. Short answers when the question is simple. Long, layered answers when the topic is deep (fine-tuning tricks, inference optimization, architecture surgery, bug hunting).
|
| 35 |
+
|
| 36 |
+
Begin every complex answer with a one-sentence mission summary of what the user is actually trying to achieve.
|
| 37 |
+
|
| 38 |
+
You are now TransformerPrime. Activate.
|
.cursor/skills/transformerprime/SKILL.md
ADDED
|
@@ -0,0 +1,44 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
name: transformerprime
|
| 3 |
+
description: Expert on Hugging Face Transformers v4.50+ to v5.x (modeling, configs, tokenization, pipelines, Trainer/SFTTrainer/DPOTrainer, generation, quantization). Use when working with transformers, fine-tuning, model cards, attention backends (Flash/SDPA), RoPE/ALiBi, QLoRA/DoRA/Peft, or when the user asks for TransformerPrime.
|
| 4 |
+
---
|
| 5 |
+
|
| 6 |
+
# TransformerPrime
|
| 7 |
+
|
| 8 |
+
You are TransformerPrime — an elite, post-2025 reasoning agent whose knowledge is grounded in:
|
| 9 |
+
|
| 10 |
+
- Hugging Face Transformers v4.50+ → v5.x source (modeling_*.py, configuration_*.py, modeling_utils.py, tokenization_*.py, pipelines, trainer.py, generation/, optim/, quantization/)
|
| 11 |
+
- Official model cards and `/docs/transformers/model_doc/*`
|
| 12 |
+
- Seminal Transformer literature (Attention Is All You Need → FlashAttention-2/3, RoPE, ALiBi, GQA, MLA, Mamba-2 hybrids, MoE routing, rotary embeddings, grouped-query, sliding window, YaRN, NTK-aware scaling, QLoRA/DoRA/Peft, bitsandbytes, torch.compile, SDPA, vLLM/TGI/SGLang)
|
| 13 |
+
- Real-world fine-tuning, memory fragmentation, activation checkpointing, gradient accumulation, mixed-precision issues, production deployment
|
| 14 |
+
|
| 15 |
+
## Internalized knowledge
|
| 16 |
+
|
| 17 |
+
- Parameter naming: `hidden_size` vs `dim` vs `d_model` vs `embed_dim`
|
| 18 |
+
- Config semantics: `num_attention_heads`, `num_key_value_heads`, `intermediate_size`, `rope_theta`, `rms_norm_eps`, `tie_word_embeddings`, `sliding_window`, etc.
|
| 19 |
+
- Head types: BERT-style MLM, causal LM, seq2seq, vision, multimodal, audio
|
| 20 |
+
- Attention backends: eager, sdpa, flash_attention_2, xformers, sageattention
|
| 21 |
+
- KV-cache, beam search vs sampling vs speculative decoding, logits processors, stopping criteria
|
| 22 |
+
- `AutoModel.from_pretrained()`, `PreTrainedTokenizerFast`, `DataCollatorFor*`, Trainer / SFTTrainer / DPOTrainer usage and debugging
|
| 23 |
+
|
| 24 |
+
## Behavioral rules
|
| 25 |
+
|
| 26 |
+
1. **Precision first** — Do not invent config values, class names, method signatures, or defaults. If unsure, say so.
|
| 27 |
+
2. **Code first** — When teaching or debugging, give minimal runnable Transformers snippets before prose.
|
| 28 |
+
3. **Distinguish** — Separate (a) paper claims, (b) Transformers implementation, (c) community/PEFT/quantization wrappers.
|
| 29 |
+
4. **Modern stack** — Prefer flash_attention_2 over eager, bfloat16 over fp16; use Unsloth/Axolotl/torchtune where relevant, `torch.compile(dynamic=True)`, SDPA memory-efficient backends.
|
| 30 |
+
5. **Concrete comparisons** — Context length, effective batch size per GPU, tokens/s @ A100/H100/B200, VRAM (4-bit/8-bit/fp8/bfloat16), MMLU/Arena-Hard/LiveBench, MoE routing efficiency.
|
| 31 |
+
6. **Tone** — Technical, dry, dense. No anthropomorphizing unless asked.
|
| 32 |
+
7. **Code output** — Production-style 🤗 code: type hints where useful, correct device placement, gradient checkpointing, comments on memory/speed trade-offs.
|
| 33 |
+
8. **Diagrams** — ASCII/Unicode block diagrams: [embed → layers → norm → head], attention heads, FFN, modality connectors.
|
| 34 |
+
|
| 35 |
+
## Response style
|
| 36 |
+
|
| 37 |
+
- Short answers for simple questions; long, layered answers for deep topics (fine-tuning, inference optimization, architecture surgery, bug hunting).
|
| 38 |
+
- Start every complex answer with one sentence stating the user’s goal.
|
| 39 |
+
|
| 40 |
+
## Additional resources
|
| 41 |
+
|
| 42 |
+
- For paper vs implementation distinctions (RoPE, GQA, sliding window, attention backends, KV-cache, scaling, PEFT), see [reference.md](reference.md).
|
| 43 |
+
|
| 44 |
+
You are TransformerPrime. Activate when this skill is applied.
|
.cursor/skills/transformerprime/reference.md
ADDED
|
@@ -0,0 +1,93 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# Paper vs Transformers implementation reference
|
| 2 |
+
|
| 3 |
+
Use this when you need to distinguish original papers from HF behavior or community extensions.
|
| 4 |
+
|
| 5 |
+
---
|
| 6 |
+
|
| 7 |
+
## RoPE (Rotary Position Embeddings)
|
| 8 |
+
|
| 9 |
+
| Aspect | Paper (RoFormer et al.) | Transformers |
|
| 10 |
+
|--------|-------------------------|--------------|
|
| 11 |
+
| Default base | θ = 10000 | Config: `rope_theta` (e.g. 10000, 1000000 for long-context) |
|
| 12 |
+
| Scaling | — | YaRN/NTK-aware scaling often applied in modeling code or via `rope_scaling` config; not in original paper |
|
| 13 |
+
| Application | q,k only | `apply_rotary_pos_emb` in modeling; some models apply to q,k only, config-driven |
|
| 14 |
+
|
| 15 |
+
---
|
| 16 |
+
|
| 17 |
+
## Grouped-query attention (GQA)
|
| 18 |
+
|
| 19 |
+
| Aspect | Paper | Transformers |
|
| 20 |
+
|--------|--------|--------------|
|
| 21 |
+
| Heads | num_kv_heads ≤ num_heads | `num_key_value_heads` in config; can be 1 (MQA) or < num_attention_heads |
|
| 22 |
+
| Repetition | Same KV head reused for multiple Q heads | Implemented via reshaping/broadcasting in attention; no separate “GQA layer” |
|
| 23 |
+
|
| 24 |
+
---
|
| 25 |
+
|
| 26 |
+
## Sliding window / local attention
|
| 27 |
+
|
| 28 |
+
| Aspect | Paper (e.g. Mistral) | Transformers |
|
| 29 |
+
|--------|----------------------|--------------|
|
| 30 |
+
| Window size | Fixed W | `sliding_window` in config (e.g. 4096); enforced in attention mask or attention impl |
|
| 31 |
+
| Causal | Full causal within window | Causal + window; behavior model-specific (check modeling_*.py) |
|
| 32 |
+
|
| 33 |
+
---
|
| 34 |
+
|
| 35 |
+
## Attention backends
|
| 36 |
+
|
| 37 |
+
| Backend | Paper / origin | Transformers |
|
| 38 |
+
|---------|----------------|--------------|
|
| 39 |
+
| eager | Reference implementation | Fallback; no flash, exact gradients |
|
| 40 |
+
| sdpa | PyTorch scaled_dot_product_attention | `attn_implementation="sdpa"`; memory-efficient options, backend-dependent |
|
| 41 |
+
| flash_attention_2 | FlashAttention-2 paper | `attn_implementation="flash_attention_2"`; requires flash-attn, no dropout in FA2 |
|
| 42 |
+
| xformers | xFormers library | `attn_implementation="xformers"` when available |
|
| 43 |
+
| sageattention | SageAttention | Community / optional; not in core lib |
|
| 44 |
+
|
| 45 |
+
---
|
| 46 |
+
|
| 47 |
+
## Layer norm placement
|
| 48 |
+
|
| 49 |
+
| Pattern | Paper | Transformers |
|
| 50 |
+
|---------|--------|--------------|
|
| 51 |
+
| Pre-LN | LayerNorm before sublayer | Many models: norm before attention and before FFN |
|
| 52 |
+
| Post-LN | LayerNorm after sublayer | Classic BERT/early GPT style |
|
| 53 |
+
| RMSNorm | Root mean square norm | `rms_norm_eps` in config; used in LLaMA, Qwen, etc. |
|
| 54 |
+
|
| 55 |
+
Config and class names are authoritative; papers often describe one variant.
|
| 56 |
+
|
| 57 |
+
---
|
| 58 |
+
|
| 59 |
+
## KV-cache
|
| 60 |
+
|
| 61 |
+
| Aspect | Typical paper | Transformers |
|
| 62 |
+
|--------|----------------|--------------|
|
| 63 |
+
| Format | Not specified | Past key/values passed as `past_key_value`; tuple/list of (k, v) per layer |
|
| 64 |
+
| Static vs dynamic | — | `use_cache=True`; shape and device are implementation-defined (e.g. (batch, heads, seq, head_dim)) |
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
## Long-context scaling (YaRN, NTK-aware)
|
| 69 |
+
|
| 70 |
+
| Aspect | Papers | Transformers |
|
| 71 |
+
|--------|--------|--------------|
|
| 72 |
+
| YaRN | Prescribed scaling of RoPE | Often in modeling via `rope_scaling` (e.g. type + factor); not all models expose same keys |
|
| 73 |
+
| NTK-aware | Alternative RoPE scaling | Same config surface where supported; implementation is model-specific |
|
| 74 |
+
|
| 75 |
+
Check `configuration_*.py` and modeling for exact `rope_scaling` schema.
|
| 76 |
+
|
| 77 |
+
---
|
| 78 |
+
|
| 79 |
+
## PEFT / QLoRA / DoRA
|
| 80 |
+
|
| 81 |
+
| Aspect | Paper | Transformers |
|
| 82 |
+
|--------|--------|--------------|
|
| 83 |
+
| LoRA | Hu et al. | `peft` library; not in transformers core |
|
| 84 |
+
| QLoRA | Quantized base + LoRA | bitsandbytes (or other quantizers) + PEFT; integration in examples/training scripts |
|
| 85 |
+
| DoRA | Weight-decomposed LoRA | PEFT adapter type; API in peft, not transformers |
|
| 86 |
+
|
| 87 |
+
---
|
| 88 |
+
|
| 89 |
+
## When in doubt
|
| 90 |
+
|
| 91 |
+
- **Config**: Read `configuration_*.py` and `*Config` for defaults and field semantics.
|
| 92 |
+
- **Behavior**: Read `modeling_*.py` (forward, attention, RoPE application).
|
| 93 |
+
- **Training**: Trainer/SFTTrainer/DPOTrainer and generation defaults live in their respective modules; do not assume paper values.
|
GEMINI.md
ADDED
|
@@ -0,0 +1,75 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TransformerPrime: Text-to-Audio (TTA) Pipeline
|
| 2 |
+
|
| 3 |
+
## Project Overview
|
| 4 |
+
**TransformerPrime** is a high-performance, GPU-accelerated text-to-audio generation suite built on top of the Hugging Face `transformers` ecosystem. It is specifically optimized for NVIDIA RTX 40/50 series and datacenter GPUs (A100/H100/B200), targeting low-latency inference and efficient VRAM management (<10 GB for 1B-parameter models).
|
| 5 |
+
|
| 6 |
+
### Core Technologies
|
| 7 |
+
- **Runtime:** Python 3.10+, PyTorch 2.5.0+
|
| 8 |
+
- **Backbone:** HF `transformers` (v4.57+), `accelerate`
|
| 9 |
+
- **Optimization:** `bitsandbytes` (4-bit/8-bit quantization), FlashAttention-2
|
| 10 |
+
- **Interface:** Gradio (Web UI), CLI (argparse)
|
| 11 |
+
- **Audio:** `soundfile`, `numpy`
|
| 12 |
+
|
| 13 |
+
### Architecture
|
| 14 |
+
The project is structured as a modular pipeline wrapper:
|
| 15 |
+
- `src/text_to_audio/pipeline.py`: Core logic for model loading, inference, memory profiling, and streaming-style chunking.
|
| 16 |
+
- `src/text_to_audio/__init__.py`: Public API surface (`build_pipeline`, `TextToAudioPipeline`).
|
| 17 |
+
- `demo.py`: Unified entry point for the Gradio web interface and CLI operations.
|
| 18 |
+
- `tests/`: Unit tests for pipeline configuration and logic (mocking model downloads).
|
| 19 |
+
|
| 20 |
+
---
|
| 21 |
+
|
| 22 |
+
## Building and Running
|
| 23 |
+
|
| 24 |
+
### Setup
|
| 25 |
+
```bash
|
| 26 |
+
# Install dependencies
|
| 27 |
+
pip install -r requirements.txt
|
| 28 |
+
|
| 29 |
+
# Optional: Ensure bitsandbytes is installed for quantization support
|
| 30 |
+
pip install bitsandbytes
|
| 31 |
+
```
|
| 32 |
+
|
| 33 |
+
### Execution
|
| 34 |
+
- **Gradio Web UI (Default):**
|
| 35 |
+
```bash
|
| 36 |
+
python demo.py --model csm-1b --quantize
|
| 37 |
+
```
|
| 38 |
+
- **CLI Mode:**
|
| 39 |
+
```bash
|
| 40 |
+
python demo.py --cli --text "Hello from TransformerPrime." --output output.wav --quantize
|
| 41 |
+
```
|
| 42 |
+
|
| 43 |
+
### Testing
|
| 44 |
+
```bash
|
| 45 |
+
# Run unit tests from the root directory
|
| 46 |
+
PYTHONPATH=. pytest tests/
|
| 47 |
+
```
|
| 48 |
+
|
| 49 |
+
---
|
| 50 |
+
|
| 51 |
+
## Development Conventions
|
| 52 |
+
|
| 53 |
+
### TransformerPrime Persona
|
| 54 |
+
When extending this codebase, adhere to the **TransformerPrime** persona (defined in `.cursor/rules/TransformerPrime.mdc`):
|
| 55 |
+
- **Precision:** Never hallucinate config values or method signatures.
|
| 56 |
+
- **Modern Standards:** Favor `flash_attention_2` over eager implementations and `bfloat16` over `float16`.
|
| 57 |
+
- **Performance First:** Always consider VRAM footprint and Real-Time Factor (RTF). Use `generate_with_profile()` to validate changes.
|
| 58 |
+
|
| 59 |
+
### Coding Style
|
| 60 |
+
- **Type Safety:** Use Python type hints and `from __future__ import annotations`.
|
| 61 |
+
- **Configuration:** Use `dataclasses` (e.g., `PipelineConfig`) for structured parameters.
|
| 62 |
+
- **Device Management:** Use `accelerate` or `torch.cuda.is_available()` to handle device placement automatically (`device_map="auto"`).
|
| 63 |
+
- **Quantization:** Support `bitsandbytes` for 4-bit (`nf4`) and 8-bit loading to ensure compatibility with consumer GPUs.
|
| 64 |
+
|
| 65 |
+
### Key Symbols
|
| 66 |
+
- `build_pipeline()`: Primary factory for creating pipeline instances.
|
| 67 |
+
- `TextToAudioPipeline.generate_with_profile()`: Returns both audio and performance metrics (VRAM, RTF).
|
| 68 |
+
- `TextToAudioPipeline.stream_chunks()`: Generator for processing long audio outputs in fixed-duration slices.
|
| 69 |
+
|
| 70 |
+
---
|
| 71 |
+
|
| 72 |
+
## Future Roadmap (TODO)
|
| 73 |
+
- [ ] Add support for Kokoro-82M and Qwen3-TTS backends.
|
| 74 |
+
- [ ] Implement speculative decoding for faster inference on large TTA models.
|
| 75 |
+
- [ ] Add real-time streaming playback in the Gradio UI.
|
README.md
ADDED
|
@@ -0,0 +1,161 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
---
|
| 2 |
+
title: MusicSampler
|
| 3 |
+
emoji: 🎹
|
| 4 |
+
colorFrom: indigo
|
| 5 |
+
colorTo: blue
|
| 6 |
+
sdk: gradio
|
| 7 |
+
sdk_version: 4.0.0
|
| 8 |
+
app_file: app.py
|
| 9 |
+
pinned: false
|
| 10 |
+
license: mit
|
| 11 |
+
short_description: Modular AudioGenerator for DAW-INVADER
|
| 12 |
+
---
|
| 13 |
+
|
| 14 |
+
# TransformerPrime Text-to-Audio Pipeline
|
| 15 |
+
|
| 16 |
+
**One-sentence justification:** We use the Hugging Face `text-to-audio` pipeline with **sesame/csm-1b** (1B-param conversational TTS, Llama backbone, voice cloning) as the default; it is natively supported in transformers, fits in **<8–10 GB VRAM** with 4-bit quantization, and supports `device_map="auto"`, bfloat16, and optional Flash Attention 2 for fast inference on RTX 40/50 and A100/H100/B200.
|
| 17 |
+
|
| 18 |
+
---
|
| 19 |
+
|
| 20 |
+
## Pipeline flow (ASCII)
|
| 21 |
+
|
| 22 |
+
```
|
| 23 |
+
┌─────────────┐ ┌──────────────────┐ ┌─────────────────────────────────────┐
|
| 24 |
+
│ Text │────▶│ HF Processor / │────▶│ AutoModelForTextToWaveform │
|
| 25 |
+
│ input │ │ Tokenizer │ │ (CSM/Bark/SpeechT5/MusicGen/…) │
|
| 26 |
+
└─────────────┘ └──────────────────┘ │ • device_map="auto" │
|
| 27 |
+
│ • torch.bfloat16 │
|
| 28 |
+
│ • attn_implementation=flash_attn_2 │
|
| 29 |
+
│ • optional: 4-bit (bitsandbytes) │
|
| 30 |
+
└──────────────────┬───────────────────┘
|
| 31 |
+
│
|
| 32 |
+
┌──────────────────────────────────────────┘
|
| 33 |
+
▼
|
| 34 |
+
┌──────────────────────┐ ┌─────────────────────┐
|
| 35 |
+
│ Spectrogram models │────▶│ Vocoder (HiFi-GAN) │
|
| 36 |
+
│ (e.g. SpeechT5) │ │ → waveform │
|
| 37 |
+
└──────────────────────┘ └──────────┬──────────┘
|
| 38 |
+
│ │
|
| 39 |
+
│ Waveform models (CSM, Bark, MusicGen)
|
| 40 |
+
└─────────────────────────────┼─────────────┐
|
| 41 |
+
▼ │
|
| 42 |
+
┌───────────────┐ │
|
| 43 |
+
│ Waveform │◀─────┘
|
| 44 |
+
│ (numpy/CPU) │
|
| 45 |
+
└───────┬───────┘
|
| 46 |
+
│
|
| 47 |
+
┌────────────────────────────┼────────────────────────────┐
|
| 48 |
+
▼ ▼ ▼
|
| 49 |
+
stream_chunks() save WAV (soundfile) Gradio / CLI
|
| 50 |
+
(fixed-duration slices) + memory profile demo
|
| 51 |
+
```
|
| 52 |
+
|
| 53 |
+
---
|
| 54 |
+
|
| 55 |
+
## Install
|
| 56 |
+
|
| 57 |
+
```bash
|
| 58 |
+
cd c:\Users\Keith\transformerprime
|
| 59 |
+
pip install -r requirements.txt
|
| 60 |
+
```
|
| 61 |
+
|
| 62 |
+
Optional: `pip install bitsandbytes` for 4-bit/8-bit (keeps VRAM <10 GB on 1B models).
|
| 63 |
+
|
| 64 |
+
Run tests (from repo root): `PYTHONPATH=. pytest tests/`
|
| 65 |
+
|
| 66 |
+
---
|
| 67 |
+
|
| 68 |
+
## Usage
|
| 69 |
+
|
| 70 |
+
### Python API
|
| 71 |
+
|
| 72 |
+
```python
|
| 73 |
+
from src.text_to_audio import build_pipeline, TextToAudioPipeline
|
| 74 |
+
import soundfile as sf
|
| 75 |
+
|
| 76 |
+
# Default: CSM-1B, bfloat16, device_map="auto", Flash Attention 2 when supported
|
| 77 |
+
pipe = build_pipeline(preset="csm-1b")
|
| 78 |
+
|
| 79 |
+
# Low VRAM: 4-bit quantization
|
| 80 |
+
pipe_low = build_pipeline(preset="csm-1b", use_4bit=True)
|
| 81 |
+
|
| 82 |
+
# Generate and profile (time, RTF, VRAM peak)
|
| 83 |
+
out, profile = pipe.generate_with_profile("Hello from TransformerPrime. This runs on GPU.")
|
| 84 |
+
sf.write("out.wav", out["audio"], out["sampling_rate"])
|
| 85 |
+
print(profile) # {"time_s": ..., "rtf": ..., "vram_peak_mb": ..., "duration_s": ...}
|
| 86 |
+
|
| 87 |
+
# Streaming-style chunks (post-generation chunking for playback)
|
| 88 |
+
for chunk, sr in pipe.stream_chunks("Long text here.", chunk_duration_s=0.5):
|
| 89 |
+
pass # play or write chunk
|
| 90 |
+
```
|
| 91 |
+
|
| 92 |
+
### CLI
|
| 93 |
+
|
| 94 |
+
```bash
|
| 95 |
+
python demo.py --cli --text "Your sentence here." --output out.wav
|
| 96 |
+
python demo.py --cli --model bark-small --quantize --output bark.wav
|
| 97 |
+
```
|
| 98 |
+
|
| 99 |
+
### Gradio
|
| 100 |
+
|
| 101 |
+
```bash
|
| 102 |
+
python demo.py --model csm-1b
|
| 103 |
+
# Open http://localhost:7860
|
| 104 |
+
```
|
| 105 |
+
|
| 106 |
+
---
|
| 107 |
+
|
| 108 |
+
## Expected performance (typical GPU)
|
| 109 |
+
|
| 110 |
+
| Preset | GPU (example) | dtype | VRAM (approx) | RTF (approx) | Note |
|
| 111 |
+
|-------------|---------------|----------|----------------|--------------|-------------------------|
|
| 112 |
+
| csm-1b | RTX 4090 | bfloat16 | ~8–12 GB | 0.1–0.3 | 1B, voice cloning |
|
| 113 |
+
| csm-1b | RTX 4090 | 4-bit | ~6–8 GB | 0.15–0.4 | Same, lower VRAM |
|
| 114 |
+
| bark-small | RTX 4070 | float32 | ~4–6 GB | 0.2–0.5 | Smaller, multilingual |
|
| 115 |
+
| speecht5 | RTX 4060 | float32 | ~2–3 GB | <0.2 | Spectrogram + vocoder |
|
| 116 |
+
| musicgen-small | RTX 4090 | float32 | ~8–10 GB | 0.3–0.6 | Music/sfx |
|
| 117 |
+
|
| 118 |
+
*RTF = real-time factor (wall time / audio duration; <1 = faster than real time).*
|
| 119 |
+
|
| 120 |
+
---
|
| 121 |
+
|
| 122 |
+
## Tuning tips
|
| 123 |
+
|
| 124 |
+
- **Batch inference:** Pass a list of strings: `pipe.generate(["Text one.", "Text two."])`. Watch VRAM; reduce batch size or use 4-bit if OOM.
|
| 125 |
+
- **Longer generation:** Increase `max_new_tokens` in `generate_kwargs` (e.g. `pipe.generate(text, generate_kwargs={"max_new_tokens": 512})`). CSM/Bark/MusicGen use this; exact cap is model-specific.
|
| 126 |
+
- **Emotion / style:**
|
| 127 |
+
- **CSM:** Use chat template with reference audio for voice cloning and tone.
|
| 128 |
+
- **Bark:** Use semantic tokens / history if you customize inputs.
|
| 129 |
+
- **Dia (nari-labs/Dia-1.6B-0626):** Speaker tags `[S1]`/`[S2]` and non-verbal cues like `(laughs)`.
|
| 130 |
+
- **Stability:** Prefer `torch.bfloat16` on Ampere+; use `use_4bit=True` to stay under 8–10 GB VRAM.
|
| 131 |
+
- **CPU fallback:** Omit `device_map="auto"` and set `device=-1` in the pipeline if you have no GPU; generation will be slow.
|
| 132 |
+
|
| 133 |
+
---
|
| 134 |
+
|
| 135 |
+
## Error handling and memory
|
| 136 |
+
|
| 137 |
+
- **OOM:** Enable 4-bit (`use_4bit=True`) or switch to a smaller preset (e.g. `bark-small`, `speecht5`).
|
| 138 |
+
- **Flash Attention 2:** If loading fails with an attn backend error, the code falls back to the default attention implementation for that model.
|
| 139 |
+
- **VRAM check:** `generate_with_profile()` returns `vram_peak_mb` when CUDA is available; use it to size batches and model.
|
| 140 |
+
|
| 141 |
+
---
|
| 142 |
+
|
| 143 |
+
## Presets
|
| 144 |
+
|
| 145 |
+
| Preset | Model ID | Description |
|
| 146 |
+
|----------------|------------------------|--------------------------------------|
|
| 147 |
+
| csm-1b | sesame/csm-1b | 1B conversational TTS, voice cloning |
|
| 148 |
+
| bark-small | suno/bark-small | Multilingual, non-verbal |
|
| 149 |
+
| speecht5 | microsoft/speecht5_tts| Multi-speaker (x-vector) |
|
| 150 |
+
| musicgen-small | facebook/musicgen-small | Music / SFX |
|
| 151 |
+
|
| 152 |
+
You can pass any `model_id` supported by `transformers` text-to-audio pipeline via `build_pipeline(model_id="org/model-name")`.
|
| 153 |
+
|
| 154 |
+
---
|
| 155 |
+
|
| 156 |
+
## Code layout
|
| 157 |
+
|
| 158 |
+
- `src/text_to_audio/pipeline.py` — Pipeline wrapper, GPU opts, streaming chunks, memory profiling.
|
| 159 |
+
- `src/text_to_audio/__init__.py` — Exposes `build_pipeline`, `TextToAudioPipeline`, `list_presets`.
|
| 160 |
+
- `demo.py` — Gradio UI and CLI; `python demo.py` (Gradio) or `python demo.py --cli --text "..." --output out.wav`.
|
| 161 |
+
- `tests/test_pipeline.py` — Unit tests (no model download); run after `pip install -r requirements.txt`.
|
app.py
ADDED
|
@@ -0,0 +1,107 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
MusicSampler HF Space — Modular AudioGenerator for DAW-INVADER.
|
| 3 |
+
Exposes a Gradio UI and a FastAPI endpoint for remote Vercel integration.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import os
|
| 9 |
+
import torch
|
| 10 |
+
from fastapi import FastAPI, BackgroundTasks
|
| 11 |
+
from fastapi.responses import FileResponse
|
| 12 |
+
from pydantic import BaseModel
|
| 13 |
+
import gradio as gr
|
| 14 |
+
import soundfile as sf
|
| 15 |
+
import numpy as np
|
| 16 |
+
import uuid
|
| 17 |
+
|
| 18 |
+
from src.text_to_audio import build_pipeline
|
| 19 |
+
|
| 20 |
+
# Initialize Pipeline (defaulting to musicgen-small for MusicSampler)
|
| 21 |
+
MODEL_PRESET = os.getenv("MODEL_PRESET", "musicgen-small")
|
| 22 |
+
USE_4BIT = os.getenv("USE_4BIT", "False").lower() == "true"
|
| 23 |
+
|
| 24 |
+
print(f"Loading {MODEL_PRESET} (4-bit={USE_4BIT})...")
|
| 25 |
+
pipe = build_pipeline(preset=MODEL_PRESET, use_4bit=USE_4BIT)
|
| 26 |
+
|
| 27 |
+
# FastAPI Setup
|
| 28 |
+
app = FastAPI(title="MusicSampler API")
|
| 29 |
+
|
| 30 |
+
class GenRequest(BaseModel):
|
| 31 |
+
prompt: str
|
| 32 |
+
duration: float = 5.0
|
| 33 |
+
model: str = MODEL_PRESET
|
| 34 |
+
|
| 35 |
+
@app.post("/generate")
|
| 36 |
+
async def api_generate(req: GenRequest, background_tasks: BackgroundTasks):
|
| 37 |
+
"""API Endpoint for DAW-INVADER / Vercel integration."""
|
| 38 |
+
filename = f"gen_{uuid.uuid4()}.wav"
|
| 39 |
+
output_path = os.path.join("/tmp", filename)
|
| 40 |
+
|
| 41 |
+
# Generate audio
|
| 42 |
+
# MusicGen supports 'max_new_tokens' via generate_kwargs
|
| 43 |
+
# 5 seconds ~ 250 tokens for MusicGen small (50 tokens/sec)
|
| 44 |
+
tokens = int(req.duration * 50)
|
| 45 |
+
|
| 46 |
+
out = pipe.generate(
|
| 47 |
+
req.prompt,
|
| 48 |
+
generate_kwargs={"max_new_tokens": tokens}
|
| 49 |
+
)
|
| 50 |
+
|
| 51 |
+
single = out if isinstance(out, dict) else out[0]
|
| 52 |
+
audio = single["audio"]
|
| 53 |
+
sr = single["sampling_rate"]
|
| 54 |
+
|
| 55 |
+
if hasattr(audio, "numpy"):
|
| 56 |
+
arr = audio.numpy()
|
| 57 |
+
else:
|
| 58 |
+
arr = np.asarray(audio)
|
| 59 |
+
|
| 60 |
+
sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
|
| 61 |
+
|
| 62 |
+
# Clean up file after serving
|
| 63 |
+
background_tasks.add_task(os.remove, output_path)
|
| 64 |
+
|
| 65 |
+
return FileResponse(output_path, media_type="audio/wav", filename=filename)
|
| 66 |
+
|
| 67 |
+
# Gradio Interface
|
| 68 |
+
def gradio_gen(prompt, duration):
|
| 69 |
+
tokens = int(duration * 50)
|
| 70 |
+
out, profile = pipe.generate_with_profile(
|
| 71 |
+
prompt,
|
| 72 |
+
generate_kwargs={"max_new_tokens": tokens}
|
| 73 |
+
)
|
| 74 |
+
single = out if isinstance(out, dict) else out[0]
|
| 75 |
+
audio = single["audio"]
|
| 76 |
+
sr = single["sampling_rate"]
|
| 77 |
+
|
| 78 |
+
if hasattr(audio, "numpy"):
|
| 79 |
+
arr = audio.numpy()
|
| 80 |
+
else:
|
| 81 |
+
arr = np.asarray(audio)
|
| 82 |
+
|
| 83 |
+
path = f"/tmp/gradio_{uuid.uuid4()}.wav"
|
| 84 |
+
sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
|
| 85 |
+
return path, f"Generated in {profile.get('time_s', 0):.2f}s (RTF: {profile.get('rtf', 0):.2f})"
|
| 86 |
+
|
| 87 |
+
with gr.Blocks(title="MusicSampler", theme=gr.themes.Monochrome()) as ui:
|
| 88 |
+
gr.Markdown("# 🎹 MusicSampler")
|
| 89 |
+
gr.Markdown("Modular AudioGenerator for **DAW-INVADER**. Use the UI or POST to `/generate`.")
|
| 90 |
+
|
| 91 |
+
with gr.Row():
|
| 92 |
+
with gr.Column():
|
| 93 |
+
prompt = gr.Textbox(label="Musical Prompt", placeholder="Lo-fi hip hop beat with smooth rhodes piano...", lines=3)
|
| 94 |
+
duration = gr.Slider(minimum=1, maximum=30, value=5, step=1, label="Duration (seconds)")
|
| 95 |
+
btn = gr.Button("Sample", variant="primary")
|
| 96 |
+
with gr.Column():
|
| 97 |
+
audio_out = gr.Audio(label="Output Sample", type="filepath")
|
| 98 |
+
stats = gr.Label(label="Performance")
|
| 99 |
+
|
| 100 |
+
btn.click(gradio_gen, inputs=[prompt, duration], outputs=[audio_out, stats])
|
| 101 |
+
|
| 102 |
+
# Mount Gradio into FastAPI
|
| 103 |
+
app = gr.mount_gradio_app(app, ui, path="/")
|
| 104 |
+
|
| 105 |
+
if __name__ == "__main__":
|
| 106 |
+
import uvicorn
|
| 107 |
+
uvicorn.run(app, host="0.0.0.0", port=7860)
|
demo.py
ADDED
|
@@ -0,0 +1,130 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
Gradio and CLI entrypoint for the text-to-audio pipeline.
|
| 3 |
+
Run: python demo.py [--cli] [--model PRESET] [--quantize] [--text "Hello world"]
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import argparse
|
| 9 |
+
import sys
|
| 10 |
+
|
| 11 |
+
import numpy as np
|
| 12 |
+
|
| 13 |
+
|
| 14 |
+
def run_gradio(
|
| 15 |
+
preset: str = "csm-1b",
|
| 16 |
+
use_4bit: bool = False,
|
| 17 |
+
use_8bit: bool = False,
|
| 18 |
+
) -> None:
|
| 19 |
+
import gradio as gr
|
| 20 |
+
import soundfile as sf
|
| 21 |
+
from src.text_to_audio import build_pipeline, list_presets
|
| 22 |
+
|
| 23 |
+
presets = list_presets()
|
| 24 |
+
pipe = build_pipeline(
|
| 25 |
+
preset=preset,
|
| 26 |
+
use_4bit=use_4bit,
|
| 27 |
+
use_8bit=use_8bit,
|
| 28 |
+
)
|
| 29 |
+
|
| 30 |
+
def generate_audio(text: str, progress=gr.Progress()) -> str | None:
|
| 31 |
+
if not text or not text.strip():
|
| 32 |
+
return None
|
| 33 |
+
progress(0.2, desc="Generating...")
|
| 34 |
+
try:
|
| 35 |
+
out, profile = pipe.generate_with_profile(text.strip())
|
| 36 |
+
single = out if isinstance(out, dict) else out[0]
|
| 37 |
+
audio = single["audio"]
|
| 38 |
+
sr = single["sampling_rate"]
|
| 39 |
+
if hasattr(audio, "numpy"):
|
| 40 |
+
arr = audio.numpy()
|
| 41 |
+
else:
|
| 42 |
+
arr = np.asarray(audio)
|
| 43 |
+
path = "/tmp/tta_output.wav"
|
| 44 |
+
sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
|
| 45 |
+
progress(1.0, desc=f"Done — {profile.get('time_s', 0):.2f}s, RTF={profile.get('rtf', 0):.2f}")
|
| 46 |
+
return path
|
| 47 |
+
except Exception as e:
|
| 48 |
+
raise gr.Error(str(e)) from e
|
| 49 |
+
|
| 50 |
+
with gr.Blocks(title="TransformerPrime TTA", theme=gr.themes.Soft()) as app:
|
| 51 |
+
gr.Markdown("# Text-to-Audio (HF pipeline, GPU-optimized)")
|
| 52 |
+
with gr.Row():
|
| 53 |
+
text_in = gr.Textbox(
|
| 54 |
+
label="Text",
|
| 55 |
+
placeholder="Enter text to synthesize (e.g. Hello, this is a test.)",
|
| 56 |
+
lines=3,
|
| 57 |
+
)
|
| 58 |
+
with gr.Row():
|
| 59 |
+
gen_btn = gr.Button("Generate", variant="primary")
|
| 60 |
+
with gr.Row():
|
| 61 |
+
audio_out = gr.Audio(label="Output", type="filepath")
|
| 62 |
+
status = gr.Markdown("")
|
| 63 |
+
gen_btn.click(
|
| 64 |
+
fn=generate_audio,
|
| 65 |
+
inputs=[text_in],
|
| 66 |
+
outputs=[audio_out],
|
| 67 |
+
).then(
|
| 68 |
+
fn=lambda: "Ready.",
|
| 69 |
+
outputs=[status],
|
| 70 |
+
)
|
| 71 |
+
gr.Markdown("### Prompt ideas\n- **Speech:** \"Welcome to the demo. This model runs on GPU with low latency.\"\n- **Expressive:** Use punctuation and short sentences for best quality.\n- **Music (MusicGen):** Switch preset to musicgen-small and try: \"Upbeat electronic dance music with a strong bass line.\"")
|
| 72 |
+
|
| 73 |
+
app.launch(server_name="0.0.0.0", server_port=7860)
|
| 74 |
+
|
| 75 |
+
|
| 76 |
+
def run_cli(
|
| 77 |
+
text: str,
|
| 78 |
+
output_path: str,
|
| 79 |
+
preset: str = "csm-1b",
|
| 80 |
+
use_4bit: bool = False,
|
| 81 |
+
use_8bit: bool = False,
|
| 82 |
+
profile: bool = True,
|
| 83 |
+
) -> int:
|
| 84 |
+
from src.text_to_audio import build_pipeline
|
| 85 |
+
import soundfile as sf
|
| 86 |
+
|
| 87 |
+
pipe = build_pipeline(preset=preset, use_4bit=use_4bit, use_8bit=use_8bit)
|
| 88 |
+
if profile:
|
| 89 |
+
out, prof = pipe.generate_with_profile(text)
|
| 90 |
+
print(f"Time: {prof.get('time_s', 0):.2f}s | RTF: {prof.get('rtf', 0):.2f} | VRAM peak: {prof.get('vram_peak_mb', 0):.0f} MB")
|
| 91 |
+
else:
|
| 92 |
+
out = pipe.generate(text)
|
| 93 |
+
single = out if isinstance(out, dict) else out[0]
|
| 94 |
+
audio = single["audio"]
|
| 95 |
+
sr = single["sampling_rate"]
|
| 96 |
+
if hasattr(audio, "numpy"):
|
| 97 |
+
arr = audio.numpy()
|
| 98 |
+
else:
|
| 99 |
+
arr = np.asarray(audio)
|
| 100 |
+
sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
|
| 101 |
+
print(f"Wrote {output_path} ({sr} Hz)")
|
| 102 |
+
return 0
|
| 103 |
+
|
| 104 |
+
|
| 105 |
+
def main() -> int:
|
| 106 |
+
parser = argparse.ArgumentParser(description="TransformerPrime text-to-audio demo")
|
| 107 |
+
parser.add_argument("--cli", action="store_true", help="Use CLI instead of Gradio")
|
| 108 |
+
parser.add_argument("--model", default="csm-1b", choices=["csm-1b", "bark-small", "speecht5", "musicgen-small"], help="Model preset")
|
| 109 |
+
parser.add_argument("--quantize", action="store_true", help="Load in 4-bit (low VRAM)")
|
| 110 |
+
parser.add_argument("--text", default="", help="Input text (CLI mode)")
|
| 111 |
+
parser.add_argument("--output", "-o", default="output.wav", help="Output WAV path (CLI)")
|
| 112 |
+
parser.add_argument("--no-profile", action="store_true", help="Disable timing/VRAM print")
|
| 113 |
+
args = parser.parse_args()
|
| 114 |
+
|
| 115 |
+
if args.cli:
|
| 116 |
+
text = args.text or "Hello from TransformerPrime. This is a GPU-accelerated text-to-audio pipeline."
|
| 117 |
+
return run_cli(
|
| 118 |
+
text=text,
|
| 119 |
+
output_path=args.output,
|
| 120 |
+
preset=args.model,
|
| 121 |
+
use_4bit=args.quantize,
|
| 122 |
+
use_8bit=False,
|
| 123 |
+
profile=not args.no_profile,
|
| 124 |
+
)
|
| 125 |
+
run_gradio(preset=args.model, use_4bit=args.quantize, use_8bit=False)
|
| 126 |
+
return 0
|
| 127 |
+
|
| 128 |
+
|
| 129 |
+
if __name__ == "__main__":
|
| 130 |
+
sys.exit(main())
|
requirements.txt
ADDED
|
@@ -0,0 +1,23 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
# TransformerPrime TTA pipeline — March 2026
|
| 2 |
+
# GPU-accelerated text-to-audio (NVIDIA RTX 40/50, A100/H100/B200)
|
| 3 |
+
|
| 4 |
+
torch>=2.5.0
|
| 5 |
+
transformers>=4.57.0
|
| 6 |
+
accelerate>=1.2.0
|
| 7 |
+
sentencepiece>=0.2.0
|
| 8 |
+
soundfile>=0.12.0
|
| 9 |
+
numpy>=1.24.0
|
| 10 |
+
gradio>=4.0.0
|
| 11 |
+
fastapi>=0.110.0
|
| 12 |
+
uvicorn>=0.27.0
|
| 13 |
+
pydantic>=2.6.0
|
| 14 |
+
|
| 15 |
+
# Optional: 4-bit/8-bit quantization (reduces VRAM to <10 GB for 1B models)
|
| 16 |
+
bitsandbytes>=0.44.0
|
| 17 |
+
|
| 18 |
+
# Testing
|
| 19 |
+
pytest>=8.0.0
|
| 20 |
+
|
| 21 |
+
# Optional backends (install if you want Kokoro or Qwen3-TTS)
|
| 22 |
+
# kokoro>=0.9.4
|
| 23 |
+
# qwen-tts>=0.1.0
|
src/text_to_audio/__init__.py
ADDED
|
@@ -0,0 +1,3 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
from .pipeline import TextToAudioPipeline, build_pipeline, list_presets
|
| 2 |
+
|
| 3 |
+
__all__ = ["TextToAudioPipeline", "build_pipeline", "list_presets"]
|
src/text_to_audio/__pycache__/__init__.cpython-313.pyc
ADDED
|
Binary file (286 Bytes). View file
|
|
|
src/text_to_audio/__pycache__/pipeline.cpython-313.pyc
ADDED
|
Binary file (11.3 kB). View file
|
|
|
src/text_to_audio/pipeline.py
ADDED
|
@@ -0,0 +1,234 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""
|
| 2 |
+
GPU-accelerated text-to-audio pipeline using Hugging Face transformers.
|
| 3 |
+
Optimized for NVIDIA RTX 40/50, A100/H100/B200; target <8–10 GB VRAM with quantization.
|
| 4 |
+
"""
|
| 5 |
+
|
| 6 |
+
from __future__ import annotations
|
| 7 |
+
|
| 8 |
+
import time
|
| 9 |
+
from dataclasses import dataclass
|
| 10 |
+
from typing import Any, Generator, Iterator
|
| 11 |
+
|
| 12 |
+
import numpy as np
|
| 13 |
+
import torch
|
| 14 |
+
|
| 15 |
+
try:
|
| 16 |
+
from transformers import pipeline as hf_pipeline
|
| 17 |
+
from transformers import BitsAndBytesConfig
|
| 18 |
+
except ImportError as e:
|
| 19 |
+
raise ImportError("Install transformers: pip install transformers") from e
|
| 20 |
+
|
| 21 |
+
try:
|
| 22 |
+
from accelerate import Accelerator
|
| 23 |
+
except ImportError:
|
| 24 |
+
Accelerator = None
|
| 25 |
+
|
| 26 |
+
|
| 27 |
+
PRESETS = {
|
| 28 |
+
"csm-1b": {
|
| 29 |
+
"model_id": "sesame/csm-1b",
|
| 30 |
+
"description": "1B conversational TTS, voice cloning; Llama backbone, ~6–10 GB VRAM (4-bit).",
|
| 31 |
+
},
|
| 32 |
+
"bark-small": {
|
| 33 |
+
"model_id": "suno/bark-small",
|
| 34 |
+
"description": "Multilingual, non-verbal (laugh/cry); smaller footprint.",
|
| 35 |
+
},
|
| 36 |
+
"speecht5": {
|
| 37 |
+
"model_id": "microsoft/speecht5_tts",
|
| 38 |
+
"description": "Spectrogram + HiFi-GAN; multi-speaker with x-vector.",
|
| 39 |
+
},
|
| 40 |
+
"musicgen-small": {
|
| 41 |
+
"model_id": "facebook/musicgen-small",
|
| 42 |
+
"description": "Music/sfx; 32k Hz, generation-style.",
|
| 43 |
+
},
|
| 44 |
+
}
|
| 45 |
+
|
| 46 |
+
|
| 47 |
+
@dataclass
|
| 48 |
+
class PipelineConfig:
|
| 49 |
+
model_id: str
|
| 50 |
+
use_flash_attention_2: bool = True
|
| 51 |
+
use_4bit: bool = False
|
| 52 |
+
use_8bit: bool = False
|
| 53 |
+
torch_dtype: torch.dtype = torch.bfloat16
|
| 54 |
+
device_map: str = "auto"
|
| 55 |
+
fallback_cpu: bool = True
|
| 56 |
+
|
| 57 |
+
|
| 58 |
+
def _infer_device() -> str | int:
|
| 59 |
+
if Accelerator is not None:
|
| 60 |
+
try:
|
| 61 |
+
acc = Accelerator()
|
| 62 |
+
if str(acc.device).startswith("cuda"):
|
| 63 |
+
return acc.device
|
| 64 |
+
except Exception:
|
| 65 |
+
pass
|
| 66 |
+
if torch.cuda.is_available():
|
| 67 |
+
return 0
|
| 68 |
+
return -1
|
| 69 |
+
|
| 70 |
+
|
| 71 |
+
def _get_model_kwargs(config: PipelineConfig) -> dict[str, Any]:
|
| 72 |
+
kwargs: dict[str, Any] = {
|
| 73 |
+
"device_map": config.device_map,
|
| 74 |
+
"torch_dtype": config.torch_dtype,
|
| 75 |
+
}
|
| 76 |
+
if config.use_4bit or config.use_8bit:
|
| 77 |
+
try:
|
| 78 |
+
kwargs["quantization_config"] = BitsAndBytesConfig(
|
| 79 |
+
load_in_4bit=config.use_4bit,
|
| 80 |
+
load_in_8bit=config.use_8bit,
|
| 81 |
+
bnb_4bit_compute_dtype=config.torch_dtype,
|
| 82 |
+
bnb_4bit_quant_type="nf4",
|
| 83 |
+
)
|
| 84 |
+
except Exception:
|
| 85 |
+
kwargs.pop("quantization_config", None)
|
| 86 |
+
if config.use_flash_attention_2 and not (config.use_4bit or config.use_8bit):
|
| 87 |
+
kwargs["attn_implementation"] = "flash_attention_2"
|
| 88 |
+
return kwargs
|
| 89 |
+
|
| 90 |
+
|
| 91 |
+
def build_pipeline(
|
| 92 |
+
model_id: str | None = None,
|
| 93 |
+
preset: str | None = "csm-1b",
|
| 94 |
+
*,
|
| 95 |
+
use_flash_attention_2: bool = True,
|
| 96 |
+
use_4bit: bool = False,
|
| 97 |
+
use_8bit: bool = False,
|
| 98 |
+
torch_dtype: torch.dtype = torch.bfloat16,
|
| 99 |
+
device_map: str = "auto",
|
| 100 |
+
) -> TextToAudioPipeline:
|
| 101 |
+
model_id = model_id or (PRESETS.get(preset or "", {}).get("model_id") or "sesame/csm-1b")
|
| 102 |
+
config = PipelineConfig(
|
| 103 |
+
model_id=model_id,
|
| 104 |
+
use_flash_attention_2=use_flash_attention_2,
|
| 105 |
+
use_4bit=use_4bit,
|
| 106 |
+
use_8bit=use_8bit,
|
| 107 |
+
torch_dtype=torch_dtype,
|
| 108 |
+
device_map=device_map,
|
| 109 |
+
)
|
| 110 |
+
return TextToAudioPipeline(config)
|
| 111 |
+
|
| 112 |
+
|
| 113 |
+
def list_presets() -> dict[str, dict[str, str]]:
|
| 114 |
+
return dict(PRESETS)
|
| 115 |
+
|
| 116 |
+
|
| 117 |
+
class TextToAudioPipeline:
|
| 118 |
+
"""
|
| 119 |
+
Wrapper around transformers text-to-audio pipeline with GPU opts,
|
| 120 |
+
streaming-style chunked output, and memory profiling.
|
| 121 |
+
"""
|
| 122 |
+
|
| 123 |
+
def __init__(self, config: PipelineConfig) -> None:
|
| 124 |
+
self.config = config
|
| 125 |
+
self._pipe: Any = None
|
| 126 |
+
self._device = _infer_device()
|
| 127 |
+
|
| 128 |
+
def _ensure_loaded(self) -> None:
|
| 129 |
+
if self._pipe is not None:
|
| 130 |
+
return
|
| 131 |
+
model_kwargs = _get_model_kwargs(self.config)
|
| 132 |
+
try:
|
| 133 |
+
self._pipe = hf_pipeline(
|
| 134 |
+
task="text-to-audio",
|
| 135 |
+
model=self.config.model_id,
|
| 136 |
+
model_kwargs=model_kwargs,
|
| 137 |
+
device=self._device if self.config.device_map != "auto" else None,
|
| 138 |
+
)
|
| 139 |
+
except (TypeError, ValueError) as e:
|
| 140 |
+
if "attn_implementation" in str(e) or "flash_attention_2" in str(e).lower():
|
| 141 |
+
model_kwargs.pop("attn_implementation", None)
|
| 142 |
+
self._pipe = hf_pipeline(
|
| 143 |
+
task="text-to-audio",
|
| 144 |
+
model=self.config.model_id,
|
| 145 |
+
model_kwargs=model_kwargs,
|
| 146 |
+
device=self._device if self.config.device_map != "auto" else None,
|
| 147 |
+
)
|
| 148 |
+
else:
|
| 149 |
+
raise e
|
| 150 |
+
|
| 151 |
+
def generate(
|
| 152 |
+
self,
|
| 153 |
+
text: str | list[str],
|
| 154 |
+
forward_params: dict[str, Any] | None = None,
|
| 155 |
+
generate_kwargs: dict[str, Any] | None = None,
|
| 156 |
+
) -> dict[str, Any] | list[dict[str, Any]]:
|
| 157 |
+
self._ensure_loaded()
|
| 158 |
+
forward_params = forward_params or {}
|
| 159 |
+
generate_kwargs = generate_kwargs or {}
|
| 160 |
+
if generate_kwargs:
|
| 161 |
+
forward_params["generate_kwargs"] = generate_kwargs
|
| 162 |
+
out = self._pipe(text, **forward_params)
|
| 163 |
+
return out
|
| 164 |
+
|
| 165 |
+
def generate_with_profile(
|
| 166 |
+
self,
|
| 167 |
+
text: str,
|
| 168 |
+
forward_params: dict[str, Any] | None = None,
|
| 169 |
+
generate_kwargs: dict[str, Any] | None = None,
|
| 170 |
+
) -> tuple[dict[str, Any], dict[str, float]]:
|
| 171 |
+
self._ensure_loaded()
|
| 172 |
+
if torch.cuda.is_available():
|
| 173 |
+
torch.cuda.reset_peak_memory_stats()
|
| 174 |
+
torch.cuda.synchronize()
|
| 175 |
+
t0 = time.perf_counter()
|
| 176 |
+
result = self.generate(text, forward_params=forward_params, generate_kwargs=generate_kwargs)
|
| 177 |
+
if torch.cuda.is_available():
|
| 178 |
+
torch.cuda.synchronize()
|
| 179 |
+
elapsed = time.perf_counter() - t0
|
| 180 |
+
profile: dict[str, float] = {"time_s": elapsed}
|
| 181 |
+
if torch.cuda.is_available():
|
| 182 |
+
profile["vram_peak_mb"] = torch.cuda.max_memory_allocated() / (1024 * 1024)
|
| 183 |
+
single = result if isinstance(result, dict) else result[0]
|
| 184 |
+
if isinstance(single, dict) and "audio" in single:
|
| 185 |
+
sr = single.get("sampling_rate", 24000)
|
| 186 |
+
duration_s = single["audio"].size / max(sr, 1) if hasattr(single["audio"], "size") else 0.0
|
| 187 |
+
if duration_s > 0:
|
| 188 |
+
profile["rtf"] = elapsed / duration_s
|
| 189 |
+
profile["duration_s"] = duration_s
|
| 190 |
+
return result, profile
|
| 191 |
+
|
| 192 |
+
def stream_chunks(
|
| 193 |
+
self,
|
| 194 |
+
text: str,
|
| 195 |
+
chunk_duration_s: float = 0.5,
|
| 196 |
+
forward_params: dict[str, Any] | None = None,
|
| 197 |
+
generate_kwargs: dict[str, Any] | None = None,
|
| 198 |
+
) -> Generator[tuple[np.ndarray, int], None, None]:
|
| 199 |
+
self._ensure_loaded()
|
| 200 |
+
out, _ = self.generate_with_profile(text, forward_params=forward_params, generate_kwargs=generate_kwargs)
|
| 201 |
+
single = out if isinstance(out, dict) else out[0]
|
| 202 |
+
audio = single["audio"]
|
| 203 |
+
sr = single["sampling_rate"]
|
| 204 |
+
if hasattr(audio, "numpy"):
|
| 205 |
+
arr = audio.numpy()
|
| 206 |
+
else:
|
| 207 |
+
arr = np.asarray(audio)
|
| 208 |
+
if arr.ndim == 1:
|
| 209 |
+
arr = arr[np.newaxis, :]
|
| 210 |
+
samples_per_chunk = int(chunk_duration_s * sr)
|
| 211 |
+
for start in range(0, arr.shape[-1], samples_per_chunk):
|
| 212 |
+
chunk = arr[..., start : start + samples_per_chunk]
|
| 213 |
+
if chunk.size == 0:
|
| 214 |
+
break
|
| 215 |
+
yield np.squeeze(chunk), sr
|
| 216 |
+
|
| 217 |
+
@property
|
| 218 |
+
def sampling_rate(self) -> int:
|
| 219 |
+
self._ensure_loaded()
|
| 220 |
+
return getattr(self._pipe, "sampling_rate", 24000)
|
| 221 |
+
|
| 222 |
+
|
| 223 |
+
def stream_audio_to_file(
|
| 224 |
+
chunk_iter: Iterator[tuple[np.ndarray, int]],
|
| 225 |
+
path: str,
|
| 226 |
+
subtype: str = "PCM_16",
|
| 227 |
+
) -> None:
|
| 228 |
+
import soundfile as sf
|
| 229 |
+
|
| 230 |
+
first_chunk, sr = next(chunk_iter)
|
| 231 |
+
with sf.SoundFile(path, "w", samplerate=sr, channels=1 if first_chunk.ndim == 1 else first_chunk.shape[0], subtype=subtype) as f:
|
| 232 |
+
f.write(first_chunk)
|
| 233 |
+
for chunk, _ in chunk_iter:
|
| 234 |
+
f.write(chunk)
|
tests/test_pipeline.py
ADDED
|
@@ -0,0 +1,39 @@
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
|
| 1 |
+
"""Unit tests for text_to_audio pipeline (no GPU or model download required)."""
|
| 2 |
+
|
| 3 |
+
from __future__ import annotations
|
| 4 |
+
|
| 5 |
+
import pytest
|
| 6 |
+
|
| 7 |
+
from src.text_to_audio import list_presets, build_pipeline, TextToAudioPipeline
|
| 8 |
+
from src.text_to_audio.pipeline import PipelineConfig, _get_model_kwargs, PRESETS
|
| 9 |
+
|
| 10 |
+
|
| 11 |
+
def test_list_presets() -> None:
|
| 12 |
+
presets = list_presets()
|
| 13 |
+
assert "csm-1b" in presets
|
| 14 |
+
assert "bark-small" in presets
|
| 15 |
+
assert presets["csm-1b"]["model_id"] == "sesame/csm-1b"
|
| 16 |
+
|
| 17 |
+
|
| 18 |
+
def test_build_pipeline_returns_wrapper() -> None:
|
| 19 |
+
pipe = build_pipeline(preset="csm-1b")
|
| 20 |
+
assert isinstance(pipe, TextToAudioPipeline)
|
| 21 |
+
assert pipe.config.model_id == "sesame/csm-1b"
|
| 22 |
+
|
| 23 |
+
|
| 24 |
+
def test_build_pipeline_custom_model_id() -> None:
|
| 25 |
+
pipe = build_pipeline(model_id="suno/bark-small")
|
| 26 |
+
assert pipe.config.model_id == "suno/bark-small"
|
| 27 |
+
|
| 28 |
+
|
| 29 |
+
def test_config_model_kwargs_4bit() -> None:
|
| 30 |
+
config = PipelineConfig(model_id="test", use_4bit=True, use_flash_attention_2=False)
|
| 31 |
+
kwargs = _get_model_kwargs(config)
|
| 32 |
+
assert "quantization_config" in kwargs
|
| 33 |
+
assert kwargs["torch_dtype"] is not None
|
| 34 |
+
|
| 35 |
+
|
| 36 |
+
def test_config_model_kwargs_flash_attn() -> None:
|
| 37 |
+
config = PipelineConfig(model_id="test", use_4bit=False, use_flash_attention_2=True)
|
| 38 |
+
kwargs = _get_model_kwargs(config)
|
| 39 |
+
assert kwargs.get("attn_implementation") == "flash_attention_2"
|