| |
| """ |
| Direct WebSocket endpoint validation for STT and TTS services |
| Tests each service independently to verify WebSocket functionality |
| """ |
|
|
| import asyncio |
| import websockets |
| import json |
| import base64 |
| import logging |
| import sys |
| from datetime import datetime |
|
|
| |
| logging.basicConfig(level=logging.INFO, format='%(asctime)s - %(levelname)s - %(message)s') |
| logger = logging.getLogger(__name__) |
|
|
| |
| STT_WEBSOCKET_URL = "wss://pgits-stt-gpu-service.hf.space/ws/stt" |
| TTS_WEBSOCKET_URL = "wss://pgits-tts-gpu-service.hf.space/ws/tts" |
|
|
| class WebSocketTester: |
| """Direct WebSocket endpoint tester""" |
| |
| def __init__(self): |
| self.test_results = {} |
| |
| def create_test_audio_data(self): |
| """Create minimal test audio data""" |
| |
| test_data = b'webm_test_audio_data_' + b'0' * 1000 |
| return test_data |
| |
| async def test_stt_websocket(self): |
| """Test STT WebSocket endpoint directly""" |
| logger.info("π€ Testing STT WebSocket endpoint...") |
| |
| try: |
| logger.info(f"Connecting to: {STT_WEBSOCKET_URL}") |
| |
| |
| async with websockets.connect(STT_WEBSOCKET_URL, timeout=10) as websocket: |
| logger.info("β
STT WebSocket connection established") |
| |
| |
| try: |
| confirmation = await asyncio.wait_for(websocket.recv(), timeout=15) |
| confirmation_data = json.loads(confirmation) |
| logger.info(f"π¨ STT confirmation received: {confirmation_data}") |
| |
| if confirmation_data.get("type") == "stt_connection_confirmed": |
| logger.info("β
STT connection confirmed properly") |
| |
| |
| test_audio = self.create_test_audio_data() |
| audio_b64 = base64.b64encode(test_audio).decode('utf-8') |
| |
| message = { |
| "type": "stt_audio_chunk", |
| "audio_data": audio_b64, |
| "language": "auto", |
| "model_size": "base" |
| } |
| |
| logger.info("π€ Sending test audio to STT...") |
| await websocket.send(json.dumps(message)) |
| |
| |
| response = await asyncio.wait_for(websocket.recv(), timeout=30) |
| response_data = json.loads(response) |
| |
| logger.info(f"π¨ STT response: {response_data}") |
| |
| if response_data.get("type") == "stt_transcription": |
| transcription = response_data.get("text", "") |
| logger.info(f"β
STT transcription received: {transcription}") |
| self.test_results["stt"] = {"success": True, "transcription": transcription} |
| return True |
| elif response_data.get("type") == "stt_error": |
| error_msg = response_data.get("message", "Unknown error") |
| logger.error(f"β STT service error: {error_msg}") |
| self.test_results["stt"] = {"success": False, "error": error_msg} |
| return False |
| else: |
| logger.warning(f"β οΈ Unexpected STT response type: {response_data}") |
| self.test_results["stt"] = {"success": False, "error": f"Unexpected response: {response_data}"} |
| return False |
| else: |
| logger.error(f"β Invalid STT confirmation: {confirmation_data}") |
| self.test_results["stt"] = {"success": False, "error": f"Invalid confirmation: {confirmation_data}"} |
| return False |
| |
| except asyncio.TimeoutError: |
| logger.error("β STT confirmation timeout") |
| self.test_results["stt"] = {"success": False, "error": "Confirmation timeout"} |
| return False |
| |
| except websockets.exceptions.InvalidStatusCode as e: |
| logger.error(f"β STT WebSocket invalid status: {e}") |
| self.test_results["stt"] = {"success": False, "error": f"Invalid status: {e}"} |
| return False |
| except websockets.exceptions.WebSocketException as e: |
| logger.error(f"β STT WebSocket error: {e}") |
| self.test_results["stt"] = {"success": False, "error": f"WebSocket error: {e}"} |
| return False |
| except Exception as e: |
| logger.error(f"β STT test failed: {e}") |
| self.test_results["stt"] = {"success": False, "error": str(e)} |
| return False |
| |
| async def test_tts_websocket(self): |
| """Test TTS WebSocket endpoint directly""" |
| logger.info("π Testing TTS WebSocket endpoint...") |
| |
| try: |
| logger.info(f"Connecting to: {TTS_WEBSOCKET_URL}") |
| |
| |
| async with websockets.connect(TTS_WEBSOCKET_URL, timeout=10) as websocket: |
| logger.info("β
TTS WebSocket connection established") |
| |
| |
| try: |
| confirmation = await asyncio.wait_for(websocket.recv(), timeout=15) |
| confirmation_data = json.loads(confirmation) |
| logger.info(f"π¨ TTS confirmation received: {confirmation_data}") |
| |
| if confirmation_data.get("type") == "tts_connection_confirmed": |
| logger.info("β
TTS connection confirmed properly") |
| |
| |
| test_text = "Hello, this is a WebSocket test of the text to speech service." |
| |
| message = { |
| "type": "tts_synthesize", |
| "text": test_text, |
| "voice_preset": "v2/en_speaker_6" |
| } |
| |
| logger.info(f"π€ Sending test text to TTS: {test_text}") |
| await websocket.send(json.dumps(message)) |
| |
| |
| response = await asyncio.wait_for(websocket.recv(), timeout=60) |
| response_data = json.loads(response) |
| |
| logger.info(f"π¨ TTS response type: {response_data.get('type')}") |
| |
| if response_data.get("type") == "tts_audio_response": |
| audio_size = response_data.get("audio_size", 0) |
| logger.info(f"β
TTS audio generated: {audio_size} bytes") |
| self.test_results["tts"] = {"success": True, "audio_size": audio_size} |
| return True |
| elif response_data.get("type") == "tts_error": |
| error_msg = response_data.get("message", "Unknown error") |
| logger.error(f"β TTS service error: {error_msg}") |
| self.test_results["tts"] = {"success": False, "error": error_msg} |
| return False |
| else: |
| logger.warning(f"β οΈ Unexpected TTS response type: {response_data}") |
| self.test_results["tts"] = {"success": False, "error": f"Unexpected response: {response_data}"} |
| return False |
| else: |
| logger.error(f"β Invalid TTS confirmation: {confirmation_data}") |
| self.test_results["tts"] = {"success": False, "error": f"Invalid confirmation: {confirmation_data}"} |
| return False |
| |
| except asyncio.TimeoutError: |
| logger.error("β TTS confirmation timeout") |
| self.test_results["tts"] = {"success": False, "error": "Confirmation timeout"} |
| return False |
| |
| except websockets.exceptions.InvalidStatusCode as e: |
| logger.error(f"β TTS WebSocket invalid status: {e}") |
| self.test_results["tts"] = {"success": False, "error": f"Invalid status: {e}"} |
| return False |
| except websockets.exceptions.WebSocketException as e: |
| logger.error(f"β TTS WebSocket error: {e}") |
| self.test_results["tts"] = {"success": False, "error": f"WebSocket error: {e}"} |
| return False |
| except Exception as e: |
| logger.error(f"β TTS test failed: {e}") |
| self.test_results["tts"] = {"success": False, "error": str(e)} |
| return False |
| |
| async def test_endpoint_availability(self): |
| """Test if WebSocket endpoints are even available""" |
| logger.info("π Testing endpoint availability...") |
| |
| |
| try: |
| logger.info(f"Testing connection to: {STT_WEBSOCKET_URL}") |
| async with websockets.connect(STT_WEBSOCKET_URL, timeout=5) as ws: |
| logger.info("β
STT endpoint is reachable") |
| self.test_results["stt_reachable"] = True |
| except Exception as e: |
| logger.error(f"β STT endpoint not reachable: {e}") |
| self.test_results["stt_reachable"] = False |
| |
| |
| try: |
| logger.info(f"Testing connection to: {TTS_WEBSOCKET_URL}") |
| async with websockets.connect(TTS_WEBSOCKET_URL, timeout=5) as ws: |
| logger.info("β
TTS endpoint is reachable") |
| self.test_results["tts_reachable"] = True |
| except Exception as e: |
| logger.error(f"β TTS endpoint not reachable: {e}") |
| self.test_results["tts_reachable"] = False |
| |
| async def run_all_tests(self): |
| """Run comprehensive WebSocket endpoint validation""" |
| logger.info("π Starting WebSocket endpoint validation...") |
| logger.info(f"Test started at: {datetime.now().isoformat()}") |
| |
| |
| await self.test_endpoint_availability() |
| |
| |
| stt_success = False |
| if self.test_results.get("stt_reachable"): |
| stt_success = await self.test_stt_websocket() |
| else: |
| logger.warning("β οΈ Skipping STT functional test - endpoint not reachable") |
| |
| |
| await asyncio.sleep(2) |
| |
| |
| tts_success = False |
| if self.test_results.get("tts_reachable"): |
| tts_success = await self.test_tts_websocket() |
| else: |
| logger.warning("β οΈ Skipping TTS functional test - endpoint not reachable") |
| |
| |
| self.print_test_results() |
| |
| return stt_success and tts_success |
| |
| def print_test_results(self): |
| """Print detailed test results""" |
| logger.info("\n" + "="*70) |
| logger.info("π WEBSOCKET ENDPOINT VALIDATION RESULTS") |
| logger.info("="*70) |
| |
| |
| logger.info("π€ STT Service:") |
| logger.info(f" Endpoint Reachable: {'β
' if self.test_results.get('stt_reachable') else 'β'}") |
| if "stt" in self.test_results: |
| stt_result = self.test_results["stt"] |
| if stt_result["success"]: |
| logger.info(f" WebSocket Function: β
PASS") |
| logger.info(f" Transcription: {stt_result.get('transcription', 'N/A')}") |
| else: |
| logger.info(f" WebSocket Function: β FAIL") |
| logger.info(f" Error: {stt_result.get('error', 'Unknown')}") |
| else: |
| logger.info(" WebSocket Function: β οΈ NOT TESTED") |
| |
| |
| logger.info("\nπ TTS Service:") |
| logger.info(f" Endpoint Reachable: {'β
' if self.test_results.get('tts_reachable') else 'β'}") |
| if "tts" in self.test_results: |
| tts_result = self.test_results["tts"] |
| if tts_result["success"]: |
| logger.info(f" WebSocket Function: β
PASS") |
| logger.info(f" Audio Generated: {tts_result.get('audio_size', 0)} bytes") |
| else: |
| logger.info(f" WebSocket Function: β FAIL") |
| logger.info(f" Error: {tts_result.get('error', 'Unknown')}") |
| else: |
| logger.info(" WebSocket Function: β οΈ NOT TESTED") |
| |
| logger.info("="*70) |
| |
| |
| stt_ok = self.test_results.get("stt_reachable") and self.test_results.get("stt", {}).get("success", False) |
| tts_ok = self.test_results.get("tts_reachable") and self.test_results.get("tts", {}).get("success", False) |
| |
| if stt_ok and tts_ok: |
| logger.info("π ALL WEBSOCKET ENDPOINTS WORKING!") |
| logger.info("β
Ready for ChatCal WebRTC integration") |
| elif stt_ok or tts_ok: |
| logger.warning("β οΈ PARTIAL SUCCESS - Some endpoints working") |
| if not stt_ok: |
| logger.warning("β STT WebSocket needs attention") |
| if not tts_ok: |
| logger.warning("β TTS WebSocket needs attention") |
| else: |
| logger.error("β NO WEBSOCKET ENDPOINTS WORKING") |
| logger.error("π§ Services need WebSocket endpoint deployment") |
| |
| logger.info(f"π Test completed at: {datetime.now().isoformat()}") |
|
|
| async def main(): |
| """Main test runner""" |
| tester = WebSocketTester() |
| |
| try: |
| success = await tester.run_all_tests() |
| return 0 if success else 1 |
| except KeyboardInterrupt: |
| logger.info("β Tests interrupted by user") |
| return 1 |
| except Exception as e: |
| logger.error(f"β Test runner failed: {e}") |
| return 1 |
|
|
| if __name__ == "__main__": |
| try: |
| exit_code = asyncio.run(main()) |
| sys.exit(exit_code) |
| except Exception as e: |
| logger.error(f"β Failed to run tests: {e}") |
| sys.exit(1) |