Spaces:
Runtime error
Runtime error
| import asyncio | |
| import base64 | |
| import os | |
| from threading import Event, Thread | |
| import gradio as gr | |
| import numpy as np | |
| import openai | |
| from dotenv import load_dotenv | |
| from gradio_webrtc import ( | |
| AdditionalOutputs, | |
| StreamHandler, | |
| WebRTC, | |
| get_twilio_turn_credentials, | |
| ) | |
| from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent | |
| from pydub import AudioSegment | |
| #os.environ["TWILIO_ACCOUNT_SID"] = userdata.get('TWILIO_ACCOUNT_SID') | |
| #os.environ["TWILIO_AUTH_TOKEN"] = userdata.get('TWILIO_AUTH_TOKEN') | |
| load_dotenv() | |
| SAMPLE_RATE = 24000 | |
| def encode_audio(sample_rate, data): | |
| segment = AudioSegment( | |
| data.tobytes(), | |
| frame_rate=sample_rate, | |
| sample_width=data.dtype.itemsize, | |
| channels=1, | |
| ) | |
| pcm_audio = ( | |
| segment.set_frame_rate(SAMPLE_RATE).set_channels(1).set_sample_width(2).raw_data | |
| ) | |
| return base64.b64encode(pcm_audio).decode("utf-8") | |
| class OpenAIHandler(StreamHandler): | |
| def __init__( | |
| self, | |
| expected_layout="mono", | |
| output_sample_rate=SAMPLE_RATE, | |
| output_frame_size=480, | |
| ) -> None: | |
| super().__init__( | |
| expected_layout, | |
| output_sample_rate, | |
| output_frame_size, | |
| input_sample_rate=SAMPLE_RATE, | |
| ) | |
| self.connection = None | |
| self.all_output_data = None | |
| self.args_set = Event() | |
| self.quit = Event() | |
| self.connected = Event() | |
| self.thread = None | |
| self._generator = None | |
| def copy(self): | |
| return OpenAIHandler( | |
| expected_layout=self.expected_layout, | |
| output_sample_rate=self.output_sample_rate, | |
| output_frame_size=self.output_frame_size, | |
| ) | |
| def _initialize_connection(self, api_key: str): | |
| """Connect to realtime API. Run forever in separate thread to keep connection open.""" | |
| self.client = openai.Client(api_key=api_key) | |
| with self.client.beta.realtime.connect( | |
| model="gpt-4o-mini-realtime-preview-2024-12-17" | |
| ) as conn: | |
| conn.session.update(session={"turn_detection": {"type": "server_vad"}}) | |
| self.connection = conn | |
| self.connected.set() | |
| self.quit.wait() | |
| async def fetch_args( | |
| self, | |
| ): | |
| if self.channel: | |
| self.channel.send("tick") | |
| def set_args(self, args): | |
| super().set_args(args) | |
| self.args_set.set() | |
| def receive(self, frame: tuple[int, np.ndarray]) -> None: | |
| if not self.channel: | |
| return | |
| if not self.connection: | |
| asyncio.run_coroutine_threadsafe(self.fetch_args(), self.loop) | |
| self.args_set.wait() | |
| self.thread = Thread( | |
| target=self._initialize_connection, args=(self.latest_args[-1],) | |
| ) | |
| self.thread.start() | |
| self.connected.wait() | |
| try: | |
| assert self.connection, "Connection not initialized" | |
| sample_rate, array = frame | |
| array = array.squeeze() | |
| audio_message = encode_audio(sample_rate, array) | |
| self.connection.input_audio_buffer.append(audio=audio_message) | |
| except Exception as e: | |
| # print traceback | |
| print(f"Error in receive: {str(e)}") | |
| import traceback | |
| traceback.print_exc() | |
| def generator(self): | |
| while True: | |
| if not self.connection: | |
| yield None | |
| continue | |
| for event in self.connection: | |
| if event.type == "response.audio_transcript.done": | |
| yield AdditionalOutputs(event) | |
| if event.type == "response.audio.delta": | |
| yield ( | |
| self.output_sample_rate, | |
| np.frombuffer( | |
| base64.b64decode(event.delta), dtype=np.int16 | |
| ).reshape(1, -1), | |
| ) | |
| def emit(self) -> tuple[int, np.ndarray] | None: | |
| if not self.connection: | |
| return None | |
| if not self._generator: | |
| self._generator = self.generator() | |
| try: | |
| return next(self._generator) | |
| except StopIteration: | |
| self._generator = self.generator() | |
| return None | |
| def reset_state(self): | |
| """Reset connection state for new recording session""" | |
| self.connection = None | |
| self.args_set.clear() | |
| self.quit.clear() | |
| self.connected.clear() | |
| self.thread = None | |
| self._generator = None | |
| self.current_session = None | |
| def shutdown(self) -> None: | |
| if self.connection: | |
| self.connection.close() | |
| self.quit.set() | |
| if self.thread: | |
| self.thread.join(timeout=5) | |
| self.reset_state() | |
| def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent): | |
| chatbot.append({"role": "assistant", "content": response.transcript}) | |
| return chatbot | |
| with gr.Blocks() as demo: | |
| with gr.Row(visible=True) as api_key_row: | |
| api_key = gr.Textbox( | |
| label="OpenAI API Key", | |
| placeholder="Enter your OpenAI API Key", | |
| value=os.getenv("OPENAI_API_KEY", ""), | |
| type="password", | |
| ) | |
| with gr.Row(visible=True) as row: | |
| with gr.Column(scale=1): | |
| webrtc = WebRTC( | |
| label="Conversation", | |
| modality="audio", | |
| mode="send-receive", | |
| rtc_configuration=get_twilio_turn_credentials(), | |
| icon="openai-logo.svg", | |
| ) | |
| with gr.Column(scale=5): | |
| chatbot = gr.Chatbot(label="Conversation", value=[], type="messages") | |
| webrtc.stream( | |
| OpenAIHandler(), | |
| inputs=[webrtc, api_key], | |
| outputs=[webrtc], | |
| time_limit=90, | |
| concurrency_limit=2, | |
| ) | |
| webrtc.on_additional_outputs( | |
| update_chatbot, | |
| inputs=[chatbot], | |
| outputs=[chatbot], | |
| #show_progress="hidden", | |
| queue=True, | |
| ) | |
| api_key.submit( | |
| lambda: (gr.update(visible=False), gr.update(visible=True)), | |
| None, | |
| [api_key_row, row], | |
| ) | |
| if __name__ == "__main__": | |
| demo.launch() |