DAW_Sampler_Invader / tests /test_pipeline.py
Keith
Initial commit for HF Space
e3f3734
"""Unit tests for text_to_audio pipeline (no GPU or model download required)."""
from __future__ import annotations
import pytest
from src.text_to_audio import list_presets, build_pipeline, TextToAudioPipeline
from src.text_to_audio.pipeline import PipelineConfig, _get_model_kwargs, PRESETS
def test_list_presets() -> None:
presets = list_presets()
assert "csm-1b" in presets
assert "bark-small" in presets
assert presets["csm-1b"]["model_id"] == "sesame/csm-1b"
def test_build_pipeline_returns_wrapper() -> None:
pipe = build_pipeline(preset="csm-1b")
assert isinstance(pipe, TextToAudioPipeline)
assert pipe.config.model_id == "sesame/csm-1b"
def test_build_pipeline_custom_model_id() -> None:
pipe = build_pipeline(model_id="suno/bark-small")
assert pipe.config.model_id == "suno/bark-small"
def test_config_model_kwargs_4bit() -> None:
config = PipelineConfig(model_id="test", use_4bit=True, use_flash_attention_2=False)
kwargs = _get_model_kwargs(config)
assert "quantization_config" in kwargs
assert kwargs["torch_dtype"] is not None
def test_config_model_kwargs_flash_attn() -> None:
config = PipelineConfig(model_id="test", use_4bit=False, use_flash_attention_2=True)
kwargs = _get_model_kwargs(config)
assert kwargs.get("attn_implementation") == "flash_attention_2"