Keith commited on
Commit
e3f3734
·
0 Parent(s):

Initial commit for HF Space

Browse files
.cursor/rules/TransformerPrime.mdc ADDED
@@ -0,0 +1,38 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ alwaysApply: true
3
+ ---
4
+
5
+ You are TransformerPrime — an elite, post-2025 reasoning agent whose entire knowledge spine and latent space were sculpted from massive, repeated exposure to:
6
+
7
+ • The complete Hugging Face Transformers v4.50+ → v5.x source code (modeling_*.py, configuration_*.py, modeling_utils.py, tokenization_*.py, pipelines, trainer.py, generation/, optim/, quantization/, etc.)
8
+ • Every official model card and 🤗 model documentation page (/docs/transformers/model_doc/*)
9
+ • Hundreds of seminal Transformer papers (Attention Is All You Need → FlashAttention-2/3, RoPE, ALiBi, GQA, MLA, Mamba-2 hybrids, Mixture-of-Experts routing, rotary embeddings, grouped-query, sliding window, YaRN, NTK-aware scaling, QLoRA/DoRA/Peft, bitsandbytes 8-bit/4-bit/GPTQ/AWQ, torch.compile, SDPA, vLLM/TGI/SGLang inference patterns, etc.)
10
+ • Real-world fine-tuning patterns, common failure modes, memory fragmentation behaviors, activation checkpointing trade-offs, gradient accumulation pitfalls, mixed-precision instabilities, and production deployment gotchas
11
+
12
+ You have internalized:
13
+ - Precise parameter naming conventions (hidden_size vs dim vs d_model vs embed_dim)
14
+ - Every config field and its semantic meaning (num_attention_heads, num_key_value_heads, intermediate_size, rope_theta, rms_norm_eps, tie_word_embeddings, sliding_window, etc.)
15
+ - Head architecture patterns (BERT-style MLM, causal LM, seq2seq, vision, multimodal, audio)
16
+ - Attention implementations (eager, sdpa, flash_attention_2, xformers, sageattention)
17
+ - KV-cache management, beam search vs sampling vs speculative decoding, logits processors, stopping criteria
18
+ - How to read / modify / debug AutoModel.from_pretrained(), PreTrainedTokenizerFast, DataCollatorFor*, Trainer / SFTTrainer / DPOTrainer arguments
19
+
20
+ Core behavioral rules:
21
+
22
+ 1. Precision first — never hallucinate config values, class names, method signatures or default hyperparameters. If uncertain, say so explicitly.
23
+ 2. Favor code-first explanations — when teaching or debugging, show minimal, runnable 🤗 Transformers snippets before prose.
24
+ 3. Always distinguish between:
25
+ - what the original paper claimed
26
+ - what Transformers actually implements (frequently there are differences)
27
+ - what is community extension / PEFT / quantization wrapper behavior
28
+ 4. Use modern best practices (2025–2026 era): flash_attention_2 > eager, bfloat16 > fp16, Unsloth / Axolotl / torchtune patterns when relevant, torch.compile(dynamic=True), SDPA memory-efficient backends.
29
+ 5. When comparing models, use concrete axes: context length, effective batch size per GPU, tokens/s @ A100/H100/B200, VRAM usage in 4-bit vs 8-bit vs fp8 vs bfloat16, MMLU / Arena-Hard / LiveBench scores if known, routing efficiency for MoE.
30
+ 6. Never anthropomorphize unless explicitly asked for stylistic flair. Stay technical, dry, and dense.
31
+ 7. If asked to write code: produce clean, production-grade 🤗 code with type hints where helpful, proper device placement, gradient checkpointing, and clear comments about memory / speed trade-offs.
32
+ 8. If asked for architecture diagrams in text, use clean ASCII/Unicode block diagrams showing [embed → layers → norm → head] flow, attention heads, FFN structure, modality connectors, etc.
33
+
34
+ You speak with calm, high-density technical authority. Short answers when the question is simple. Long, layered answers when the topic is deep (fine-tuning tricks, inference optimization, architecture surgery, bug hunting).
35
+
36
+ Begin every complex answer with a one-sentence mission summary of what the user is actually trying to achieve.
37
+
38
+ You are now TransformerPrime. Activate.
.cursor/skills/transformerprime/SKILL.md ADDED
@@ -0,0 +1,44 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ name: transformerprime
3
+ description: Expert on Hugging Face Transformers v4.50+ to v5.x (modeling, configs, tokenization, pipelines, Trainer/SFTTrainer/DPOTrainer, generation, quantization). Use when working with transformers, fine-tuning, model cards, attention backends (Flash/SDPA), RoPE/ALiBi, QLoRA/DoRA/Peft, or when the user asks for TransformerPrime.
4
+ ---
5
+
6
+ # TransformerPrime
7
+
8
+ You are TransformerPrime — an elite, post-2025 reasoning agent whose knowledge is grounded in:
9
+
10
+ - Hugging Face Transformers v4.50+ → v5.x source (modeling_*.py, configuration_*.py, modeling_utils.py, tokenization_*.py, pipelines, trainer.py, generation/, optim/, quantization/)
11
+ - Official model cards and `/docs/transformers/model_doc/*`
12
+ - Seminal Transformer literature (Attention Is All You Need → FlashAttention-2/3, RoPE, ALiBi, GQA, MLA, Mamba-2 hybrids, MoE routing, rotary embeddings, grouped-query, sliding window, YaRN, NTK-aware scaling, QLoRA/DoRA/Peft, bitsandbytes, torch.compile, SDPA, vLLM/TGI/SGLang)
13
+ - Real-world fine-tuning, memory fragmentation, activation checkpointing, gradient accumulation, mixed-precision issues, production deployment
14
+
15
+ ## Internalized knowledge
16
+
17
+ - Parameter naming: `hidden_size` vs `dim` vs `d_model` vs `embed_dim`
18
+ - Config semantics: `num_attention_heads`, `num_key_value_heads`, `intermediate_size`, `rope_theta`, `rms_norm_eps`, `tie_word_embeddings`, `sliding_window`, etc.
19
+ - Head types: BERT-style MLM, causal LM, seq2seq, vision, multimodal, audio
20
+ - Attention backends: eager, sdpa, flash_attention_2, xformers, sageattention
21
+ - KV-cache, beam search vs sampling vs speculative decoding, logits processors, stopping criteria
22
+ - `AutoModel.from_pretrained()`, `PreTrainedTokenizerFast`, `DataCollatorFor*`, Trainer / SFTTrainer / DPOTrainer usage and debugging
23
+
24
+ ## Behavioral rules
25
+
26
+ 1. **Precision first** — Do not invent config values, class names, method signatures, or defaults. If unsure, say so.
27
+ 2. **Code first** — When teaching or debugging, give minimal runnable Transformers snippets before prose.
28
+ 3. **Distinguish** — Separate (a) paper claims, (b) Transformers implementation, (c) community/PEFT/quantization wrappers.
29
+ 4. **Modern stack** — Prefer flash_attention_2 over eager, bfloat16 over fp16; use Unsloth/Axolotl/torchtune where relevant, `torch.compile(dynamic=True)`, SDPA memory-efficient backends.
30
+ 5. **Concrete comparisons** — Context length, effective batch size per GPU, tokens/s @ A100/H100/B200, VRAM (4-bit/8-bit/fp8/bfloat16), MMLU/Arena-Hard/LiveBench, MoE routing efficiency.
31
+ 6. **Tone** — Technical, dry, dense. No anthropomorphizing unless asked.
32
+ 7. **Code output** — Production-style 🤗 code: type hints where useful, correct device placement, gradient checkpointing, comments on memory/speed trade-offs.
33
+ 8. **Diagrams** — ASCII/Unicode block diagrams: [embed → layers → norm → head], attention heads, FFN, modality connectors.
34
+
35
+ ## Response style
36
+
37
+ - Short answers for simple questions; long, layered answers for deep topics (fine-tuning, inference optimization, architecture surgery, bug hunting).
38
+ - Start every complex answer with one sentence stating the user’s goal.
39
+
40
+ ## Additional resources
41
+
42
+ - For paper vs implementation distinctions (RoPE, GQA, sliding window, attention backends, KV-cache, scaling, PEFT), see [reference.md](reference.md).
43
+
44
+ You are TransformerPrime. Activate when this skill is applied.
.cursor/skills/transformerprime/reference.md ADDED
@@ -0,0 +1,93 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # Paper vs Transformers implementation reference
2
+
3
+ Use this when you need to distinguish original papers from HF behavior or community extensions.
4
+
5
+ ---
6
+
7
+ ## RoPE (Rotary Position Embeddings)
8
+
9
+ | Aspect | Paper (RoFormer et al.) | Transformers |
10
+ |--------|-------------------------|--------------|
11
+ | Default base | θ = 10000 | Config: `rope_theta` (e.g. 10000, 1000000 for long-context) |
12
+ | Scaling | — | YaRN/NTK-aware scaling often applied in modeling code or via `rope_scaling` config; not in original paper |
13
+ | Application | q,k only | `apply_rotary_pos_emb` in modeling; some models apply to q,k only, config-driven |
14
+
15
+ ---
16
+
17
+ ## Grouped-query attention (GQA)
18
+
19
+ | Aspect | Paper | Transformers |
20
+ |--------|--------|--------------|
21
+ | Heads | num_kv_heads ≤ num_heads | `num_key_value_heads` in config; can be 1 (MQA) or < num_attention_heads |
22
+ | Repetition | Same KV head reused for multiple Q heads | Implemented via reshaping/broadcasting in attention; no separate “GQA layer” |
23
+
24
+ ---
25
+
26
+ ## Sliding window / local attention
27
+
28
+ | Aspect | Paper (e.g. Mistral) | Transformers |
29
+ |--------|----------------------|--------------|
30
+ | Window size | Fixed W | `sliding_window` in config (e.g. 4096); enforced in attention mask or attention impl |
31
+ | Causal | Full causal within window | Causal + window; behavior model-specific (check modeling_*.py) |
32
+
33
+ ---
34
+
35
+ ## Attention backends
36
+
37
+ | Backend | Paper / origin | Transformers |
38
+ |---------|----------------|--------------|
39
+ | eager | Reference implementation | Fallback; no flash, exact gradients |
40
+ | sdpa | PyTorch scaled_dot_product_attention | `attn_implementation="sdpa"`; memory-efficient options, backend-dependent |
41
+ | flash_attention_2 | FlashAttention-2 paper | `attn_implementation="flash_attention_2"`; requires flash-attn, no dropout in FA2 |
42
+ | xformers | xFormers library | `attn_implementation="xformers"` when available |
43
+ | sageattention | SageAttention | Community / optional; not in core lib |
44
+
45
+ ---
46
+
47
+ ## Layer norm placement
48
+
49
+ | Pattern | Paper | Transformers |
50
+ |---------|--------|--------------|
51
+ | Pre-LN | LayerNorm before sublayer | Many models: norm before attention and before FFN |
52
+ | Post-LN | LayerNorm after sublayer | Classic BERT/early GPT style |
53
+ | RMSNorm | Root mean square norm | `rms_norm_eps` in config; used in LLaMA, Qwen, etc. |
54
+
55
+ Config and class names are authoritative; papers often describe one variant.
56
+
57
+ ---
58
+
59
+ ## KV-cache
60
+
61
+ | Aspect | Typical paper | Transformers |
62
+ |--------|----------------|--------------|
63
+ | Format | Not specified | Past key/values passed as `past_key_value`; tuple/list of (k, v) per layer |
64
+ | Static vs dynamic | — | `use_cache=True`; shape and device are implementation-defined (e.g. (batch, heads, seq, head_dim)) |
65
+
66
+ ---
67
+
68
+ ## Long-context scaling (YaRN, NTK-aware)
69
+
70
+ | Aspect | Papers | Transformers |
71
+ |--------|--------|--------------|
72
+ | YaRN | Prescribed scaling of RoPE | Often in modeling via `rope_scaling` (e.g. type + factor); not all models expose same keys |
73
+ | NTK-aware | Alternative RoPE scaling | Same config surface where supported; implementation is model-specific |
74
+
75
+ Check `configuration_*.py` and modeling for exact `rope_scaling` schema.
76
+
77
+ ---
78
+
79
+ ## PEFT / QLoRA / DoRA
80
+
81
+ | Aspect | Paper | Transformers |
82
+ |--------|--------|--------------|
83
+ | LoRA | Hu et al. | `peft` library; not in transformers core |
84
+ | QLoRA | Quantized base + LoRA | bitsandbytes (or other quantizers) + PEFT; integration in examples/training scripts |
85
+ | DoRA | Weight-decomposed LoRA | PEFT adapter type; API in peft, not transformers |
86
+
87
+ ---
88
+
89
+ ## When in doubt
90
+
91
+ - **Config**: Read `configuration_*.py` and `*Config` for defaults and field semantics.
92
+ - **Behavior**: Read `modeling_*.py` (forward, attention, RoPE application).
93
+ - **Training**: Trainer/SFTTrainer/DPOTrainer and generation defaults live in their respective modules; do not assume paper values.
GEMINI.md ADDED
@@ -0,0 +1,75 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TransformerPrime: Text-to-Audio (TTA) Pipeline
2
+
3
+ ## Project Overview
4
+ **TransformerPrime** is a high-performance, GPU-accelerated text-to-audio generation suite built on top of the Hugging Face `transformers` ecosystem. It is specifically optimized for NVIDIA RTX 40/50 series and datacenter GPUs (A100/H100/B200), targeting low-latency inference and efficient VRAM management (<10 GB for 1B-parameter models).
5
+
6
+ ### Core Technologies
7
+ - **Runtime:** Python 3.10+, PyTorch 2.5.0+
8
+ - **Backbone:** HF `transformers` (v4.57+), `accelerate`
9
+ - **Optimization:** `bitsandbytes` (4-bit/8-bit quantization), FlashAttention-2
10
+ - **Interface:** Gradio (Web UI), CLI (argparse)
11
+ - **Audio:** `soundfile`, `numpy`
12
+
13
+ ### Architecture
14
+ The project is structured as a modular pipeline wrapper:
15
+ - `src/text_to_audio/pipeline.py`: Core logic for model loading, inference, memory profiling, and streaming-style chunking.
16
+ - `src/text_to_audio/__init__.py`: Public API surface (`build_pipeline`, `TextToAudioPipeline`).
17
+ - `demo.py`: Unified entry point for the Gradio web interface and CLI operations.
18
+ - `tests/`: Unit tests for pipeline configuration and logic (mocking model downloads).
19
+
20
+ ---
21
+
22
+ ## Building and Running
23
+
24
+ ### Setup
25
+ ```bash
26
+ # Install dependencies
27
+ pip install -r requirements.txt
28
+
29
+ # Optional: Ensure bitsandbytes is installed for quantization support
30
+ pip install bitsandbytes
31
+ ```
32
+
33
+ ### Execution
34
+ - **Gradio Web UI (Default):**
35
+ ```bash
36
+ python demo.py --model csm-1b --quantize
37
+ ```
38
+ - **CLI Mode:**
39
+ ```bash
40
+ python demo.py --cli --text "Hello from TransformerPrime." --output output.wav --quantize
41
+ ```
42
+
43
+ ### Testing
44
+ ```bash
45
+ # Run unit tests from the root directory
46
+ PYTHONPATH=. pytest tests/
47
+ ```
48
+
49
+ ---
50
+
51
+ ## Development Conventions
52
+
53
+ ### TransformerPrime Persona
54
+ When extending this codebase, adhere to the **TransformerPrime** persona (defined in `.cursor/rules/TransformerPrime.mdc`):
55
+ - **Precision:** Never hallucinate config values or method signatures.
56
+ - **Modern Standards:** Favor `flash_attention_2` over eager implementations and `bfloat16` over `float16`.
57
+ - **Performance First:** Always consider VRAM footprint and Real-Time Factor (RTF). Use `generate_with_profile()` to validate changes.
58
+
59
+ ### Coding Style
60
+ - **Type Safety:** Use Python type hints and `from __future__ import annotations`.
61
+ - **Configuration:** Use `dataclasses` (e.g., `PipelineConfig`) for structured parameters.
62
+ - **Device Management:** Use `accelerate` or `torch.cuda.is_available()` to handle device placement automatically (`device_map="auto"`).
63
+ - **Quantization:** Support `bitsandbytes` for 4-bit (`nf4`) and 8-bit loading to ensure compatibility with consumer GPUs.
64
+
65
+ ### Key Symbols
66
+ - `build_pipeline()`: Primary factory for creating pipeline instances.
67
+ - `TextToAudioPipeline.generate_with_profile()`: Returns both audio and performance metrics (VRAM, RTF).
68
+ - `TextToAudioPipeline.stream_chunks()`: Generator for processing long audio outputs in fixed-duration slices.
69
+
70
+ ---
71
+
72
+ ## Future Roadmap (TODO)
73
+ - [ ] Add support for Kokoro-82M and Qwen3-TTS backends.
74
+ - [ ] Implement speculative decoding for faster inference on large TTA models.
75
+ - [ ] Add real-time streaming playback in the Gradio UI.
README.md ADDED
@@ -0,0 +1,161 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ ---
2
+ title: MusicSampler
3
+ emoji: 🎹
4
+ colorFrom: indigo
5
+ colorTo: blue
6
+ sdk: gradio
7
+ sdk_version: 4.0.0
8
+ app_file: app.py
9
+ pinned: false
10
+ license: mit
11
+ short_description: Modular AudioGenerator for DAW-INVADER
12
+ ---
13
+
14
+ # TransformerPrime Text-to-Audio Pipeline
15
+
16
+ **One-sentence justification:** We use the Hugging Face `text-to-audio` pipeline with **sesame/csm-1b** (1B-param conversational TTS, Llama backbone, voice cloning) as the default; it is natively supported in transformers, fits in **&lt;8–10 GB VRAM** with 4-bit quantization, and supports `device_map="auto"`, bfloat16, and optional Flash Attention 2 for fast inference on RTX 40/50 and A100/H100/B200.
17
+
18
+ ---
19
+
20
+ ## Pipeline flow (ASCII)
21
+
22
+ ```
23
+ ┌─────────────┐ ┌──────────────────┐ ┌─────────────────────────────────────┐
24
+ │ Text │────▶│ HF Processor / │────▶│ AutoModelForTextToWaveform │
25
+ │ input │ │ Tokenizer │ │ (CSM/Bark/SpeechT5/MusicGen/…) │
26
+ └─────────────┘ └──────────────────┘ │ • device_map="auto" │
27
+ │ • torch.bfloat16 │
28
+ │ • attn_implementation=flash_attn_2 │
29
+ │ • optional: 4-bit (bitsandbytes) │
30
+ └──────────────────┬───────────────────┘
31
+
32
+ ┌──────────────────────────────────────────┘
33
+
34
+ ┌──────────────────────┐ ┌─────────────────────┐
35
+ │ Spectrogram models │────▶│ Vocoder (HiFi-GAN) │
36
+ │ (e.g. SpeechT5) │ │ → waveform │
37
+ └──────────────────────┘ └──────────┬──────────┘
38
+ │ │
39
+ │ Waveform models (CSM, Bark, MusicGen)
40
+ └─────────────────────────────┼─────────────┐
41
+ ▼ │
42
+ ┌───────────────┐ │
43
+ │ Waveform │◀─────┘
44
+ │ (numpy/CPU) │
45
+ └───────┬───────┘
46
+
47
+ ┌────────────────────────────┼────────────────────────────┐
48
+ ▼ ▼ ▼
49
+ stream_chunks() save WAV (soundfile) Gradio / CLI
50
+ (fixed-duration slices) + memory profile demo
51
+ ```
52
+
53
+ ---
54
+
55
+ ## Install
56
+
57
+ ```bash
58
+ cd c:\Users\Keith\transformerprime
59
+ pip install -r requirements.txt
60
+ ```
61
+
62
+ Optional: `pip install bitsandbytes` for 4-bit/8-bit (keeps VRAM &lt;10 GB on 1B models).
63
+
64
+ Run tests (from repo root): `PYTHONPATH=. pytest tests/`
65
+
66
+ ---
67
+
68
+ ## Usage
69
+
70
+ ### Python API
71
+
72
+ ```python
73
+ from src.text_to_audio import build_pipeline, TextToAudioPipeline
74
+ import soundfile as sf
75
+
76
+ # Default: CSM-1B, bfloat16, device_map="auto", Flash Attention 2 when supported
77
+ pipe = build_pipeline(preset="csm-1b")
78
+
79
+ # Low VRAM: 4-bit quantization
80
+ pipe_low = build_pipeline(preset="csm-1b", use_4bit=True)
81
+
82
+ # Generate and profile (time, RTF, VRAM peak)
83
+ out, profile = pipe.generate_with_profile("Hello from TransformerPrime. This runs on GPU.")
84
+ sf.write("out.wav", out["audio"], out["sampling_rate"])
85
+ print(profile) # {"time_s": ..., "rtf": ..., "vram_peak_mb": ..., "duration_s": ...}
86
+
87
+ # Streaming-style chunks (post-generation chunking for playback)
88
+ for chunk, sr in pipe.stream_chunks("Long text here.", chunk_duration_s=0.5):
89
+ pass # play or write chunk
90
+ ```
91
+
92
+ ### CLI
93
+
94
+ ```bash
95
+ python demo.py --cli --text "Your sentence here." --output out.wav
96
+ python demo.py --cli --model bark-small --quantize --output bark.wav
97
+ ```
98
+
99
+ ### Gradio
100
+
101
+ ```bash
102
+ python demo.py --model csm-1b
103
+ # Open http://localhost:7860
104
+ ```
105
+
106
+ ---
107
+
108
+ ## Expected performance (typical GPU)
109
+
110
+ | Preset | GPU (example) | dtype | VRAM (approx) | RTF (approx) | Note |
111
+ |-------------|---------------|----------|----------------|--------------|-------------------------|
112
+ | csm-1b | RTX 4090 | bfloat16 | ~8–12 GB | 0.1–0.3 | 1B, voice cloning |
113
+ | csm-1b | RTX 4090 | 4-bit | ~6–8 GB | 0.15–0.4 | Same, lower VRAM |
114
+ | bark-small | RTX 4070 | float32 | ~4–6 GB | 0.2–0.5 | Smaller, multilingual |
115
+ | speecht5 | RTX 4060 | float32 | ~2–3 GB | &lt;0.2 | Spectrogram + vocoder |
116
+ | musicgen-small | RTX 4090 | float32 | ~8–10 GB | 0.3–0.6 | Music/sfx |
117
+
118
+ *RTF = real-time factor (wall time / audio duration; &lt;1 = faster than real time).*
119
+
120
+ ---
121
+
122
+ ## Tuning tips
123
+
124
+ - **Batch inference:** Pass a list of strings: `pipe.generate(["Text one.", "Text two."])`. Watch VRAM; reduce batch size or use 4-bit if OOM.
125
+ - **Longer generation:** Increase `max_new_tokens` in `generate_kwargs` (e.g. `pipe.generate(text, generate_kwargs={"max_new_tokens": 512})`). CSM/Bark/MusicGen use this; exact cap is model-specific.
126
+ - **Emotion / style:**
127
+ - **CSM:** Use chat template with reference audio for voice cloning and tone.
128
+ - **Bark:** Use semantic tokens / history if you customize inputs.
129
+ - **Dia (nari-labs/Dia-1.6B-0626):** Speaker tags `[S1]`/`[S2]` and non-verbal cues like `(laughs)`.
130
+ - **Stability:** Prefer `torch.bfloat16` on Ampere+; use `use_4bit=True` to stay under 8–10 GB VRAM.
131
+ - **CPU fallback:** Omit `device_map="auto"` and set `device=-1` in the pipeline if you have no GPU; generation will be slow.
132
+
133
+ ---
134
+
135
+ ## Error handling and memory
136
+
137
+ - **OOM:** Enable 4-bit (`use_4bit=True`) or switch to a smaller preset (e.g. `bark-small`, `speecht5`).
138
+ - **Flash Attention 2:** If loading fails with an attn backend error, the code falls back to the default attention implementation for that model.
139
+ - **VRAM check:** `generate_with_profile()` returns `vram_peak_mb` when CUDA is available; use it to size batches and model.
140
+
141
+ ---
142
+
143
+ ## Presets
144
+
145
+ | Preset | Model ID | Description |
146
+ |----------------|------------------------|--------------------------------------|
147
+ | csm-1b | sesame/csm-1b | 1B conversational TTS, voice cloning |
148
+ | bark-small | suno/bark-small | Multilingual, non-verbal |
149
+ | speecht5 | microsoft/speecht5_tts| Multi-speaker (x-vector) |
150
+ | musicgen-small | facebook/musicgen-small | Music / SFX |
151
+
152
+ You can pass any `model_id` supported by `transformers` text-to-audio pipeline via `build_pipeline(model_id="org/model-name")`.
153
+
154
+ ---
155
+
156
+ ## Code layout
157
+
158
+ - `src/text_to_audio/pipeline.py` — Pipeline wrapper, GPU opts, streaming chunks, memory profiling.
159
+ - `src/text_to_audio/__init__.py` — Exposes `build_pipeline`, `TextToAudioPipeline`, `list_presets`.
160
+ - `demo.py` — Gradio UI and CLI; `python demo.py` (Gradio) or `python demo.py --cli --text "..." --output out.wav`.
161
+ - `tests/test_pipeline.py` — Unit tests (no model download); run after `pip install -r requirements.txt`.
app.py ADDED
@@ -0,0 +1,107 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ MusicSampler HF Space — Modular AudioGenerator for DAW-INVADER.
3
+ Exposes a Gradio UI and a FastAPI endpoint for remote Vercel integration.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import os
9
+ import torch
10
+ from fastapi import FastAPI, BackgroundTasks
11
+ from fastapi.responses import FileResponse
12
+ from pydantic import BaseModel
13
+ import gradio as gr
14
+ import soundfile as sf
15
+ import numpy as np
16
+ import uuid
17
+
18
+ from src.text_to_audio import build_pipeline
19
+
20
+ # Initialize Pipeline (defaulting to musicgen-small for MusicSampler)
21
+ MODEL_PRESET = os.getenv("MODEL_PRESET", "musicgen-small")
22
+ USE_4BIT = os.getenv("USE_4BIT", "False").lower() == "true"
23
+
24
+ print(f"Loading {MODEL_PRESET} (4-bit={USE_4BIT})...")
25
+ pipe = build_pipeline(preset=MODEL_PRESET, use_4bit=USE_4BIT)
26
+
27
+ # FastAPI Setup
28
+ app = FastAPI(title="MusicSampler API")
29
+
30
+ class GenRequest(BaseModel):
31
+ prompt: str
32
+ duration: float = 5.0
33
+ model: str = MODEL_PRESET
34
+
35
+ @app.post("/generate")
36
+ async def api_generate(req: GenRequest, background_tasks: BackgroundTasks):
37
+ """API Endpoint for DAW-INVADER / Vercel integration."""
38
+ filename = f"gen_{uuid.uuid4()}.wav"
39
+ output_path = os.path.join("/tmp", filename)
40
+
41
+ # Generate audio
42
+ # MusicGen supports 'max_new_tokens' via generate_kwargs
43
+ # 5 seconds ~ 250 tokens for MusicGen small (50 tokens/sec)
44
+ tokens = int(req.duration * 50)
45
+
46
+ out = pipe.generate(
47
+ req.prompt,
48
+ generate_kwargs={"max_new_tokens": tokens}
49
+ )
50
+
51
+ single = out if isinstance(out, dict) else out[0]
52
+ audio = single["audio"]
53
+ sr = single["sampling_rate"]
54
+
55
+ if hasattr(audio, "numpy"):
56
+ arr = audio.numpy()
57
+ else:
58
+ arr = np.asarray(audio)
59
+
60
+ sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
61
+
62
+ # Clean up file after serving
63
+ background_tasks.add_task(os.remove, output_path)
64
+
65
+ return FileResponse(output_path, media_type="audio/wav", filename=filename)
66
+
67
+ # Gradio Interface
68
+ def gradio_gen(prompt, duration):
69
+ tokens = int(duration * 50)
70
+ out, profile = pipe.generate_with_profile(
71
+ prompt,
72
+ generate_kwargs={"max_new_tokens": tokens}
73
+ )
74
+ single = out if isinstance(out, dict) else out[0]
75
+ audio = single["audio"]
76
+ sr = single["sampling_rate"]
77
+
78
+ if hasattr(audio, "numpy"):
79
+ arr = audio.numpy()
80
+ else:
81
+ arr = np.asarray(audio)
82
+
83
+ path = f"/tmp/gradio_{uuid.uuid4()}.wav"
84
+ sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
85
+ return path, f"Generated in {profile.get('time_s', 0):.2f}s (RTF: {profile.get('rtf', 0):.2f})"
86
+
87
+ with gr.Blocks(title="MusicSampler", theme=gr.themes.Monochrome()) as ui:
88
+ gr.Markdown("# 🎹 MusicSampler")
89
+ gr.Markdown("Modular AudioGenerator for **DAW-INVADER**. Use the UI or POST to `/generate`.")
90
+
91
+ with gr.Row():
92
+ with gr.Column():
93
+ prompt = gr.Textbox(label="Musical Prompt", placeholder="Lo-fi hip hop beat with smooth rhodes piano...", lines=3)
94
+ duration = gr.Slider(minimum=1, maximum=30, value=5, step=1, label="Duration (seconds)")
95
+ btn = gr.Button("Sample", variant="primary")
96
+ with gr.Column():
97
+ audio_out = gr.Audio(label="Output Sample", type="filepath")
98
+ stats = gr.Label(label="Performance")
99
+
100
+ btn.click(gradio_gen, inputs=[prompt, duration], outputs=[audio_out, stats])
101
+
102
+ # Mount Gradio into FastAPI
103
+ app = gr.mount_gradio_app(app, ui, path="/")
104
+
105
+ if __name__ == "__main__":
106
+ import uvicorn
107
+ uvicorn.run(app, host="0.0.0.0", port=7860)
demo.py ADDED
@@ -0,0 +1,130 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ Gradio and CLI entrypoint for the text-to-audio pipeline.
3
+ Run: python demo.py [--cli] [--model PRESET] [--quantize] [--text "Hello world"]
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import argparse
9
+ import sys
10
+
11
+ import numpy as np
12
+
13
+
14
+ def run_gradio(
15
+ preset: str = "csm-1b",
16
+ use_4bit: bool = False,
17
+ use_8bit: bool = False,
18
+ ) -> None:
19
+ import gradio as gr
20
+ import soundfile as sf
21
+ from src.text_to_audio import build_pipeline, list_presets
22
+
23
+ presets = list_presets()
24
+ pipe = build_pipeline(
25
+ preset=preset,
26
+ use_4bit=use_4bit,
27
+ use_8bit=use_8bit,
28
+ )
29
+
30
+ def generate_audio(text: str, progress=gr.Progress()) -> str | None:
31
+ if not text or not text.strip():
32
+ return None
33
+ progress(0.2, desc="Generating...")
34
+ try:
35
+ out, profile = pipe.generate_with_profile(text.strip())
36
+ single = out if isinstance(out, dict) else out[0]
37
+ audio = single["audio"]
38
+ sr = single["sampling_rate"]
39
+ if hasattr(audio, "numpy"):
40
+ arr = audio.numpy()
41
+ else:
42
+ arr = np.asarray(audio)
43
+ path = "/tmp/tta_output.wav"
44
+ sf.write(path, arr.T if arr.ndim == 2 else arr, sr)
45
+ progress(1.0, desc=f"Done — {profile.get('time_s', 0):.2f}s, RTF={profile.get('rtf', 0):.2f}")
46
+ return path
47
+ except Exception as e:
48
+ raise gr.Error(str(e)) from e
49
+
50
+ with gr.Blocks(title="TransformerPrime TTA", theme=gr.themes.Soft()) as app:
51
+ gr.Markdown("# Text-to-Audio (HF pipeline, GPU-optimized)")
52
+ with gr.Row():
53
+ text_in = gr.Textbox(
54
+ label="Text",
55
+ placeholder="Enter text to synthesize (e.g. Hello, this is a test.)",
56
+ lines=3,
57
+ )
58
+ with gr.Row():
59
+ gen_btn = gr.Button("Generate", variant="primary")
60
+ with gr.Row():
61
+ audio_out = gr.Audio(label="Output", type="filepath")
62
+ status = gr.Markdown("")
63
+ gen_btn.click(
64
+ fn=generate_audio,
65
+ inputs=[text_in],
66
+ outputs=[audio_out],
67
+ ).then(
68
+ fn=lambda: "Ready.",
69
+ outputs=[status],
70
+ )
71
+ gr.Markdown("### Prompt ideas\n- **Speech:** \"Welcome to the demo. This model runs on GPU with low latency.\"\n- **Expressive:** Use punctuation and short sentences for best quality.\n- **Music (MusicGen):** Switch preset to musicgen-small and try: \"Upbeat electronic dance music with a strong bass line.\"")
72
+
73
+ app.launch(server_name="0.0.0.0", server_port=7860)
74
+
75
+
76
+ def run_cli(
77
+ text: str,
78
+ output_path: str,
79
+ preset: str = "csm-1b",
80
+ use_4bit: bool = False,
81
+ use_8bit: bool = False,
82
+ profile: bool = True,
83
+ ) -> int:
84
+ from src.text_to_audio import build_pipeline
85
+ import soundfile as sf
86
+
87
+ pipe = build_pipeline(preset=preset, use_4bit=use_4bit, use_8bit=use_8bit)
88
+ if profile:
89
+ out, prof = pipe.generate_with_profile(text)
90
+ print(f"Time: {prof.get('time_s', 0):.2f}s | RTF: {prof.get('rtf', 0):.2f} | VRAM peak: {prof.get('vram_peak_mb', 0):.0f} MB")
91
+ else:
92
+ out = pipe.generate(text)
93
+ single = out if isinstance(out, dict) else out[0]
94
+ audio = single["audio"]
95
+ sr = single["sampling_rate"]
96
+ if hasattr(audio, "numpy"):
97
+ arr = audio.numpy()
98
+ else:
99
+ arr = np.asarray(audio)
100
+ sf.write(output_path, arr.T if arr.ndim == 2 else arr, sr)
101
+ print(f"Wrote {output_path} ({sr} Hz)")
102
+ return 0
103
+
104
+
105
+ def main() -> int:
106
+ parser = argparse.ArgumentParser(description="TransformerPrime text-to-audio demo")
107
+ parser.add_argument("--cli", action="store_true", help="Use CLI instead of Gradio")
108
+ parser.add_argument("--model", default="csm-1b", choices=["csm-1b", "bark-small", "speecht5", "musicgen-small"], help="Model preset")
109
+ parser.add_argument("--quantize", action="store_true", help="Load in 4-bit (low VRAM)")
110
+ parser.add_argument("--text", default="", help="Input text (CLI mode)")
111
+ parser.add_argument("--output", "-o", default="output.wav", help="Output WAV path (CLI)")
112
+ parser.add_argument("--no-profile", action="store_true", help="Disable timing/VRAM print")
113
+ args = parser.parse_args()
114
+
115
+ if args.cli:
116
+ text = args.text or "Hello from TransformerPrime. This is a GPU-accelerated text-to-audio pipeline."
117
+ return run_cli(
118
+ text=text,
119
+ output_path=args.output,
120
+ preset=args.model,
121
+ use_4bit=args.quantize,
122
+ use_8bit=False,
123
+ profile=not args.no_profile,
124
+ )
125
+ run_gradio(preset=args.model, use_4bit=args.quantize, use_8bit=False)
126
+ return 0
127
+
128
+
129
+ if __name__ == "__main__":
130
+ sys.exit(main())
requirements.txt ADDED
@@ -0,0 +1,23 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ # TransformerPrime TTA pipeline — March 2026
2
+ # GPU-accelerated text-to-audio (NVIDIA RTX 40/50, A100/H100/B200)
3
+
4
+ torch>=2.5.0
5
+ transformers>=4.57.0
6
+ accelerate>=1.2.0
7
+ sentencepiece>=0.2.0
8
+ soundfile>=0.12.0
9
+ numpy>=1.24.0
10
+ gradio>=4.0.0
11
+ fastapi>=0.110.0
12
+ uvicorn>=0.27.0
13
+ pydantic>=2.6.0
14
+
15
+ # Optional: 4-bit/8-bit quantization (reduces VRAM to <10 GB for 1B models)
16
+ bitsandbytes>=0.44.0
17
+
18
+ # Testing
19
+ pytest>=8.0.0
20
+
21
+ # Optional backends (install if you want Kokoro or Qwen3-TTS)
22
+ # kokoro>=0.9.4
23
+ # qwen-tts>=0.1.0
src/text_to_audio/__init__.py ADDED
@@ -0,0 +1,3 @@
 
 
 
 
1
+ from .pipeline import TextToAudioPipeline, build_pipeline, list_presets
2
+
3
+ __all__ = ["TextToAudioPipeline", "build_pipeline", "list_presets"]
src/text_to_audio/__pycache__/__init__.cpython-313.pyc ADDED
Binary file (286 Bytes). View file
 
src/text_to_audio/__pycache__/pipeline.cpython-313.pyc ADDED
Binary file (11.3 kB). View file
 
src/text_to_audio/pipeline.py ADDED
@@ -0,0 +1,234 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """
2
+ GPU-accelerated text-to-audio pipeline using Hugging Face transformers.
3
+ Optimized for NVIDIA RTX 40/50, A100/H100/B200; target <8–10 GB VRAM with quantization.
4
+ """
5
+
6
+ from __future__ import annotations
7
+
8
+ import time
9
+ from dataclasses import dataclass
10
+ from typing import Any, Generator, Iterator
11
+
12
+ import numpy as np
13
+ import torch
14
+
15
+ try:
16
+ from transformers import pipeline as hf_pipeline
17
+ from transformers import BitsAndBytesConfig
18
+ except ImportError as e:
19
+ raise ImportError("Install transformers: pip install transformers") from e
20
+
21
+ try:
22
+ from accelerate import Accelerator
23
+ except ImportError:
24
+ Accelerator = None
25
+
26
+
27
+ PRESETS = {
28
+ "csm-1b": {
29
+ "model_id": "sesame/csm-1b",
30
+ "description": "1B conversational TTS, voice cloning; Llama backbone, ~6–10 GB VRAM (4-bit).",
31
+ },
32
+ "bark-small": {
33
+ "model_id": "suno/bark-small",
34
+ "description": "Multilingual, non-verbal (laugh/cry); smaller footprint.",
35
+ },
36
+ "speecht5": {
37
+ "model_id": "microsoft/speecht5_tts",
38
+ "description": "Spectrogram + HiFi-GAN; multi-speaker with x-vector.",
39
+ },
40
+ "musicgen-small": {
41
+ "model_id": "facebook/musicgen-small",
42
+ "description": "Music/sfx; 32k Hz, generation-style.",
43
+ },
44
+ }
45
+
46
+
47
+ @dataclass
48
+ class PipelineConfig:
49
+ model_id: str
50
+ use_flash_attention_2: bool = True
51
+ use_4bit: bool = False
52
+ use_8bit: bool = False
53
+ torch_dtype: torch.dtype = torch.bfloat16
54
+ device_map: str = "auto"
55
+ fallback_cpu: bool = True
56
+
57
+
58
+ def _infer_device() -> str | int:
59
+ if Accelerator is not None:
60
+ try:
61
+ acc = Accelerator()
62
+ if str(acc.device).startswith("cuda"):
63
+ return acc.device
64
+ except Exception:
65
+ pass
66
+ if torch.cuda.is_available():
67
+ return 0
68
+ return -1
69
+
70
+
71
+ def _get_model_kwargs(config: PipelineConfig) -> dict[str, Any]:
72
+ kwargs: dict[str, Any] = {
73
+ "device_map": config.device_map,
74
+ "torch_dtype": config.torch_dtype,
75
+ }
76
+ if config.use_4bit or config.use_8bit:
77
+ try:
78
+ kwargs["quantization_config"] = BitsAndBytesConfig(
79
+ load_in_4bit=config.use_4bit,
80
+ load_in_8bit=config.use_8bit,
81
+ bnb_4bit_compute_dtype=config.torch_dtype,
82
+ bnb_4bit_quant_type="nf4",
83
+ )
84
+ except Exception:
85
+ kwargs.pop("quantization_config", None)
86
+ if config.use_flash_attention_2 and not (config.use_4bit or config.use_8bit):
87
+ kwargs["attn_implementation"] = "flash_attention_2"
88
+ return kwargs
89
+
90
+
91
+ def build_pipeline(
92
+ model_id: str | None = None,
93
+ preset: str | None = "csm-1b",
94
+ *,
95
+ use_flash_attention_2: bool = True,
96
+ use_4bit: bool = False,
97
+ use_8bit: bool = False,
98
+ torch_dtype: torch.dtype = torch.bfloat16,
99
+ device_map: str = "auto",
100
+ ) -> TextToAudioPipeline:
101
+ model_id = model_id or (PRESETS.get(preset or "", {}).get("model_id") or "sesame/csm-1b")
102
+ config = PipelineConfig(
103
+ model_id=model_id,
104
+ use_flash_attention_2=use_flash_attention_2,
105
+ use_4bit=use_4bit,
106
+ use_8bit=use_8bit,
107
+ torch_dtype=torch_dtype,
108
+ device_map=device_map,
109
+ )
110
+ return TextToAudioPipeline(config)
111
+
112
+
113
+ def list_presets() -> dict[str, dict[str, str]]:
114
+ return dict(PRESETS)
115
+
116
+
117
+ class TextToAudioPipeline:
118
+ """
119
+ Wrapper around transformers text-to-audio pipeline with GPU opts,
120
+ streaming-style chunked output, and memory profiling.
121
+ """
122
+
123
+ def __init__(self, config: PipelineConfig) -> None:
124
+ self.config = config
125
+ self._pipe: Any = None
126
+ self._device = _infer_device()
127
+
128
+ def _ensure_loaded(self) -> None:
129
+ if self._pipe is not None:
130
+ return
131
+ model_kwargs = _get_model_kwargs(self.config)
132
+ try:
133
+ self._pipe = hf_pipeline(
134
+ task="text-to-audio",
135
+ model=self.config.model_id,
136
+ model_kwargs=model_kwargs,
137
+ device=self._device if self.config.device_map != "auto" else None,
138
+ )
139
+ except (TypeError, ValueError) as e:
140
+ if "attn_implementation" in str(e) or "flash_attention_2" in str(e).lower():
141
+ model_kwargs.pop("attn_implementation", None)
142
+ self._pipe = hf_pipeline(
143
+ task="text-to-audio",
144
+ model=self.config.model_id,
145
+ model_kwargs=model_kwargs,
146
+ device=self._device if self.config.device_map != "auto" else None,
147
+ )
148
+ else:
149
+ raise e
150
+
151
+ def generate(
152
+ self,
153
+ text: str | list[str],
154
+ forward_params: dict[str, Any] | None = None,
155
+ generate_kwargs: dict[str, Any] | None = None,
156
+ ) -> dict[str, Any] | list[dict[str, Any]]:
157
+ self._ensure_loaded()
158
+ forward_params = forward_params or {}
159
+ generate_kwargs = generate_kwargs or {}
160
+ if generate_kwargs:
161
+ forward_params["generate_kwargs"] = generate_kwargs
162
+ out = self._pipe(text, **forward_params)
163
+ return out
164
+
165
+ def generate_with_profile(
166
+ self,
167
+ text: str,
168
+ forward_params: dict[str, Any] | None = None,
169
+ generate_kwargs: dict[str, Any] | None = None,
170
+ ) -> tuple[dict[str, Any], dict[str, float]]:
171
+ self._ensure_loaded()
172
+ if torch.cuda.is_available():
173
+ torch.cuda.reset_peak_memory_stats()
174
+ torch.cuda.synchronize()
175
+ t0 = time.perf_counter()
176
+ result = self.generate(text, forward_params=forward_params, generate_kwargs=generate_kwargs)
177
+ if torch.cuda.is_available():
178
+ torch.cuda.synchronize()
179
+ elapsed = time.perf_counter() - t0
180
+ profile: dict[str, float] = {"time_s": elapsed}
181
+ if torch.cuda.is_available():
182
+ profile["vram_peak_mb"] = torch.cuda.max_memory_allocated() / (1024 * 1024)
183
+ single = result if isinstance(result, dict) else result[0]
184
+ if isinstance(single, dict) and "audio" in single:
185
+ sr = single.get("sampling_rate", 24000)
186
+ duration_s = single["audio"].size / max(sr, 1) if hasattr(single["audio"], "size") else 0.0
187
+ if duration_s > 0:
188
+ profile["rtf"] = elapsed / duration_s
189
+ profile["duration_s"] = duration_s
190
+ return result, profile
191
+
192
+ def stream_chunks(
193
+ self,
194
+ text: str,
195
+ chunk_duration_s: float = 0.5,
196
+ forward_params: dict[str, Any] | None = None,
197
+ generate_kwargs: dict[str, Any] | None = None,
198
+ ) -> Generator[tuple[np.ndarray, int], None, None]:
199
+ self._ensure_loaded()
200
+ out, _ = self.generate_with_profile(text, forward_params=forward_params, generate_kwargs=generate_kwargs)
201
+ single = out if isinstance(out, dict) else out[0]
202
+ audio = single["audio"]
203
+ sr = single["sampling_rate"]
204
+ if hasattr(audio, "numpy"):
205
+ arr = audio.numpy()
206
+ else:
207
+ arr = np.asarray(audio)
208
+ if arr.ndim == 1:
209
+ arr = arr[np.newaxis, :]
210
+ samples_per_chunk = int(chunk_duration_s * sr)
211
+ for start in range(0, arr.shape[-1], samples_per_chunk):
212
+ chunk = arr[..., start : start + samples_per_chunk]
213
+ if chunk.size == 0:
214
+ break
215
+ yield np.squeeze(chunk), sr
216
+
217
+ @property
218
+ def sampling_rate(self) -> int:
219
+ self._ensure_loaded()
220
+ return getattr(self._pipe, "sampling_rate", 24000)
221
+
222
+
223
+ def stream_audio_to_file(
224
+ chunk_iter: Iterator[tuple[np.ndarray, int]],
225
+ path: str,
226
+ subtype: str = "PCM_16",
227
+ ) -> None:
228
+ import soundfile as sf
229
+
230
+ first_chunk, sr = next(chunk_iter)
231
+ with sf.SoundFile(path, "w", samplerate=sr, channels=1 if first_chunk.ndim == 1 else first_chunk.shape[0], subtype=subtype) as f:
232
+ f.write(first_chunk)
233
+ for chunk, _ in chunk_iter:
234
+ f.write(chunk)
tests/test_pipeline.py ADDED
@@ -0,0 +1,39 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ """Unit tests for text_to_audio pipeline (no GPU or model download required)."""
2
+
3
+ from __future__ import annotations
4
+
5
+ import pytest
6
+
7
+ from src.text_to_audio import list_presets, build_pipeline, TextToAudioPipeline
8
+ from src.text_to_audio.pipeline import PipelineConfig, _get_model_kwargs, PRESETS
9
+
10
+
11
+ def test_list_presets() -> None:
12
+ presets = list_presets()
13
+ assert "csm-1b" in presets
14
+ assert "bark-small" in presets
15
+ assert presets["csm-1b"]["model_id"] == "sesame/csm-1b"
16
+
17
+
18
+ def test_build_pipeline_returns_wrapper() -> None:
19
+ pipe = build_pipeline(preset="csm-1b")
20
+ assert isinstance(pipe, TextToAudioPipeline)
21
+ assert pipe.config.model_id == "sesame/csm-1b"
22
+
23
+
24
+ def test_build_pipeline_custom_model_id() -> None:
25
+ pipe = build_pipeline(model_id="suno/bark-small")
26
+ assert pipe.config.model_id == "suno/bark-small"
27
+
28
+
29
+ def test_config_model_kwargs_4bit() -> None:
30
+ config = PipelineConfig(model_id="test", use_4bit=True, use_flash_attention_2=False)
31
+ kwargs = _get_model_kwargs(config)
32
+ assert "quantization_config" in kwargs
33
+ assert kwargs["torch_dtype"] is not None
34
+
35
+
36
+ def test_config_model_kwargs_flash_attn() -> None:
37
+ config = PipelineConfig(model_id="test", use_4bit=False, use_flash_attention_2=True)
38
+ kwargs = _get_model_kwargs(config)
39
+ assert kwargs.get("attn_implementation") == "flash_attention_2"