|
|
""" |
|
|
Performance tests for speech pathology diagnosis system. |
|
|
|
|
|
Tests latency requirements: |
|
|
- File batch: <200ms per file |
|
|
- Per-frame: <50ms |
|
|
- WebSocket roundtrip: <100ms |
|
|
""" |
|
|
|
|
|
import time |
|
|
import numpy as np |
|
|
import logging |
|
|
from pathlib import Path |
|
|
import asyncio |
|
|
from typing import Dict, List |
|
|
|
|
|
logger = logging.getLogger(__name__) |
|
|
|
|
|
|
|
|
def generate_test_audio(duration_seconds: float = 1.0, sample_rate: int = 16000) -> np.ndarray: |
|
|
""" |
|
|
Generate synthetic test audio. |
|
|
|
|
|
Args: |
|
|
duration_seconds: Duration in seconds |
|
|
sample_rate: Sample rate in Hz |
|
|
|
|
|
Returns: |
|
|
Audio array |
|
|
""" |
|
|
num_samples = int(duration_seconds * sample_rate) |
|
|
|
|
|
t = np.linspace(0, duration_seconds, num_samples) |
|
|
audio = 0.5 * np.sin(2 * np.pi * 440 * t) |
|
|
return audio.astype(np.float32) |
|
|
|
|
|
|
|
|
def test_batch_latency(pipeline, num_files: int = 10) -> Dict[str, float]: |
|
|
""" |
|
|
Test batch file processing latency. |
|
|
|
|
|
Args: |
|
|
pipeline: InferencePipeline instance |
|
|
num_files: Number of test files to process |
|
|
|
|
|
Returns: |
|
|
Dictionary with latency statistics |
|
|
""" |
|
|
logger.info(f"Testing batch latency with {num_files} files...") |
|
|
|
|
|
latencies = [] |
|
|
|
|
|
for i in range(num_files): |
|
|
|
|
|
audio = generate_test_audio(duration_seconds=1.0) |
|
|
|
|
|
|
|
|
import tempfile |
|
|
import soundfile as sf |
|
|
|
|
|
with tempfile.NamedTemporaryFile(suffix='.wav', delete=False) as f: |
|
|
temp_path = f.name |
|
|
sf.write(temp_path, audio, 16000) |
|
|
|
|
|
try: |
|
|
start_time = time.time() |
|
|
result = pipeline.predict_phone_level(temp_path, return_timestamps=True) |
|
|
latency_ms = (time.time() - start_time) * 1000 |
|
|
latencies.append(latency_ms) |
|
|
|
|
|
logger.info(f" File {i+1}: {latency_ms:.1f}ms ({result.num_frames} frames)") |
|
|
except Exception as e: |
|
|
logger.error(f" File {i+1} failed: {e}") |
|
|
finally: |
|
|
import os |
|
|
if os.path.exists(temp_path): |
|
|
os.remove(temp_path) |
|
|
|
|
|
if not latencies: |
|
|
return {"error": "No successful runs"} |
|
|
|
|
|
avg_latency = sum(latencies) / len(latencies) |
|
|
max_latency = max(latencies) |
|
|
min_latency = min(latencies) |
|
|
|
|
|
result = { |
|
|
"avg_latency_ms": avg_latency, |
|
|
"max_latency_ms": max_latency, |
|
|
"min_latency_ms": min_latency, |
|
|
"num_files": len(latencies), |
|
|
"target_ms": 200.0, |
|
|
"passed": avg_latency < 200.0 |
|
|
} |
|
|
|
|
|
logger.info(f"β
Batch latency test: avg={avg_latency:.1f}ms, max={max_latency:.1f}ms, " |
|
|
f"target=200ms, passed={result['passed']}") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def test_frame_latency(pipeline, num_frames: int = 100) -> Dict[str, float]: |
|
|
""" |
|
|
Test per-frame processing latency. |
|
|
|
|
|
Args: |
|
|
pipeline: InferencePipeline instance |
|
|
num_frames: Number of frames to test |
|
|
|
|
|
Returns: |
|
|
Dictionary with latency statistics |
|
|
""" |
|
|
logger.info(f"Testing frame latency with {num_frames} frames...") |
|
|
|
|
|
|
|
|
audio = generate_test_audio(duration_seconds=1.0) |
|
|
|
|
|
latencies = [] |
|
|
|
|
|
for i in range(num_frames): |
|
|
start_time = time.time() |
|
|
try: |
|
|
result = pipeline.predict_phone_level(audio, return_timestamps=False) |
|
|
latency_ms = (time.time() - start_time) * 1000 |
|
|
latencies.append(latency_ms) |
|
|
except Exception as e: |
|
|
logger.error(f" Frame {i+1} failed: {e}") |
|
|
|
|
|
if not latencies: |
|
|
return {"error": "No successful runs"} |
|
|
|
|
|
avg_latency = sum(latencies) / len(latencies) |
|
|
max_latency = max(latencies) |
|
|
min_latency = min(latencies) |
|
|
p95_latency = sorted(latencies)[int(len(latencies) * 0.95)] |
|
|
|
|
|
result = { |
|
|
"avg_latency_ms": avg_latency, |
|
|
"max_latency_ms": max_latency, |
|
|
"min_latency_ms": min_latency, |
|
|
"p95_latency_ms": p95_latency, |
|
|
"num_frames": len(latencies), |
|
|
"target_ms": 50.0, |
|
|
"passed": avg_latency < 50.0 |
|
|
} |
|
|
|
|
|
logger.info(f"β
Frame latency test: avg={avg_latency:.1f}ms, p95={p95_latency:.1f}ms, " |
|
|
f"target=50ms, passed={result['passed']}") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
async def test_websocket_latency(websocket_url: str, num_chunks: int = 50) -> Dict[str, float]: |
|
|
""" |
|
|
Test WebSocket streaming latency. |
|
|
|
|
|
Args: |
|
|
websocket_url: WebSocket URL |
|
|
num_chunks: Number of chunks to send |
|
|
|
|
|
Returns: |
|
|
Dictionary with latency statistics |
|
|
""" |
|
|
try: |
|
|
import websockets |
|
|
|
|
|
logger.info(f"Testing WebSocket latency with {num_chunks} chunks...") |
|
|
|
|
|
latencies = [] |
|
|
|
|
|
async with websockets.connect(websocket_url) as websocket: |
|
|
|
|
|
chunk_samples = 320 |
|
|
audio_chunk = generate_test_audio(duration_seconds=0.02) |
|
|
chunk_bytes = (audio_chunk * 32768).astype(np.int16).tobytes() |
|
|
|
|
|
for i in range(num_chunks): |
|
|
start_time = time.time() |
|
|
|
|
|
|
|
|
await websocket.send(chunk_bytes) |
|
|
|
|
|
|
|
|
response = await websocket.recv() |
|
|
|
|
|
latency_ms = (time.time() - start_time) * 1000 |
|
|
latencies.append(latency_ms) |
|
|
|
|
|
if i % 10 == 0: |
|
|
logger.info(f" Chunk {i+1}: {latency_ms:.1f}ms") |
|
|
|
|
|
if not latencies: |
|
|
return {"error": "No successful runs"} |
|
|
|
|
|
avg_latency = sum(latencies) / len(latencies) |
|
|
max_latency = max(latencies) |
|
|
p95_latency = sorted(latencies)[int(len(latencies) * 0.95)] |
|
|
|
|
|
result = { |
|
|
"avg_latency_ms": avg_latency, |
|
|
"max_latency_ms": max_latency, |
|
|
"p95_latency_ms": p95_latency, |
|
|
"num_chunks": len(latencies), |
|
|
"target_ms": 100.0, |
|
|
"passed": avg_latency < 100.0 |
|
|
} |
|
|
|
|
|
logger.info(f"β
WebSocket latency test: avg={avg_latency:.1f}ms, p95={p95_latency:.1f}ms, " |
|
|
f"target=100ms, passed={result['passed']}") |
|
|
|
|
|
return result |
|
|
|
|
|
except ImportError: |
|
|
logger.warning("websockets library not available, skipping WebSocket test") |
|
|
return {"error": "websockets library not available"} |
|
|
except Exception as e: |
|
|
logger.error(f"WebSocket test failed: {e}") |
|
|
return {"error": str(e)} |
|
|
|
|
|
|
|
|
def test_concurrent_connections(pipeline, num_connections: int = 10) -> Dict[str, Any]: |
|
|
""" |
|
|
Test concurrent processing (simulated). |
|
|
|
|
|
Args: |
|
|
pipeline: InferencePipeline instance |
|
|
num_connections: Number of concurrent requests |
|
|
|
|
|
Returns: |
|
|
Dictionary with results |
|
|
""" |
|
|
logger.info(f"Testing {num_connections} concurrent connections...") |
|
|
|
|
|
import concurrent.futures |
|
|
|
|
|
def process_audio(i: int): |
|
|
try: |
|
|
audio = generate_test_audio(duration_seconds=0.5) |
|
|
start_time = time.time() |
|
|
result = pipeline.predict_phone_level(audio, return_timestamps=False) |
|
|
latency_ms = (time.time() - start_time) * 1000 |
|
|
return {"success": True, "latency_ms": latency_ms, "frames": result.num_frames} |
|
|
except Exception as e: |
|
|
return {"success": False, "error": str(e)} |
|
|
|
|
|
start_time = time.time() |
|
|
|
|
|
with concurrent.futures.ThreadPoolExecutor(max_workers=num_connections) as executor: |
|
|
futures = [executor.submit(process_audio, i) for i in range(num_connections)] |
|
|
results = [f.result() for f in concurrent.futures.as_completed(futures)] |
|
|
|
|
|
total_time = time.time() - start_time |
|
|
|
|
|
successful = sum(1 for r in results if r.get("success", False)) |
|
|
avg_latency = sum(r["latency_ms"] for r in results if r.get("success", False)) / successful if successful > 0 else 0.0 |
|
|
|
|
|
result = { |
|
|
"total_connections": num_connections, |
|
|
"successful": successful, |
|
|
"failed": num_connections - successful, |
|
|
"total_time_seconds": total_time, |
|
|
"avg_latency_ms": avg_latency, |
|
|
"throughput_per_second": successful / total_time if total_time > 0 else 0.0 |
|
|
} |
|
|
|
|
|
logger.info(f"β
Concurrent test: {successful}/{num_connections} successful, " |
|
|
f"avg_latency={avg_latency:.1f}ms, throughput={result['throughput_per_second']:.1f}/s") |
|
|
|
|
|
return result |
|
|
|
|
|
|
|
|
def run_all_performance_tests(pipeline, websocket_url: Optional[str] = None) -> Dict[str, Any]: |
|
|
""" |
|
|
Run all performance tests. |
|
|
|
|
|
Args: |
|
|
pipeline: InferencePipeline instance |
|
|
websocket_url: Optional WebSocket URL for streaming tests |
|
|
|
|
|
Returns: |
|
|
Dictionary with all test results |
|
|
""" |
|
|
logger.info("=" * 60) |
|
|
logger.info("Running Performance Tests") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
results = {} |
|
|
|
|
|
|
|
|
logger.info("\n1. Batch File Latency Test") |
|
|
results["batch_latency"] = test_batch_latency(pipeline) |
|
|
|
|
|
|
|
|
logger.info("\n2. Per-Frame Latency Test") |
|
|
results["frame_latency"] = test_frame_latency(pipeline) |
|
|
|
|
|
|
|
|
logger.info("\n3. Concurrent Connections Test") |
|
|
results["concurrent"] = test_concurrent_connections(pipeline, num_connections=10) |
|
|
|
|
|
|
|
|
if websocket_url: |
|
|
logger.info("\n4. WebSocket Latency Test") |
|
|
results["websocket_latency"] = asyncio.run(test_websocket_latency(websocket_url)) |
|
|
|
|
|
|
|
|
logger.info("\n" + "=" * 60) |
|
|
logger.info("Performance Test Summary") |
|
|
logger.info("=" * 60) |
|
|
|
|
|
if "batch_latency" in results and results["batch_latency"].get("passed"): |
|
|
logger.info("β
Batch latency: PASSED") |
|
|
else: |
|
|
logger.warning("β Batch latency: FAILED") |
|
|
|
|
|
if "frame_latency" in results and results["frame_latency"].get("passed"): |
|
|
logger.info("β
Frame latency: PASSED") |
|
|
else: |
|
|
logger.warning("β Frame latency: FAILED") |
|
|
|
|
|
if "websocket_latency" in results and results["websocket_latency"].get("passed"): |
|
|
logger.info("β
WebSocket latency: PASSED") |
|
|
elif "websocket_latency" in results: |
|
|
logger.warning("β WebSocket latency: FAILED") |
|
|
|
|
|
return results |
|
|
|
|
|
|
|
|
if __name__ == "__main__": |
|
|
|
|
|
logging.basicConfig(level=logging.INFO) |
|
|
|
|
|
try: |
|
|
from inference.inference_pipeline import create_inference_pipeline |
|
|
|
|
|
pipeline = create_inference_pipeline() |
|
|
results = run_all_performance_tests(pipeline) |
|
|
|
|
|
print("\nTest Results:") |
|
|
import json |
|
|
print(json.dumps(results, indent=2)) |
|
|
|
|
|
except Exception as e: |
|
|
logger.error(f"Test failed: {e}", exc_info=True) |
|
|
|
|
|
|