Spaces:
Sleeping
Sleeping
| """Factory for creating STT provider instances.""" | |
| import logging | |
| from typing import Dict, Type, Optional | |
| from ..base.stt_provider_base import STTProviderBase | |
| from .whisper_provider import WhisperSTTProvider | |
| from ...domain.exceptions import SpeechRecognitionException | |
| logger = logging.getLogger(__name__) | |
| class STTProviderFactory: | |
| """Factory for creating STT provider instances with availability checking and fallback logic.""" | |
| _providers: Dict[str, Type[STTProviderBase]] = { | |
| "whisper": WhisperSTTProvider | |
| } | |
| _fallback_order = ["whisper"] | |
| def create_provider(cls, provider_name: str) -> STTProviderBase: | |
| """ | |
| Create an STT provider instance by name. | |
| Args: | |
| provider_name: Name of the provider to create | |
| Returns: | |
| STTProviderBase: The created provider instance | |
| Raises: | |
| SpeechRecognitionException: If provider is not available or creation fails | |
| """ | |
| provider_name = provider_name.lower() | |
| # Debug logging | |
| logger.info(f"Attempting to create STT provider: '{provider_name}'") | |
| logger.info(f"Available providers: {list(cls._providers.keys())}") | |
| if provider_name not in cls._providers: | |
| # Simple handling for whisper-large - just use whisper provider | |
| if provider_name == "whisper-large": | |
| logger.info("whisper-large requested, using whisper provider") | |
| provider_name = "whisper" | |
| else: | |
| # Check if this is a model name that should be mapped to a provider | |
| mapped_provider = cls._map_model_to_provider(provider_name) | |
| if mapped_provider: | |
| logger.info(f"Mapped model '{provider_name}' to provider '{mapped_provider}'") | |
| provider_name = mapped_provider | |
| else: | |
| logger.error(f"Unknown STT provider: {provider_name}. Available: {list(cls._providers.keys())}") | |
| raise SpeechRecognitionException(f"Unknown STT provider: {provider_name}") | |
| provider_class = cls._providers[provider_name] | |
| try: | |
| provider = provider_class() | |
| if not provider.is_available(): | |
| raise SpeechRecognitionException(f"STT provider {provider_name} is not available") | |
| logger.info(f"Created STT provider: {provider_name}") | |
| return provider | |
| except Exception as e: | |
| logger.error(f"Failed to create STT provider {provider_name}: {str(e)}") | |
| raise SpeechRecognitionException(f"Failed to create STT provider {provider_name}: {str(e)}") from e | |
| def create_provider_with_fallback(cls, preferred_provider: str) -> STTProviderBase: | |
| """ | |
| Create an STT provider with fallback to other available providers. | |
| Args: | |
| preferred_provider: The preferred provider name | |
| Returns: | |
| STTProviderBase: The created provider instance | |
| Raises: | |
| SpeechRecognitionException: If no providers are available | |
| """ | |
| # Try preferred provider first | |
| try: | |
| return cls.create_provider(preferred_provider) | |
| except SpeechRecognitionException as e: | |
| logger.warning(f"Preferred STT provider {preferred_provider} failed: {str(e)}") | |
| # Try fallback providers | |
| for provider_name in cls._fallback_order: | |
| if provider_name.lower() == preferred_provider.lower(): | |
| continue # Skip the preferred provider we already tried | |
| try: | |
| logger.info(f"Trying fallback STT provider: {provider_name}") | |
| return cls.create_provider(provider_name) | |
| except SpeechRecognitionException as e: | |
| logger.warning(f"Fallback STT provider {provider_name} failed: {str(e)}") | |
| continue | |
| raise SpeechRecognitionException("No STT providers are available") | |
| def get_available_providers(cls) -> list[str]: | |
| """ | |
| Get list of available STT providers. | |
| Returns: | |
| list[str]: List of available provider names | |
| """ | |
| available = [] | |
| for provider_name, provider_class in cls._providers.items(): | |
| try: | |
| provider = provider_class() | |
| if provider.is_available(): | |
| available.append(provider_name) | |
| except Exception as e: | |
| logger.info(f"Provider {provider_name} not available: {str(e)}") | |
| return available | |
| def get_provider_info(cls, provider_name: str) -> Optional[dict]: | |
| """ | |
| Get information about a specific provider. | |
| Args: | |
| provider_name: Name of the provider | |
| Returns: | |
| Optional[dict]: Provider information or None if not found | |
| """ | |
| provider_name = provider_name.lower() | |
| if provider_name not in cls._providers: | |
| return None | |
| provider_class = cls._providers[provider_name] | |
| try: | |
| provider = provider_class() | |
| return { | |
| "name": provider.provider_name, | |
| "available": provider.is_available(), | |
| "supported_languages": provider.supported_languages, | |
| "available_models": provider.get_available_models() if provider.is_available() else [], | |
| "default_model": provider.get_default_model() if provider.is_available() else None | |
| } | |
| except Exception as e: | |
| logger.info(f"Failed to get info for provider {provider_name}: {str(e)}") | |
| return { | |
| "name": provider_name, | |
| "available": False, | |
| "error": str(e) | |
| } | |
| def register_provider(cls, name: str, provider_class: Type[STTProviderBase]) -> None: | |
| """ | |
| Register a new STT provider. | |
| Args: | |
| name: Name of the provider | |
| provider_class: The provider class | |
| """ | |
| cls._providers[name.lower()] = provider_class | |
| logger.info(f"Registered STT provider: {name}") | |
| def _map_model_to_provider(cls, model_name: str) -> Optional[str]: | |
| """ | |
| Map a specific model name to a provider name. | |
| Args: | |
| model_name: The model name to map | |
| Returns: | |
| Optional[str]: The provider name if mapping exists, None otherwise | |
| """ | |
| # Define model-to-provider mappings | |
| model_mapping = { | |
| # Whisper model variants -> whisper provider | |
| 'whisper-large': 'whisper', | |
| 'whisper-large-v1': 'whisper', | |
| 'whisper-large-v2': 'whisper', | |
| 'whisper-large-v3': 'whisper', | |
| 'whisper-medium': 'whisper', | |
| 'whisper-medium.en': 'whisper', | |
| 'whisper-small': 'whisper', | |
| 'whisper-small.en': 'whisper', | |
| 'whisper-base': 'whisper', | |
| 'whisper-base.en': 'whisper', | |
| 'whisper-tiny': 'whisper', | |
| 'whisper-tiny.en': 'whisper', | |
| # Legacy model names | |
| 'faster-whisper': 'whisper', | |
| 'openai-whisper': 'whisper', | |
| } | |
| # Try exact match first | |
| if model_name.lower() in model_mapping: | |
| return model_mapping[model_name.lower()] | |
| # Try prefix matching (e.g., "whisper-large" matches "whisper") | |
| for model_prefix, provider in model_mapping.items(): | |
| if model_name.lower().startswith(model_prefix.lower()): | |
| logger.info(f"Prefix match: '{model_name}' -> '{provider}' (matched '{model_prefix}')") | |
| return provider | |
| return None | |
| # Legacy compatibility - create an ASRFactory alias | |
| class ASRFactory: | |
| """Legacy ASRFactory for backward compatibility.""" | |
| def get_model(model_name: str = "whisper") -> STTProviderBase: | |
| """ | |
| Get STT provider by model name (legacy interface). | |
| Args: | |
| model_name: Name of the model/provider to use | |
| Returns: | |
| STTProviderBase: The provider instance | |
| """ | |
| # Map legacy model names to provider names | |
| provider_mapping = { | |
| "whisper": "whisper", | |
| "faster-whisper": "whisper" | |
| } | |
| provider_name = provider_mapping.get(model_name.lower(), model_name.lower()) | |
| try: | |
| return STTProviderFactory.create_provider(provider_name) | |
| except SpeechRecognitionException: | |
| # Fallback to any available provider | |
| logger.warning(f"Requested provider {provider_name} not available, using fallback") | |
| return STTProviderFactory.create_provider_with_fallback(provider_name) |