anfastech's picture
Feat: project structure and configuration
79f7931
# diagnosis/ai_engine/model_loader.py
"""Singleton pattern for model loading
This loader provides a clean interface for getting detector instances.
Uses singleton pattern to ensure models are loaded only once.
Supports both:
- Legacy AdvancedStutterDetector (Whisper-based)
- New SpeechPathologyClassifier (Wav2Vec2-XLSR-53 based)
- New InferencePipeline (for real-time and batch processing)
"""
import logging
import sys
from pathlib import Path
logger = logging.getLogger(__name__)
# Add project root to path for imports
_project_root = Path(__file__).parent.parent.parent
if str(_project_root) not in sys.path:
sys.path.insert(0, str(_project_root))
# Singleton instances
_detector_instance = None
_speech_pathology_model_instance = None
_inference_pipeline_instance = None
def get_stutter_detector():
"""
Get or create singleton AdvancedStutterDetector instance (legacy).
This ensures models are loaded only once and reused across requests.
Returns:
AdvancedStutterDetector: The singleton detector instance
Raises:
ImportError: If the detector class cannot be imported
"""
global _detector_instance
if _detector_instance is None:
try:
from .detect_stuttering import AdvancedStutterDetector
logger.info("πŸ”„ Initializing detector instance (first call)...")
_detector_instance = AdvancedStutterDetector()
logger.info("βœ… Detector instance created successfully")
except ImportError as e:
logger.error(f"❌ Failed to import AdvancedStutterDetector: {e}")
raise ImportError("No StutterDetector implementation available in detect_stuttering.py") from e
except Exception as e:
logger.error(f"❌ Failed to create detector instance: {e}")
raise
return _detector_instance
def get_speech_pathology_model(model_path: str = None):
"""
Get or create singleton SpeechPathologyClassifier instance.
Uses Wav2Vec2-XLSR-53 with custom multi-task head for fluency and articulation.
Args:
model_path: Optional path to saved model checkpoint
Returns:
SpeechPathologyClassifier: The singleton model instance
Raises:
ImportError: If the model class cannot be imported
RuntimeError: If model cannot be loaded
"""
global _speech_pathology_model_instance
if _speech_pathology_model_instance is None:
try:
from models.speech_pathology_model import load_speech_pathology_model
from config import default_model_config
logger.info("πŸ”„ Initializing SpeechPathologyClassifier (first call)...")
_speech_pathology_model_instance = load_speech_pathology_model(
model_name=default_model_config.model_name,
classifier_hidden_dims=default_model_config.classifier_hidden_dims,
dropout=default_model_config.dropout,
device=default_model_config.device,
use_fp16=default_model_config.use_fp16,
model_path=model_path
)
logger.info("βœ… SpeechPathologyClassifier instance created successfully")
except ImportError as e:
logger.error(f"❌ Failed to import SpeechPathologyClassifier: {e}")
raise ImportError("Failed to import SpeechPathologyClassifier. Check models package.") from e
except Exception as e:
logger.error(f"❌ Failed to create SpeechPathologyClassifier instance: {e}", exc_info=True)
raise RuntimeError(f"Failed to load SpeechPathologyClassifier: {e}") from e
return _speech_pathology_model_instance
def get_inference_pipeline(model_path: str = None):
"""
Get or create singleton InferencePipeline instance.
Provides both batch and streaming inference capabilities with phone-level analysis.
Args:
model_path: Optional path to saved model checkpoint
Returns:
InferencePipeline: The singleton pipeline instance
Raises:
ImportError: If the pipeline class cannot be imported
RuntimeError: If pipeline cannot be initialized
"""
global _inference_pipeline_instance
if _inference_pipeline_instance is None:
try:
from inference.inference_pipeline import InferencePipeline
from config import default_audio_config, default_model_config, default_inference_config
logger.info("πŸ”„ Initializing InferencePipeline (first call)...")
# Load model if path provided, otherwise let pipeline create it
model = None
if model_path:
model = get_speech_pathology_model(model_path=model_path)
_inference_pipeline_instance = InferencePipeline(
model=model,
audio_config=default_audio_config,
model_config=default_model_config,
inference_config=default_inference_config
)
logger.info("βœ… InferencePipeline instance created successfully")
except ImportError as e:
logger.error(f"❌ Failed to import InferencePipeline: {e}")
raise ImportError("Failed to import InferencePipeline. Check inference package.") from e
except Exception as e:
logger.error(f"❌ Failed to create InferencePipeline instance: {e}", exc_info=True)
raise RuntimeError(f"Failed to initialize InferencePipeline: {e}") from e
return _inference_pipeline_instance
def reset_detector():
"""
Reset the singleton detector instance (useful for testing or reloading models).
Note: This will force reloading of models on next get_stutter_detector() call.
"""
global _detector_instance
_detector_instance = None
logger.info("πŸ”„ Detector instance reset")
def reset_speech_pathology_model():
"""
Reset the singleton SpeechPathologyClassifier instance.
Note: This will force reloading of models on next get_speech_pathology_model() call.
"""
global _speech_pathology_model_instance
_speech_pathology_model_instance = None
logger.info("πŸ”„ SpeechPathologyClassifier instance reset")
def reset_inference_pipeline():
"""
Reset the singleton InferencePipeline instance.
Note: This will force reloading on next get_inference_pipeline() call.
"""
global _inference_pipeline_instance
_inference_pipeline_instance = None
logger.info("πŸ”„ InferencePipeline instance reset")
def reset_all():
"""
Reset all singleton instances (useful for testing or complete reload).
"""
reset_detector()
reset_speech_pathology_model()
reset_inference_pipeline()
logger.info("πŸ”„ All model instances reset")