# diagnosis/ai_engine/model_loader.py """Singleton pattern for model loading This loader provides a clean interface for getting detector instances. Uses singleton pattern to ensure models are loaded only once. Supports both: - Legacy AdvancedStutterDetector (Whisper-based) - New SpeechPathologyClassifier (Wav2Vec2-XLSR-53 based) - New InferencePipeline (for real-time and batch processing) """ import logging import sys from pathlib import Path logger = logging.getLogger(__name__) # Add project root to path for imports _project_root = Path(__file__).parent.parent.parent if str(_project_root) not in sys.path: sys.path.insert(0, str(_project_root)) # Singleton instances _detector_instance = None _speech_pathology_model_instance = None _inference_pipeline_instance = None def get_stutter_detector(): """ Get or create singleton AdvancedStutterDetector instance (legacy). This ensures models are loaded only once and reused across requests. Returns: AdvancedStutterDetector: The singleton detector instance Raises: ImportError: If the detector class cannot be imported """ global _detector_instance if _detector_instance is None: try: from .detect_stuttering import AdvancedStutterDetector logger.info("🔄 Initializing detector instance (first call)...") _detector_instance = AdvancedStutterDetector() logger.info("✅ Detector instance created successfully") except ImportError as e: logger.error(f"❌ Failed to import AdvancedStutterDetector: {e}") raise ImportError("No StutterDetector implementation available in detect_stuttering.py") from e except Exception as e: logger.error(f"❌ Failed to create detector instance: {e}") raise return _detector_instance def get_speech_pathology_model(model_path: str = None): """ Get or create singleton SpeechPathologyClassifier instance. Uses Wav2Vec2-XLSR-53 with custom multi-task head for fluency and articulation. Args: model_path: Optional path to saved model checkpoint Returns: SpeechPathologyClassifier: The singleton model instance Raises: ImportError: If the model class cannot be imported RuntimeError: If model cannot be loaded """ global _speech_pathology_model_instance if _speech_pathology_model_instance is None: try: from models.speech_pathology_model import load_speech_pathology_model from config import default_model_config logger.info("🔄 Initializing SpeechPathologyClassifier (first call)...") _speech_pathology_model_instance = load_speech_pathology_model( model_name=default_model_config.model_name, classifier_hidden_dims=default_model_config.classifier_hidden_dims, dropout=default_model_config.dropout, device=default_model_config.device, use_fp16=default_model_config.use_fp16, model_path=model_path ) logger.info("✅ SpeechPathologyClassifier instance created successfully") except ImportError as e: logger.error(f"❌ Failed to import SpeechPathologyClassifier: {e}") raise ImportError("Failed to import SpeechPathologyClassifier. Check models package.") from e except Exception as e: logger.error(f"❌ Failed to create SpeechPathologyClassifier instance: {e}", exc_info=True) raise RuntimeError(f"Failed to load SpeechPathologyClassifier: {e}") from e return _speech_pathology_model_instance def get_inference_pipeline(model_path: str = None): """ Get or create singleton InferencePipeline instance. Provides both batch and streaming inference capabilities with phone-level analysis. Args: model_path: Optional path to saved model checkpoint Returns: InferencePipeline: The singleton pipeline instance Raises: ImportError: If the pipeline class cannot be imported RuntimeError: If pipeline cannot be initialized """ global _inference_pipeline_instance if _inference_pipeline_instance is None: try: from inference.inference_pipeline import InferencePipeline from config import default_audio_config, default_model_config, default_inference_config logger.info("🔄 Initializing InferencePipeline (first call)...") # Load model if path provided, otherwise let pipeline create it model = None if model_path: model = get_speech_pathology_model(model_path=model_path) _inference_pipeline_instance = InferencePipeline( model=model, audio_config=default_audio_config, model_config=default_model_config, inference_config=default_inference_config ) logger.info("✅ InferencePipeline instance created successfully") except ImportError as e: logger.error(f"❌ Failed to import InferencePipeline: {e}") raise ImportError("Failed to import InferencePipeline. Check inference package.") from e except Exception as e: logger.error(f"❌ Failed to create InferencePipeline instance: {e}", exc_info=True) raise RuntimeError(f"Failed to initialize InferencePipeline: {e}") from e return _inference_pipeline_instance def reset_detector(): """ Reset the singleton detector instance (useful for testing or reloading models). Note: This will force reloading of models on next get_stutter_detector() call. """ global _detector_instance _detector_instance = None logger.info("🔄 Detector instance reset") def reset_speech_pathology_model(): """ Reset the singleton SpeechPathologyClassifier instance. Note: This will force reloading of models on next get_speech_pathology_model() call. """ global _speech_pathology_model_instance _speech_pathology_model_instance = None logger.info("🔄 SpeechPathologyClassifier instance reset") def reset_inference_pipeline(): """ Reset the singleton InferencePipeline instance. Note: This will force reloading on next get_inference_pipeline() call. """ global _inference_pipeline_instance _inference_pipeline_instance = None logger.info("🔄 InferencePipeline instance reset") def reset_all(): """ Reset all singleton instances (useful for testing or complete reload). """ reset_detector() reset_speech_pathology_model() reset_inference_pipeline() logger.info("🔄 All model instances reset")