nice-bill commited on
Commit
bce1b3c
·
1 Parent(s): fe88a2e

added test for websocket

Browse files
Files changed (1) hide show
  1. src/api/test_ws.py +71 -0
src/api/test_ws.py ADDED
@@ -0,0 +1,71 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import websocket
2
+ import json
3
+ import time
4
+ import librosa
5
+ import numpy as np
6
+ import os
7
+ import torch
8
+
9
+ # --- CONFIG ---
10
+ MOCK_MIC_RATE = 44100 # Simulate a 44.1kHz microphone
11
+ WS_URL = f"ws://localhost:8000/stream/audio?rate={MOCK_MIC_RATE}"
12
+ TEST_FILE = r"C:\dev\archive\Emotions\Angry\03-01-05-01-01-01-01.wav"
13
+
14
+ def test_streaming():
15
+ if not os.path.exists(TEST_FILE):
16
+ print(f"Test file not found: {TEST_FILE}")
17
+ return
18
+
19
+ print(f"Loading test file at {MOCK_MIC_RATE}Hz to simulate high-res mic...")
20
+ # Load audio and resample to the MOCK rate
21
+ speech, _ = librosa.load(TEST_FILE, sr=MOCK_MIC_RATE)
22
+
23
+ # Connect to WebSocket
24
+ print(f"Connecting to {WS_URL}...")
25
+ try:
26
+ ws = websocket.create_connection(WS_URL)
27
+ except Exception as e:
28
+ print(f"Connection failed: {e}")
29
+ return
30
+
31
+ try:
32
+ # Send 0.5s chunks of 44.1kHz data
33
+ chunk_size = int(MOCK_MIC_RATE * 0.5)
34
+
35
+ print("Starting Stream...")
36
+ for i in range(0, len(speech), chunk_size):
37
+ chunk = speech[i:i + chunk_size]
38
+ if len(chunk) == 0: continue
39
+
40
+ # Convert to 16-bit PCM
41
+ chunk_int16 = (chunk * 32767).astype(np.int16)
42
+
43
+ # Send binary data
44
+ ws.send_binary(chunk_int16.tobytes())
45
+
46
+ # Receive response
47
+ try:
48
+ # Set a longer timeout for resampling latency
49
+ ws.settimeout(0.5)
50
+ result = ws.recv()
51
+ data = json.loads(result)
52
+
53
+ # Check for the new status field and confidence
54
+ status_marker = "[ALERT]" if data.get('status') == "high_confidence" else "[INFO]"
55
+ print(f"{status_marker} Prediction: {data['emotion']} | Conf: {data['confidence']:.2%} | Status: {data.get('status')}")
56
+
57
+ except websocket.WebSocketTimeoutException:
58
+ pass
59
+
60
+ time.sleep(0.5)
61
+
62
+ print("\nStream Finished.")
63
+
64
+ except Exception as e:
65
+ print(f"Error during stream: {e}")
66
+ finally:
67
+ ws.close()
68
+ print("WebSocket Closed.")
69
+
70
+ if __name__ == "__main__":
71
+ test_streaming()