Spaces:
Runtime error
Runtime error
| """Unit tests for text_to_audio pipeline (no GPU or model download required).""" | |
| from __future__ import annotations | |
| import pytest | |
| from src.text_to_audio import list_presets, build_pipeline, TextToAudioPipeline | |
| from src.text_to_audio.pipeline import PipelineConfig, _get_model_kwargs, PRESETS | |
| def test_list_presets() -> None: | |
| presets = list_presets() | |
| assert "csm-1b" in presets | |
| assert "bark-small" in presets | |
| assert presets["csm-1b"]["model_id"] == "sesame/csm-1b" | |
| def test_build_pipeline_returns_wrapper() -> None: | |
| pipe = build_pipeline(preset="csm-1b") | |
| assert isinstance(pipe, TextToAudioPipeline) | |
| assert pipe.config.model_id == "sesame/csm-1b" | |
| def test_build_pipeline_custom_model_id() -> None: | |
| pipe = build_pipeline(model_id="suno/bark-small") | |
| assert pipe.config.model_id == "suno/bark-small" | |
| def test_config_model_kwargs_4bit() -> None: | |
| config = PipelineConfig(model_id="test", use_4bit=True, use_flash_attention_2=False) | |
| kwargs = _get_model_kwargs(config) | |
| assert "quantization_config" in kwargs | |
| assert kwargs["torch_dtype"] is not None | |
| def test_config_model_kwargs_flash_attn() -> None: | |
| config = PipelineConfig(model_id="test", use_4bit=False, use_flash_attention_2=True) | |
| kwargs = _get_model_kwargs(config) | |
| assert kwargs.get("attn_implementation") == "flash_attention_2" | |