zlaqa-version-c-ai-enginee / test_model_loading.py
anfastech's picture
create models/speech_pathology_model.py, Create inference/inference_pipeline.py, Integrate with existing diagnosis/ai_engine/model_loader.py (singleton pattern)
1c1f628
"""
Test script for model loading and inference with dummy audio.
This script verifies that:
1. SpeechPathologyClassifier can be loaded
2. InferencePipeline can be initialized
3. Dummy audio can be processed
4. Predictions are generated correctly
"""
import logging
import sys
import numpy as np
from pathlib import Path
# Configure logging
logging.basicConfig(
level=logging.INFO,
format='%(asctime)s - %(name)s - %(levelname)s - %(message)s',
stream=sys.stdout
)
logger = logging.getLogger(__name__)
# Add project root to path
sys.path.insert(0, str(Path(__file__).parent))
def generate_dummy_audio(duration_seconds: float = 1.0, sample_rate: int = 16000) -> np.ndarray:
"""
Generate dummy audio signal for testing.
Args:
duration_seconds: Duration of audio in seconds
sample_rate: Sample rate
Returns:
Audio array normalized to [-1, 1]
"""
num_samples = int(duration_seconds * sample_rate)
# Generate a simple sine wave with some noise
t = np.linspace(0, duration_seconds, num_samples)
frequency = 440.0 # A4 note
audio = np.sin(2 * np.pi * frequency * t)
# Add some noise
noise = np.random.normal(0, 0.1, num_samples)
audio = audio + noise
# Normalize to [-1, 1]
audio = audio / np.abs(audio).max()
return audio
def test_model_loading():
"""Test that SpeechPathologyClassifier can be loaded."""
logger.info("=" * 60)
logger.info("TEST 1: Model Loading")
logger.info("=" * 60)
try:
from diagnosis.ai_engine.model_loader import get_speech_pathology_model
logger.info("Loading SpeechPathologyClassifier...")
model = get_speech_pathology_model()
logger.info(f"βœ… Model loaded successfully")
logger.info(f" Model type: {type(model).__name__}")
logger.info(f" Device: {model.device}")
logger.info(f" Articulation classes: {model.ARTICULATION_CLASSES}")
return model
except Exception as e:
logger.error(f"❌ Model loading failed: {e}", exc_info=True)
return None
def test_inference_pipeline_loading():
"""Test that InferencePipeline can be initialized."""
logger.info("\n" + "=" * 60)
logger.info("TEST 2: Inference Pipeline Loading")
logger.info("=" * 60)
try:
from diagnosis.ai_engine.model_loader import get_inference_pipeline
logger.info("Loading InferencePipeline...")
pipeline = get_inference_pipeline()
logger.info(f"βœ… InferencePipeline loaded successfully")
logger.info(f" Pipeline type: {type(pipeline).__name__}")
logger.info(f" Frame size: {pipeline.frame_size_samples} samples")
logger.info(f" Sample rate: {pipeline.audio_config.sample_rate} Hz")
return pipeline
except Exception as e:
logger.error(f"❌ InferencePipeline loading failed: {e}", exc_info=True)
return None
def test_dummy_audio_prediction(pipeline):
"""Test prediction on dummy audio."""
logger.info("\n" + "=" * 60)
logger.info("TEST 3: Dummy Audio Prediction")
logger.info("=" * 60)
if pipeline is None:
logger.error("❌ Cannot test prediction: pipeline is None")
return False
try:
# Generate dummy audio (1 second)
logger.info("Generating dummy audio (1 second)...")
dummy_audio = generate_dummy_audio(duration_seconds=1.0, sample_rate=16000)
logger.info(f" Audio shape: {dummy_audio.shape}")
logger.info(f" Audio range: [{dummy_audio.min():.3f}, {dummy_audio.max():.3f}]")
# Test streaming prediction
logger.info("\nTesting streaming prediction...")
chunk_size = pipeline.frame_size_samples
chunk = dummy_audio[:chunk_size]
result = pipeline.predict_streaming(chunk, frame_index=0, timestamp_ms=0.0)
logger.info(f"βœ… Streaming prediction successful")
logger.info(f" Fluency score: {result.fluency_score:.4f}")
logger.info(f" Articulation class: {result.articulation_class} "
f"({result.articulation_class_name})")
logger.info(f" Articulation probs: {[f'{p:.4f}' for p in result.articulation_probs]}")
logger.info(f" Confidence: {result.confidence:.4f}")
# Test batch prediction
logger.info("\nTesting batch prediction...")
batch_result = pipeline.predict_batch(dummy_audio, return_timestamps=True)
logger.info(f"βœ… Batch prediction successful")
logger.info(f" Processed {len(batch_result.articulation_scores)} frames")
logger.info(f" Fluency metrics:")
for key, value in batch_result.fluency_metrics.items():
logger.info(f" {key}: {value:.4f}")
logger.info(f" Overall confidence: {batch_result.confidence:.4f}")
logger.info(f" Frame duration: {batch_result.frame_duration_ms} ms")
return True
except Exception as e:
logger.error(f"❌ Dummy audio prediction failed: {e}", exc_info=True)
return False
def test_model_direct_prediction(model):
"""Test direct model prediction (without pipeline)."""
logger.info("\n" + "=" * 60)
logger.info("TEST 4: Direct Model Prediction")
logger.info("=" * 60)
if model is None:
logger.error("❌ Cannot test direct prediction: model is None")
return False
try:
# Generate dummy audio
logger.info("Generating dummy audio...")
dummy_audio = generate_dummy_audio(duration_seconds=0.5, sample_rate=16000)
# Convert to tensor
import torch
audio_tensor = torch.from_numpy(dummy_audio).float()
logger.info("Running direct model prediction...")
result = model.predict(audio_tensor, sample_rate=16000, return_dict=True)
logger.info(f"βœ… Direct model prediction successful")
logger.info(f" Fluency score: {result['fluency_score']:.4f}")
logger.info(f" Articulation class: {result['articulation_class']} "
f"({result['articulation_class_name']})")
logger.info(f" Confidence: {result['confidence']:.4f}")
return True
except Exception as e:
logger.error(f"❌ Direct model prediction failed: {e}", exc_info=True)
return False
def main():
"""Run all tests."""
logger.info("πŸš€ Starting model loading tests...\n")
results = []
# Test 1: Model loading
model = test_model_loading()
results.append(("Model Loading", model is not None))
# Test 2: Inference pipeline loading
pipeline = test_inference_pipeline_loading()
results.append(("Inference Pipeline Loading", pipeline is not None))
# Test 3: Dummy audio prediction
if pipeline:
results.append(("Dummy Audio Prediction", test_dummy_audio_prediction(pipeline)))
# Test 4: Direct model prediction
if model:
results.append(("Direct Model Prediction", test_direct_prediction(model)))
# Summary
logger.info("\n" + "=" * 60)
logger.info("TEST SUMMARY")
logger.info("=" * 60)
for test_name, passed in results:
status = "βœ… PASS" if passed else "❌ FAIL"
logger.info(f"{status}: {test_name}")
total_passed = sum(1 for _, passed in results if passed)
total_tests = len(results)
logger.info(f"\nTotal: {total_passed}/{total_tests} tests passed")
if total_passed == total_tests:
logger.info("πŸŽ‰ All tests passed!")
return 0
else:
logger.warning(f"⚠️ {total_tests - total_passed} test(s) failed")
return 1
if __name__ == "__main__":
sys.exit(main())