| | import threading |
| | from queue import Queue |
| | import sounddevice as sd |
| | import numpy as np |
| | import requests |
| | import base64 |
| | import time |
| | from dataclasses import dataclass, field |
| |
|
| | @dataclass |
| | class AudioStreamingClientArguments: |
| | sample_rate: int = field(default=16000, metadata={"help": "Audio sample rate in Hz. Default is 16000."}) |
| | chunk_size: int = field(default=512, metadata={"help": "The size of audio chunks in samples. Default is 1024."}) |
| | api_url: str = field(default="https://yxfmjcvuzgi123sw.us-east-1.aws.endpoints.huggingface.cloud", metadata={"help": "The URL of the API endpoint."}) |
| | auth_token: str = field(default="your_auth_token", metadata={"help": "Authentication token for the API."}) |
| |
|
| | class AudioStreamingClient: |
| | def __init__(self, args: AudioStreamingClientArguments, handler): |
| | self.args = args |
| | self.handler = handler |
| | self.stop_event = threading.Event() |
| | self.send_queue = Queue() |
| | self.recv_queue = Queue() |
| | self.session_id = None |
| | self.headers = { |
| | "Accept": "application/json", |
| | "Authorization": f"Bearer {self.args.auth_token}", |
| | "Content-Type": "application/json" |
| | } |
| | self.session_state = "idle" |
| |
|
| | def start(self): |
| | print("Starting audio streaming...") |
| | |
| | send_thread = threading.Thread(target=self.send_audio) |
| | play_thread = threading.Thread(target=self.play_audio) |
| |
|
| | with sd.InputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=self.audio_callback, blocksize=self.args.chunk_size): |
| | send_thread.start() |
| | play_thread.start() |
| |
|
| | try: |
| | input("Press Enter to stop streaming...") |
| | except KeyboardInterrupt: |
| | print("\nStreaming interrupted by user.") |
| | finally: |
| | self.stop_event.set() |
| | send_thread.join() |
| | play_thread.join() |
| | print("Audio streaming stopped.") |
| |
|
| | def audio_callback(self, indata, frames, time, status): |
| | self.send_queue.put(indata.copy()) |
| |
|
| | def send_audio(self): |
| | buffer = b'' |
| | while not self.stop_event.is_set(): |
| | if self.session_state != "processing" and not self.send_queue.empty(): |
| | chunk = self.send_queue.get().tobytes() |
| | buffer += chunk |
| | if len(buffer) >= self.args.chunk_size * 2: |
| | self.send_request(buffer) |
| | buffer = b'' |
| | else: |
| | self.send_request() |
| | time.sleep(0.1) |
| |
|
| | def send_request(self, audio_data=None): |
| | payload = {"input_type": "speech", |
| | "inputs": ""} |
| |
|
| | if audio_data is not None: |
| | print("Sending audio data") |
| | payload["inputs"] = base64.b64encode(audio_data).decode('utf-8') |
| |
|
| | if self.session_id: |
| | payload["session_id"] = self.session_id |
| | payload["request_type"] = "continue" |
| | else: |
| | payload["request_type"] = "start" |
| |
|
| | try: |
| | response_data = self.handler(payload) |
| |
|
| | if "session_id" in response_data: |
| | self.session_id = response_data["session_id"] |
| |
|
| | if "status" in response_data and response_data["status"] == "processing": |
| | print("Processing audio data") |
| | self.session_state = "processing" |
| |
|
| | if "output" in response_data and response_data["output"]: |
| | print("Received audio data") |
| | self.session_state = "processing" |
| | audio_bytes = base64.b64decode(response_data["output"]) |
| | audio_np = np.frombuffer(audio_bytes, dtype=np.int16) |
| | |
| | for i in range(0, len(audio_np), self.args.chunk_size): |
| | chunk = audio_np[i:i+self.args.chunk_size] |
| | self.recv_queue.put(chunk) |
| |
|
| | if "status" in response_data and response_data["status"] == "completed": |
| | print("Completed audio processing") |
| | self.session_state = None |
| | self.session_id = None |
| | while not self.recv_queue.empty(): |
| | time.sleep(0.01) |
| | while not self.send_queue.empty(): |
| | _ = self.send_queue.get() |
| |
|
| | except Exception as e: |
| | print(f"Error sending request: {e}") |
| | self.session_state = "idle" |
| |
|
| | def play_audio(self): |
| | def audio_callback(outdata, frames, time, status): |
| | if not self.recv_queue.empty(): |
| | chunk = self.recv_queue.get() |
| | |
| | |
| | chunk_int16 = np.clip(chunk, -32768, 32767).astype(np.int16) |
| | |
| | if len(chunk_int16) < len(outdata): |
| | outdata[:len(chunk_int16), 0] = chunk_int16 |
| | outdata[len(chunk_int16):] = 0 |
| | else: |
| | outdata[:, 0] = chunk_int16[:len(outdata)] |
| | else: |
| | outdata[:] = 0 |
| |
|
| | with sd.OutputStream(samplerate=self.args.sample_rate, channels=1, dtype='int16', callback=audio_callback, blocksize=self.args.chunk_size): |
| | while not self.stop_event.is_set(): |
| | time.sleep(0.01) |
| |
|
| | if __name__ == "__main__": |
| | import argparse |
| |
|
| | parser = argparse.ArgumentParser(description="Audio Streaming Client") |
| | parser.add_argument("--sample_rate", type=int, default=16000, help="Audio sample rate in Hz. Default is 16000.") |
| | parser.add_argument("--chunk_size", type=int, default=1024, help="The size of audio chunks in samples. Default is 1024.") |
| | parser.add_argument("--api_url", type=str, required=True, help="The URL of the API endpoint.") |
| | parser.add_argument("--auth_token", type=str, required=True, help="Authentication token for the API.") |
| |
|
| | args = parser.parse_args() |
| | client_args = AudioStreamingClientArguments(**vars(args)) |
| | client = AudioStreamingClient(client_args) |
| | client.start() |