| | |
| | """Singleton pattern for model loading |
| | |
| | This loader provides a clean interface for getting detector instances. |
| | Uses singleton pattern to ensure models are loaded only once. |
| | |
| | Supports both: |
| | - Legacy AdvancedStutterDetector (Whisper-based) |
| | - New SpeechPathologyClassifier (Wav2Vec2-XLSR-53 based) |
| | - New InferencePipeline (for real-time and batch processing) |
| | """ |
| | import logging |
| | import sys |
| | from pathlib import Path |
| |
|
| | logger = logging.getLogger(__name__) |
| |
|
| | |
| | _project_root = Path(__file__).parent.parent.parent |
| | if str(_project_root) not in sys.path: |
| | sys.path.insert(0, str(_project_root)) |
| |
|
| | |
| | _detector_instance = None |
| | _speech_pathology_model_instance = None |
| | _inference_pipeline_instance = None |
| |
|
| |
|
| | def get_stutter_detector(): |
| | """ |
| | Get or create singleton AdvancedStutterDetector instance (legacy). |
| | |
| | This ensures models are loaded only once and reused across requests. |
| | |
| | Returns: |
| | AdvancedStutterDetector: The singleton detector instance |
| | |
| | Raises: |
| | ImportError: If the detector class cannot be imported |
| | """ |
| | global _detector_instance |
| | |
| | if _detector_instance is None: |
| | try: |
| | from .detect_stuttering import AdvancedStutterDetector |
| | logger.info("π Initializing detector instance (first call)...") |
| | _detector_instance = AdvancedStutterDetector() |
| | logger.info("β
Detector instance created successfully") |
| | except ImportError as e: |
| | logger.error(f"β Failed to import AdvancedStutterDetector: {e}") |
| | raise ImportError("No StutterDetector implementation available in detect_stuttering.py") from e |
| | except Exception as e: |
| | logger.error(f"β Failed to create detector instance: {e}") |
| | raise |
| | |
| | return _detector_instance |
| |
|
| |
|
| | def get_speech_pathology_model(model_path: str = None): |
| | """ |
| | Get or create singleton SpeechPathologyClassifier instance. |
| | |
| | Uses Wav2Vec2-XLSR-53 with custom multi-task head for fluency and articulation. |
| | |
| | Args: |
| | model_path: Optional path to saved model checkpoint |
| | |
| | Returns: |
| | SpeechPathologyClassifier: The singleton model instance |
| | |
| | Raises: |
| | ImportError: If the model class cannot be imported |
| | RuntimeError: If model cannot be loaded |
| | """ |
| | global _speech_pathology_model_instance |
| | |
| | if _speech_pathology_model_instance is None: |
| | try: |
| | from models.speech_pathology_model import load_speech_pathology_model |
| | from config import default_model_config |
| | |
| | logger.info("π Initializing SpeechPathologyClassifier (first call)...") |
| | _speech_pathology_model_instance = load_speech_pathology_model( |
| | model_name=default_model_config.model_name, |
| | classifier_hidden_dims=default_model_config.classifier_hidden_dims, |
| | dropout=default_model_config.dropout, |
| | device=default_model_config.device, |
| | use_fp16=default_model_config.use_fp16, |
| | model_path=model_path |
| | ) |
| | logger.info("β
SpeechPathologyClassifier instance created successfully") |
| | except ImportError as e: |
| | logger.error(f"β Failed to import SpeechPathologyClassifier: {e}") |
| | raise ImportError("Failed to import SpeechPathologyClassifier. Check models package.") from e |
| | except Exception as e: |
| | logger.error(f"β Failed to create SpeechPathologyClassifier instance: {e}", exc_info=True) |
| | raise RuntimeError(f"Failed to load SpeechPathologyClassifier: {e}") from e |
| | |
| | return _speech_pathology_model_instance |
| |
|
| |
|
| | def get_inference_pipeline(model_path: str = None): |
| | """ |
| | Get or create singleton InferencePipeline instance. |
| | |
| | Provides both batch and streaming inference capabilities with phone-level analysis. |
| | |
| | Args: |
| | model_path: Optional path to saved model checkpoint |
| | |
| | Returns: |
| | InferencePipeline: The singleton pipeline instance |
| | |
| | Raises: |
| | ImportError: If the pipeline class cannot be imported |
| | RuntimeError: If pipeline cannot be initialized |
| | """ |
| | global _inference_pipeline_instance |
| | |
| | if _inference_pipeline_instance is None: |
| | try: |
| | from inference.inference_pipeline import InferencePipeline |
| | from config import default_audio_config, default_model_config, default_inference_config |
| | |
| | logger.info("π Initializing InferencePipeline (first call)...") |
| | |
| | |
| | model = None |
| | if model_path: |
| | model = get_speech_pathology_model(model_path=model_path) |
| | |
| | _inference_pipeline_instance = InferencePipeline( |
| | model=model, |
| | audio_config=default_audio_config, |
| | model_config=default_model_config, |
| | inference_config=default_inference_config |
| | ) |
| | logger.info("β
InferencePipeline instance created successfully") |
| | except ImportError as e: |
| | logger.error(f"β Failed to import InferencePipeline: {e}") |
| | raise ImportError("Failed to import InferencePipeline. Check inference package.") from e |
| | except Exception as e: |
| | logger.error(f"β Failed to create InferencePipeline instance: {e}", exc_info=True) |
| | raise RuntimeError(f"Failed to initialize InferencePipeline: {e}") from e |
| | |
| | return _inference_pipeline_instance |
| |
|
| |
|
| | def reset_detector(): |
| | """ |
| | Reset the singleton detector instance (useful for testing or reloading models). |
| | |
| | Note: This will force reloading of models on next get_stutter_detector() call. |
| | """ |
| | global _detector_instance |
| | _detector_instance = None |
| | logger.info("π Detector instance reset") |
| |
|
| |
|
| | def reset_speech_pathology_model(): |
| | """ |
| | Reset the singleton SpeechPathologyClassifier instance. |
| | |
| | Note: This will force reloading of models on next get_speech_pathology_model() call. |
| | """ |
| | global _speech_pathology_model_instance |
| | _speech_pathology_model_instance = None |
| | logger.info("π SpeechPathologyClassifier instance reset") |
| |
|
| |
|
| | def reset_inference_pipeline(): |
| | """ |
| | Reset the singleton InferencePipeline instance. |
| | |
| | Note: This will force reloading on next get_inference_pipeline() call. |
| | """ |
| | global _inference_pipeline_instance |
| | _inference_pipeline_instance = None |
| | logger.info("π InferencePipeline instance reset") |
| |
|
| |
|
| | def reset_all(): |
| | """ |
| | Reset all singleton instances (useful for testing or complete reload). |
| | """ |
| | reset_detector() |
| | reset_speech_pathology_model() |
| | reset_inference_pipeline() |
| | logger.info("π All model instances reset") |
| |
|
| |
|