File size: 7,277 Bytes
9eadafc
 
 
 
 
 
 
 
 
 
 
 
 
1ac339e
337f563
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
---
pipeline_tag: audio-classification
---
This model is an ONNX version of a YAMNet-based classifier trained to recognize environmental and human-made sounds (e.g., speech, music, silence, barking, etc.).

Format: ONNX (.onnx)

Input: float32 mono audio, sampled at 16,000 Hz

Output: Score matrix per class [num_frames, num_classes]

Classes: See mapping in yamnet_class_map.csv

Example for uses
```python

import sounddevice as sd
import numpy as np
import onnxruntime
import scipy.signal
import csv
import threading
import time
from collections import deque, Counter

# Path to the ONNX model file
MODEL_PATH = "./yamnet.onnx"
# Create an inference session with ONNX runtime using CPU provider
session = onnxruntime.InferenceSession(MODEL_PATH, providers=['CPUExecutionProvider'])
# Get the name of the model's input node
input_name = session.get_inputs()[0].name

def load_class_map(csv_path="yamnet_class_map.csv"):
    """Load class names from a CSV file into a dictionary mapping class IDs to names."""
    class_map = {}
    with open(csv_path, newline='') as csvfile:
        reader = csv.reader(csvfile)
        next(reader)  # Skip header row
        for row in reader:
            class_id = int(row[0])
            class_name = row[2]  # Class name is in the third column
            class_map[class_id] = class_name
    return class_map

# Load the class mapping from CSV
class_map = load_class_map()

# Audio buffer settings (1.5 seconds at 16kHz sample rate)
BUFFER_SIZE = int(16000 * 1.5) 
# Circular buffer to store audio samples
audio_buffer = deque(maxlen=BUFFER_SIZE)

# Queues for storing and consolidating detection results
detections_queue = deque(maxlen=3)  # Stores recent detections
consolidated_results = deque(maxlen=3)  # Stores consolidated results
last_printed_result = None  # Last result printed to console
last_inference_time = 0  # Timestamp of last inference

# Configuration for sleep mode (temporarily pausing inference)
sleep_triggers = {"music", "speech", "silence"}  # Classes that trigger sleep mode
sleep_duration = 3.0  # How long to sleep after detecting trigger classes
same_class_count = 0  # Counter for consecutive same-class detections
last_detected_class = None  # Last class detected
is_sleeping_until = 0  # Timestamp until which we're in sleep mode

def resample_if_needed(audio, original_sr, target_sr=16000):
    """Resample audio to target sample rate if needed."""
    if original_sr != target_sr:
        audio = scipy.signal.resample_poly(audio, target_sr, int(original_sr))
    return audio

def run_inference(audio_chunk):
    """Run the ONNX model inference on an audio chunk and process results."""
    global detections_queue, consolidated_results, last_printed_result
    global is_sleeping_until, same_class_count, last_detected_class

    try:
        # Run the ONNX model
        outputs = session.run(None, {input_name: audio_chunk.astype(np.float32)})
        scores = outputs[0]
        mean_scores = np.mean(scores, axis=0)  # Average scores across frames

        # Get the class with highest score
        class_id = np.argmax(mean_scores)
        confidence = mean_scores[class_id]
        class_name = class_map.get(class_id, f"Class {class_id}")

        # Track consecutive detections of the same class
        if class_name == last_detected_class:
            same_class_count += 1
        else:
            same_class_count = 1
            last_detected_class = class_name

        # Enter sleep mode if we detect trigger classes consecutively
        if same_class_count >= 3 and class_name in sleep_triggers:
            is_sleeping_until = time.time() + sleep_duration
            print(f" Pausing inferences for {sleep_duration} seconds (detected: {class_name})")
            same_class_count = 0

        # Add detection to queue
        detections_queue.append((class_name, confidence))

        # When queue is full, consolidate results
        if len(detections_queue) == 1:
            most_common_name, _ = Counter([d[0] for d in detections_queue]).most_common(1)[0]
            relevant_conf = [conf for n, conf in detections_queue if n == most_common_name]
            avg_conf = np.mean(relevant_conf) * 100
            if avg_conf > 20:  # Only consider results with >20% confidence
                consolidated_results.append((most_common_name, avg_conf))
            detections_queue.clear()

        # When we have enough consolidated results, print the final result
        if len(consolidated_results) == 3:
            names = [r[0] for r in consolidated_results]
            confidences = [r[1] for r in consolidated_results]
            most_common_name, _ = Counter(names).most_common(1)[0]
            avg_conf = np.mean([c for n, c in consolidated_results if n == most_common_name])
            msg = f" Consolidated result: {most_common_name} with average confidence {avg_conf:.2f}%"
            if msg != last_printed_result:
                print(msg)
                last_printed_result = msg

    except Exception as e:
        print(f" ONNX Error: {e}")

def audio_callback(indata, frames, time_info, status):
    """Callback function for audio input stream."""
    global last_inference_time, is_sleeping_until

    if status:
        print(f" Status: {status}")

    # Get mono audio and resample if needed
    audio = indata[:, 0]
    original_sr = sd.query_devices(sd.default.device[0], 'input')['default_samplerate']
    audio = resample_if_needed(audio, original_sr)

    # Normalize audio
    max_val = np.max(np.abs(audio))
    if max_val > 0:
        audio = audio / max_val

    # Add audio to buffer
    audio_buffer.extend(audio)

    # Run inference if buffer is full and not in sleep mode
    if len(audio_buffer) >= BUFFER_SIZE:
        now = time.time()

        if now < is_sleeping_until:
            return

        # Throttle inference to once per second
        if now - last_inference_time > 1.0:
            last_inference_time = now
            audio_chunk = np.array(audio_buffer)
            # Run inference in a separate thread
            threading.Thread(target=run_inference, args=(audio_chunk,), daemon=True).start()

def main():
    """Main function to start audio streaming and processing."""
    print(" Starting real-time listening (Ctrl+C to stop)...")

    # List available input devices
    devices = sd.query_devices()
    input_devices = [i for i, d in enumerate(devices) if d['max_input_channels'] > 0]

    print("\nAvailable input devices:")
    for i in input_devices:
        print(f"{i}: {devices[i]['name']}")

    # Let user select input device
    device_id = int(input("Select input device ID: "))
    sd.default.device = (device_id, None)

    # Configure audio settings
    samplerate = int(sd.query_devices(device_id, 'input')['default_samplerate'])
    sd.default.samplerate = samplerate
    sd.default.channels = 1  # Mono audio

    print("🎧 Listening...")

    # Start audio stream with callback
    with sd.InputStream(callback=audio_callback, channels=1, samplerate=samplerate):
        try:
            while True:
                time.sleep(0.1)  # Keep main thread alive
        except KeyboardInterrupt:
            print("\n Stopping listener...")

if __name__ == "__main__":
    main()

```