jafet21 commited on
Commit
337f563
·
verified ·
1 Parent(s): 5a1c71d

Update README.md

Browse files
Files changed (1) hide show
  1. README.md +182 -0
README.md CHANGED
@@ -12,3 +12,185 @@ Output: Score matrix per class [num_frames, num_classes]
12
  Classes: See mapping in yamnet_class_map.csv
13
 
14
  Example for uses
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
12
  Classes: See mapping in yamnet_class_map.csv
13
 
14
  Example for uses
15
+ ```python
16
+
17
+ import sounddevice as sd
18
+ import numpy as np
19
+ import onnxruntime
20
+ import scipy.signal
21
+ import csv
22
+ import threading
23
+ import time
24
+ from collections import deque, Counter
25
+
26
+ # Path to the ONNX model file
27
+ MODEL_PATH = "./yamnet.onnx"
28
+ # Create an inference session with ONNX runtime using CPU provider
29
+ session = onnxruntime.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
30
+ # Get the name of the model's input node
31
+ input_name = session.get_inputs()[0].name
32
+
33
+ def load_class_map(csv_path="yamnet_class_map.csv"):
34
+ """Load class names from a CSV file into a dictionary mapping class IDs to names."""
35
+ class_map = {}
36
+ with open(csv_path, newline='') as csvfile:
37
+ reader = csv.reader(csvfile)
38
+ next(reader) # Skip header row
39
+ for row in reader:
40
+ class_id = int(row[0])
41
+ class_name = row[2] # Class name is in the third column
42
+ class_map[class_id] = class_name
43
+ return class_map
44
+
45
+ # Load the class mapping from CSV
46
+ class_map = load_class_map()
47
+
48
+ # Audio buffer settings (1.5 seconds at 16kHz sample rate)
49
+ BUFFER_SIZE = int(16000 * 1.5)
50
+ # Circular buffer to store audio samples
51
+ audio_buffer = deque(maxlen=BUFFER_SIZE)
52
+
53
+ # Queues for storing and consolidating detection results
54
+ detections_queue = deque(maxlen=3) # Stores recent detections
55
+ consolidated_results = deque(maxlen=3) # Stores consolidated results
56
+ last_printed_result = None # Last result printed to console
57
+ last_inference_time = 0 # Timestamp of last inference
58
+
59
+ # Configuration for sleep mode (temporarily pausing inference)
60
+ sleep_triggers = {"music", "speech", "silence"} # Classes that trigger sleep mode
61
+ sleep_duration = 3.0 # How long to sleep after detecting trigger classes
62
+ same_class_count = 0 # Counter for consecutive same-class detections
63
+ last_detected_class = None # Last class detected
64
+ is_sleeping_until = 0 # Timestamp until which we're in sleep mode
65
+
66
+ def resample_if_needed(audio, original_sr, target_sr=16000):
67
+ """Resample audio to target sample rate if needed."""
68
+ if original_sr != target_sr:
69
+ audio = scipy.signal.resample_poly(audio, target_sr, int(original_sr))
70
+ return audio
71
+
72
+ def run_inference(audio_chunk):
73
+ """Run the ONNX model inference on an audio chunk and process results."""
74
+ global detections_queue, consolidated_results, last_printed_result
75
+ global is_sleeping_until, same_class_count, last_detected_class
76
+
77
+ try:
78
+ # Run the ONNX model
79
+ outputs = session.run(None, {input_name: audio_chunk.astype(np.float32)})
80
+ scores = outputs[0]
81
+ mean_scores = np.mean(scores, axis=0) # Average scores across frames
82
+
83
+ # Get the class with highest score
84
+ class_id = np.argmax(mean_scores)
85
+ confidence = mean_scores[class_id]
86
+ class_name = class_map.get(class_id, f"Class {class_id}")
87
+
88
+ # Track consecutive detections of the same class
89
+ if class_name == last_detected_class:
90
+ same_class_count += 1
91
+ else:
92
+ same_class_count = 1
93
+ last_detected_class = class_name
94
+
95
+ # Enter sleep mode if we detect trigger classes consecutively
96
+ if same_class_count >= 3 and class_name in sleep_triggers:
97
+ is_sleeping_until = time.time() + sleep_duration
98
+ print(f" Pausing inferences for {sleep_duration} seconds (detected: {class_name})")
99
+ same_class_count = 0
100
+
101
+ # Add detection to queue
102
+ detections_queue.append((class_name, confidence))
103
+
104
+ # When queue is full, consolidate results
105
+ if len(detections_queue) == 1:
106
+ most_common_name, _ = Counter([d[0] for d in detections_queue]).most_common(1)[0]
107
+ relevant_conf = [conf for n, conf in detections_queue if n == most_common_name]
108
+ avg_conf = np.mean(relevant_conf) * 100
109
+ if avg_conf > 20: # Only consider results with >20% confidence
110
+ consolidated_results.append((most_common_name, avg_conf))
111
+ detections_queue.clear()
112
+
113
+ # When we have enough consolidated results, print the final result
114
+ if len(consolidated_results) == 3:
115
+ names = [r[0] for r in consolidated_results]
116
+ confidences = [r[1] for r in consolidated_results]
117
+ most_common_name, _ = Counter(names).most_common(1)[0]
118
+ avg_conf = np.mean([c for n, c in consolidated_results if n == most_common_name])
119
+ msg = f" Consolidated result: {most_common_name} with average confidence {avg_conf:.2f}%"
120
+ if msg != last_printed_result:
121
+ print(msg)
122
+ last_printed_result = msg
123
+
124
+ except Exception as e:
125
+ print(f" ONNX Error: {e}")
126
+
127
+ def audio_callback(indata, frames, time_info, status):
128
+ """Callback function for audio input stream."""
129
+ global last_inference_time, is_sleeping_until
130
+
131
+ if status:
132
+ print(f" Status: {status}")
133
+
134
+ # Get mono audio and resample if needed
135
+ audio = indata[:, 0]
136
+ original_sr = sd.query_devices(sd.default.device[0], 'input')['default_samplerate']
137
+ audio = resample_if_needed(audio, original_sr)
138
+
139
+ # Normalize audio
140
+ max_val = np.max(np.abs(audio))
141
+ if max_val > 0:
142
+ audio = audio / max_val
143
+
144
+ # Add audio to buffer
145
+ audio_buffer.extend(audio)
146
+
147
+ # Run inference if buffer is full and not in sleep mode
148
+ if len(audio_buffer) >= BUFFER_SIZE:
149
+ now = time.time()
150
+
151
+ if now < is_sleeping_until:
152
+ return
153
+
154
+ # Throttle inference to once per second
155
+ if now - last_inference_time > 1.0:
156
+ last_inference_time = now
157
+ audio_chunk = np.array(audio_buffer)
158
+ # Run inference in a separate thread
159
+ threading.Thread(target=run_inference, args=(audio_chunk,), daemon=True).start()
160
+
161
+ def main():
162
+ """Main function to start audio streaming and processing."""
163
+ print(" Starting real-time listening (Ctrl+C to stop)...")
164
+
165
+ # List available input devices
166
+ devices = sd.query_devices()
167
+ input_devices = [i for i, d in enumerate(devices) if d['max_input_channels'] > 0]
168
+
169
+ print("\nAvailable input devices:")
170
+ for i in input_devices:
171
+ print(f"{i}: {devices[i]['name']}")
172
+
173
+ # Let user select input device
174
+ device_id = int(input("Select input device ID: "))
175
+ sd.default.device = (device_id, None)
176
+
177
+ # Configure audio settings
178
+ samplerate = int(sd.query_devices(device_id, 'input')['default_samplerate'])
179
+ sd.default.samplerate = samplerate
180
+ sd.default.channels = 1 # Mono audio
181
+
182
+ print("🎧 Listening...")
183
+
184
+ # Start audio stream with callback
185
+ with sd.InputStream(callback=audio_callback, channels=1, samplerate=samplerate):
186
+ try:
187
+ while True:
188
+ time.sleep(0.1) # Keep main thread alive
189
+ except KeyboardInterrupt:
190
+ print("\n Stopping listener...")
191
+
192
+ if __name__ == "__main__":
193
+ main()
194
+
195
+ ```
196
+