jafet21 commited on
Commit
1ac339e
·
verified ·
1 Parent(s): 9eadafc

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +180 -0
README.md CHANGED
@@ -11,3 +11,183 @@ Output: Score matrix per class [num_frames, num_classes]
11
 
12
  Classes: See mapping in yamnet_class_map.csv
13
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
11
 
12
  Classes: See mapping in yamnet_class_map.csv
13
 
14
+ Example for uses
15
+
16
+ import sounddevice as sd
17
+ import numpy as np
18
+ import onnxruntime
19
+ import scipy.signal
20
+ import csv
21
+ import threading
22
+ import time
23
+ from collections import deque, Counter
24
+
25
+ # Path to the ONNX model file
26
+ MODEL_PATH = "./yamnet.onnx"
27
+ # Create an inference session with ONNX runtime using CPU provider
28
+ session = onnxruntime.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
29
+ # Get the name of the model's input node
30
+ input_name = session.get_inputs()[0].name
31
+
32
+ def load_class_map(csv_path="yamnet_class_map.csv"):
33
+ """Load class names from a CSV file into a dictionary mapping class IDs to names."""
34
+ class_map = {}
35
+ with open(csv_path, newline='') as csvfile:
36
+ reader = csv.reader(csvfile)
37
+ next(reader) # Skip header row
38
+ for row in reader:
39
+ class_id = int(row[0])
40
+ class_name = row[2] # Class name is in the third column
41
+ class_map[class_id] = class_name
42
+ return class_map
43
+
44
+ # Load the class mapping from CSV
45
+ class_map = load_class_map()
46
+
47
+ # Audio buffer settings (1.5 seconds at 16kHz sample rate)
48
+ BUFFER_SIZE = int(16000 * 1.5)
49
+ # Circular buffer to store audio samples
50
+ audio_buffer = deque(maxlen=BUFFER_SIZE)
51
+
52
+ # Queues for storing and consolidating detection results
53
+ detections_queue = deque(maxlen=3) # Stores recent detections
54
+ consolidated_results = deque(maxlen=3) # Stores consolidated results
55
+ last_printed_result = None # Last result printed to console
56
+ last_inference_time = 0 # Timestamp of last inference
57
+
58
+ # Configuration for sleep mode (temporarily pausing inference)
59
+ sleep_triggers = {"music", "speech", "silence"} # Classes that trigger sleep mode
60
+ sleep_duration = 3.0 # How long to sleep after detecting trigger classes
61
+ same_class_count = 0 # Counter for consecutive same-class detections
62
+ last_detected_class = None # Last class detected
63
+ is_sleeping_until = 0 # Timestamp until which we're in sleep mode
64
+
65
+ def resample_if_needed(audio, original_sr, target_sr=16000):
66
+ """Resample audio to target sample rate if needed."""
67
+ if original_sr != target_sr:
68
+ audio = scipy.signal.resample_poly(audio, target_sr, int(original_sr))
69
+ return audio
70
+
71
+ def run_inference(audio_chunk):
72
+ """Run the ONNX model inference on an audio chunk and process results."""
73
+ global detections_queue, consolidated_results, last_printed_result
74
+ global is_sleeping_until, same_class_count, last_detected_class
75
+
76
+ try:
77
+ # Run the ONNX model
78
+ outputs = session.run(None, {input_name: audio_chunk.astype(np.float32)})
79
+ scores = outputs[0]
80
+ mean_scores = np.mean(scores, axis=0) # Average scores across frames
81
+
82
+ # Get the class with highest score
83
+ class_id = np.argmax(mean_scores)
84
+ confidence = mean_scores[class_id]
85
+ class_name = class_map.get(class_id, f"Class {class_id}")
86
+
87
+ # Track consecutive detections of the same class
88
+ if class_name == last_detected_class:
89
+ same_class_count += 1
90
+ else:
91
+ same_class_count = 1
92
+ last_detected_class = class_name
93
+
94
+ # Enter sleep mode if we detect trigger classes consecutively
95
+ if same_class_count >= 3 and class_name in sleep_triggers:
96
+ is_sleeping_until = time.time() + sleep_duration
97
+ print(f" Pausing inferences for {sleep_duration} seconds (detected: {class_name})")
98
+ same_class_count = 0
99
+
100
+ # Add detection to queue
101
+ detections_queue.append((class_name, confidence))
102
+
103
+ # When queue is full, consolidate results
104
+ if len(detections_queue) == 1:
105
+ most_common_name, _ = Counter([d[0] for d in detections_queue]).most_common(1)[0]
106
+ relevant_conf = [conf for n, conf in detections_queue if n == most_common_name]
107
+ avg_conf = np.mean(relevant_conf) * 100
108
+ if avg_conf > 20: # Only consider results with >20% confidence
109
+ consolidated_results.append((most_common_name, avg_conf))
110
+ detections_queue.clear()
111
+
112
+ # When we have enough consolidated results, print the final result
113
+ if len(consolidated_results) == 3:
114
+ names = [r[0] for r in consolidated_results]
115
+ confidences = [r[1] for r in consolidated_results]
116
+ most_common_name, _ = Counter(names).most_common(1)[0]
117
+ avg_conf = np.mean([c for n, c in consolidated_results if n == most_common_name])
118
+ msg = f" Consolidated result: {most_common_name} with average confidence {avg_conf:.2f}%"
119
+ if msg != last_printed_result:
120
+ print(msg)
121
+ last_printed_result = msg
122
+
123
+ except Exception as e:
124
+ print(f" ONNX Error: {e}")
125
+
126
+ def audio_callback(indata, frames, time_info, status):
127
+ """Callback function for audio input stream."""
128
+ global last_inference_time, is_sleeping_until
129
+
130
+ if status:
131
+ print(f" Status: {status}")
132
+
133
+ # Get mono audio and resample if needed
134
+ audio = indata[:, 0]
135
+ original_sr = sd.query_devices(sd.default.device[0], 'input')['default_samplerate']
136
+ audio = resample_if_needed(audio, original_sr)
137
+
138
+ # Normalize audio
139
+ max_val = np.max(np.abs(audio))
140
+ if max_val > 0:
141
+ audio = audio / max_val
142
+
143
+ # Add audio to buffer
144
+ audio_buffer.extend(audio)
145
+
146
+ # Run inference if buffer is full and not in sleep mode
147
+ if len(audio_buffer) >= BUFFER_SIZE:
148
+ now = time.time()
149
+
150
+ if now < is_sleeping_until:
151
+ return
152
+
153
+ # Throttle inference to once per second
154
+ if now - last_inference_time > 1.0:
155
+ last_inference_time = now
156
+ audio_chunk = np.array(audio_buffer)
157
+ # Run inference in a separate thread
158
+ threading.Thread(target=run_inference, args=(audio_chunk,), daemon=True).start()
159
+
160
+ def main():
161
+ """Main function to start audio streaming and processing."""
162
+ print(" Starting real-time listening (Ctrl+C to stop)...")
163
+
164
+ # List available input devices
165
+ devices = sd.query_devices()
166
+ input_devices = [i for i, d in enumerate(devices) if d['max_input_channels'] > 0]
167
+
168
+ print("\nAvailable input devices:")
169
+ for i in input_devices:
170
+ print(f"{i}: {devices[i]['name']}")
171
+
172
+ # Let user select input device
173
+ device_id = int(input("Select input device ID: "))
174
+ sd.default.device = (device_id, None)
175
+
176
+ # Configure audio settings
177
+ samplerate = int(sd.query_devices(device_id, 'input')['default_samplerate'])
178
+ sd.default.samplerate = samplerate
179
+ sd.default.channels = 1 # Mono audio
180
+
181
+ print("🎧 Listening...")
182
+
183
+ # Start audio stream with callback
184
+ with sd.InputStream(callback=audio_callback, channels=1, samplerate=samplerate):
185
+ try:
186
+ while True:
187
+ time.sleep(0.1) # Keep main thread alive
188
+ except KeyboardInterrupt:
189
+ print("\n Stopping listener...")
190
+
191
+ if __name__ == "__main__":
192
+ main()
193
+