Michael Hu
fix(stt): handle whisper-large model name as alias for whisper provider
b10a453
"""Factory for creating STT provider instances."""
import logging
from typing import Dict, Type, Optional
from ..base.stt_provider_base import STTProviderBase
from .whisper_provider import WhisperSTTProvider
from ...domain.exceptions import SpeechRecognitionException
logger = logging.getLogger(__name__)
class STTProviderFactory:
"""Factory for creating STT provider instances with availability checking and fallback logic."""
_providers: Dict[str, Type[STTProviderBase]] = {
"whisper": WhisperSTTProvider
}
_fallback_order = ["whisper"]
@classmethod
def create_provider(cls, provider_name: str) -> STTProviderBase:
"""
Create an STT provider instance by name.
Args:
provider_name: Name of the provider to create
Returns:
STTProviderBase: The created provider instance
Raises:
SpeechRecognitionException: If provider is not available or creation fails
"""
provider_name = provider_name.lower()
# Debug logging
logger.info(f"Attempting to create STT provider: '{provider_name}'")
logger.info(f"Available providers: {list(cls._providers.keys())}")
if provider_name not in cls._providers:
# Simple handling for whisper-large - just use whisper provider
if provider_name == "whisper-large":
logger.info("whisper-large requested, using whisper provider")
provider_name = "whisper"
else:
# Check if this is a model name that should be mapped to a provider
mapped_provider = cls._map_model_to_provider(provider_name)
if mapped_provider:
logger.info(f"Mapped model '{provider_name}' to provider '{mapped_provider}'")
provider_name = mapped_provider
else:
logger.error(f"Unknown STT provider: {provider_name}. Available: {list(cls._providers.keys())}")
raise SpeechRecognitionException(f"Unknown STT provider: {provider_name}")
provider_class = cls._providers[provider_name]
try:
provider = provider_class()
if not provider.is_available():
raise SpeechRecognitionException(f"STT provider {provider_name} is not available")
logger.info(f"Created STT provider: {provider_name}")
return provider
except Exception as e:
logger.error(f"Failed to create STT provider {provider_name}: {str(e)}")
raise SpeechRecognitionException(f"Failed to create STT provider {provider_name}: {str(e)}") from e
@classmethod
def create_provider_with_fallback(cls, preferred_provider: str) -> STTProviderBase:
"""
Create an STT provider with fallback to other available providers.
Args:
preferred_provider: The preferred provider name
Returns:
STTProviderBase: The created provider instance
Raises:
SpeechRecognitionException: If no providers are available
"""
# Try preferred provider first
try:
return cls.create_provider(preferred_provider)
except SpeechRecognitionException as e:
logger.warning(f"Preferred STT provider {preferred_provider} failed: {str(e)}")
# Try fallback providers
for provider_name in cls._fallback_order:
if provider_name.lower() == preferred_provider.lower():
continue # Skip the preferred provider we already tried
try:
logger.info(f"Trying fallback STT provider: {provider_name}")
return cls.create_provider(provider_name)
except SpeechRecognitionException as e:
logger.warning(f"Fallback STT provider {provider_name} failed: {str(e)}")
continue
raise SpeechRecognitionException("No STT providers are available")
@classmethod
def get_available_providers(cls) -> list[str]:
"""
Get list of available STT providers.
Returns:
list[str]: List of available provider names
"""
available = []
for provider_name, provider_class in cls._providers.items():
try:
provider = provider_class()
if provider.is_available():
available.append(provider_name)
except Exception as e:
logger.info(f"Provider {provider_name} not available: {str(e)}")
return available
@classmethod
def get_provider_info(cls, provider_name: str) -> Optional[dict]:
"""
Get information about a specific provider.
Args:
provider_name: Name of the provider
Returns:
Optional[dict]: Provider information or None if not found
"""
provider_name = provider_name.lower()
if provider_name not in cls._providers:
return None
provider_class = cls._providers[provider_name]
try:
provider = provider_class()
return {
"name": provider.provider_name,
"available": provider.is_available(),
"supported_languages": provider.supported_languages,
"available_models": provider.get_available_models() if provider.is_available() else [],
"default_model": provider.get_default_model() if provider.is_available() else None
}
except Exception as e:
logger.info(f"Failed to get info for provider {provider_name}: {str(e)}")
return {
"name": provider_name,
"available": False,
"error": str(e)
}
@classmethod
def register_provider(cls, name: str, provider_class: Type[STTProviderBase]) -> None:
"""
Register a new STT provider.
Args:
name: Name of the provider
provider_class: The provider class
"""
cls._providers[name.lower()] = provider_class
logger.info(f"Registered STT provider: {name}")
@classmethod
def _map_model_to_provider(cls, model_name: str) -> Optional[str]:
"""
Map a specific model name to a provider name.
Args:
model_name: The model name to map
Returns:
Optional[str]: The provider name if mapping exists, None otherwise
"""
# Define model-to-provider mappings
model_mapping = {
# Whisper model variants -> whisper provider
'whisper-large': 'whisper',
'whisper-large-v1': 'whisper',
'whisper-large-v2': 'whisper',
'whisper-large-v3': 'whisper',
'whisper-medium': 'whisper',
'whisper-medium.en': 'whisper',
'whisper-small': 'whisper',
'whisper-small.en': 'whisper',
'whisper-base': 'whisper',
'whisper-base.en': 'whisper',
'whisper-tiny': 'whisper',
'whisper-tiny.en': 'whisper',
# Legacy model names
'faster-whisper': 'whisper',
'openai-whisper': 'whisper',
}
# Try exact match first
if model_name.lower() in model_mapping:
return model_mapping[model_name.lower()]
# Try prefix matching (e.g., "whisper-large" matches "whisper")
for model_prefix, provider in model_mapping.items():
if model_name.lower().startswith(model_prefix.lower()):
logger.info(f"Prefix match: '{model_name}' -> '{provider}' (matched '{model_prefix}')")
return provider
return None
# Legacy compatibility - create an ASRFactory alias
class ASRFactory:
"""Legacy ASRFactory for backward compatibility."""
@staticmethod
def get_model(model_name: str = "whisper") -> STTProviderBase:
"""
Get STT provider by model name (legacy interface).
Args:
model_name: Name of the model/provider to use
Returns:
STTProviderBase: The provider instance
"""
# Map legacy model names to provider names
provider_mapping = {
"whisper": "whisper",
"faster-whisper": "whisper"
}
provider_name = provider_mapping.get(model_name.lower(), model_name.lower())
try:
return STTProviderFactory.create_provider(provider_name)
except SpeechRecognitionException:
# Fallback to any available provider
logger.warning(f"Requested provider {provider_name} not available, using fallback")
return STTProviderFactory.create_provider_with_fallback(provider_name)