File size: 2,399 Bytes
ef0d09e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1cbd55c
ef0d09e
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
from abc import ABC, abstractmethod
from typing import Tuple

import numpy as np

from ..models.base import BaseTTSConfig


class TTSInterface(ABC):
    """TTS服务的抽象接口"""

    def __init__(self, config: BaseTTSConfig):
        self.config = config
        self._is_ready = False

    @abstractmethod
    def setup(self, **kwargs) -> None:
        """
        初始化TTS服务
        
        Args:
            **kwargs: 额外的初始化参数
        """
        pass

    @abstractmethod
    def warmup(self, warmup_steps: int = 1) -> None:
        """
        预热TTS引擎
        
        Args:
            warmup_steps: 预热步数
        """
        pass

    @abstractmethod
    def synthesize(self, text: str, **kwargs) -> Tuple[np.ndarray, int]:
        """
        将文本转换为语音
        
        Args:
            text: 要转换的文本
            **kwargs: 额外的合成参数
            
        Returns:
            Tuple[np.ndarray, int]: (音频数据, 采样率)
        """
        pass

    @property
    def is_ready(self) -> bool:
        """
        检查TTS服务是否准备就绪
        
        Returns:
            bool: 是否准备就绪
        """
        return self._is_ready

    @is_ready.setter
    def is_ready(self, value: bool):
        self._is_ready = value

    def get_config(self) -> BaseTTSConfig:
        """获取当前配置"""
        return self.config


class TTSFactory:
    """TTS工厂类,用于创建不同的TTS实现"""

    _registry = {}

    @classmethod
    def register(cls, provider_name: str, tts_class):
        """注册TTS提供者"""
        cls._registry[provider_name] = tts_class

    @classmethod
    def create(cls, config: BaseTTSConfig) -> TTSInterface:
        """
        根据配置创建TTS实例
        
        Args:
            config: TTS配置
            
        Returns:
            TTSInterface: TTS实例
            
        Raises:
            ValueError: 不支持的TTS提供者
        """
        provider = config.provider.value
        if provider not in cls._registry:
            raise ValueError(f"不支持的TTS提供者: {provider}")

        tts_class = cls._registry[provider]
        return tts_class(config)

    @classmethod
    def list_providers(cls):
        """列出所有已注册的TTS提供者"""
        return list(cls._registry.keys())