EduTechTeam commited on
Commit
8b31215
·
verified ·
1 Parent(s): 9e87648

Create app.py

Browse files
Files changed (1) hide show
  1. app.py +207 -0
app.py ADDED
@@ -0,0 +1,207 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import base64
3
+ import os
4
+ from threading import Event, Thread
5
+ from google.colab import userdata
6
+
7
+ import gradio as gr
8
+ import numpy as np
9
+ import openai
10
+ from dotenv import load_dotenv
11
+ from gradio_webrtc import (
12
+ AdditionalOutputs,
13
+ StreamHandler,
14
+ WebRTC,
15
+ get_twilio_turn_credentials,
16
+ )
17
+ from openai.types.beta.realtime import ResponseAudioTranscriptDoneEvent
18
+ from pydub import AudioSegment
19
+
20
+ os.environ["TWILIO_ACCOUNT_SID"] = userdata.get('TWILIO_ACCOUNT_SID')
21
+ os.environ["TWILIO_AUTH_TOKEN"] = userdata.get('TWILIO_AUTH_TOKEN')
22
+
23
+ load_dotenv()
24
+
25
+ SAMPLE_RATE = 24000
26
+
27
+
28
+ def encode_audio(sample_rate, data):
29
+ segment = AudioSegment(
30
+ data.tobytes(),
31
+ frame_rate=sample_rate,
32
+ sample_width=data.dtype.itemsize,
33
+ channels=1,
34
+ )
35
+ pcm_audio = (
36
+ segment.set_frame_rate(SAMPLE_RATE).set_channels(1).set_sample_width(2).raw_data
37
+ )
38
+ return base64.b64encode(pcm_audio).decode("utf-8")
39
+
40
+
41
+ class OpenAIHandler(StreamHandler):
42
+ def __init__(
43
+ self,
44
+ expected_layout="mono",
45
+ output_sample_rate=SAMPLE_RATE,
46
+ output_frame_size=480,
47
+ ) -> None:
48
+ super().__init__(
49
+ expected_layout,
50
+ output_sample_rate,
51
+ output_frame_size,
52
+ input_sample_rate=SAMPLE_RATE,
53
+ )
54
+ self.connection = None
55
+ self.all_output_data = None
56
+ self.args_set = Event()
57
+ self.quit = Event()
58
+ self.connected = Event()
59
+ self.thread = None
60
+ self._generator = None
61
+
62
+ def copy(self):
63
+ return OpenAIHandler(
64
+ expected_layout=self.expected_layout,
65
+ output_sample_rate=self.output_sample_rate,
66
+ output_frame_size=self.output_frame_size,
67
+ )
68
+
69
+ def _initialize_connection(self, api_key: str):
70
+ """Connect to realtime API. Run forever in separate thread to keep connection open."""
71
+ self.client = openai.Client(api_key=api_key)
72
+ with self.client.beta.realtime.connect(
73
+ model="gpt-4o-mini-realtime-preview-2024-12-17"
74
+ ) as conn:
75
+ conn.session.update(session={"turn_detection": {"type": "server_vad"}})
76
+ self.connection = conn
77
+ self.connected.set()
78
+ self.quit.wait()
79
+
80
+ async def fetch_args(
81
+ self,
82
+ ):
83
+ if self.channel:
84
+ self.channel.send("tick")
85
+
86
+ def set_args(self, args):
87
+ super().set_args(args)
88
+ self.args_set.set()
89
+
90
+ def receive(self, frame: tuple[int, np.ndarray]) -> None:
91
+ if not self.channel:
92
+ return
93
+ if not self.connection:
94
+ asyncio.run_coroutine_threadsafe(self.fetch_args(), self.loop)
95
+ self.args_set.wait()
96
+ self.thread = Thread(
97
+ target=self._initialize_connection, args=(self.latest_args[-1],)
98
+ )
99
+ self.thread.start()
100
+ self.connected.wait()
101
+ try:
102
+ assert self.connection, "Connection not initialized"
103
+ sample_rate, array = frame
104
+ array = array.squeeze()
105
+ audio_message = encode_audio(sample_rate, array)
106
+ self.connection.input_audio_buffer.append(audio=audio_message)
107
+ except Exception as e:
108
+ # print traceback
109
+ print(f"Error in receive: {str(e)}")
110
+ import traceback
111
+
112
+ traceback.print_exc()
113
+
114
+ def generator(self):
115
+ while True:
116
+ if not self.connection:
117
+ yield None
118
+ continue
119
+ for event in self.connection:
120
+ if event.type == "response.audio_transcript.done":
121
+ yield AdditionalOutputs(event)
122
+ if event.type == "response.audio.delta":
123
+ yield (
124
+ self.output_sample_rate,
125
+ np.frombuffer(
126
+ base64.b64decode(event.delta), dtype=np.int16
127
+ ).reshape(1, -1),
128
+ )
129
+
130
+ def emit(self) -> tuple[int, np.ndarray] | None:
131
+ if not self.connection:
132
+ return None
133
+ if not self._generator:
134
+ self._generator = self.generator()
135
+ try:
136
+ return next(self._generator)
137
+ except StopIteration:
138
+ self._generator = self.generator()
139
+ return None
140
+
141
+ def reset_state(self):
142
+ """Reset connection state for new recording session"""
143
+ self.connection = None
144
+ self.args_set.clear()
145
+ self.quit.clear()
146
+ self.connected.clear()
147
+ self.thread = None
148
+ self._generator = None
149
+ self.current_session = None
150
+
151
+ def shutdown(self) -> None:
152
+ if self.connection:
153
+ self.connection.close()
154
+ self.quit.set()
155
+ if self.thread:
156
+ self.thread.join(timeout=5)
157
+ self.reset_state()
158
+
159
+
160
+ def update_chatbot(chatbot: list[dict], response: ResponseAudioTranscriptDoneEvent):
161
+ chatbot.append({"role": "assistant", "content": response.transcript})
162
+ return chatbot
163
+
164
+
165
+ with gr.Blocks() as demo:
166
+
167
+ with gr.Row(visible=True) as api_key_row:
168
+ api_key = gr.Textbox(
169
+ label="OpenAI API Key",
170
+ placeholder="Enter your OpenAI API Key",
171
+ value=os.getenv("OPENAI_API_KEY", ""),
172
+ type="password",
173
+ )
174
+ with gr.Row(visible=True) as row:
175
+ with gr.Column(scale=1):
176
+ webrtc = WebRTC(
177
+ label="Conversation",
178
+ modality="audio",
179
+ mode="send-receive",
180
+ rtc_configuration=get_twilio_turn_credentials(),
181
+ icon="openai-logo.svg",
182
+ )
183
+ with gr.Column(scale=5):
184
+ chatbot = gr.Chatbot(label="Conversation", value=[], type="messages")
185
+ webrtc.stream(
186
+ OpenAIHandler(),
187
+ inputs=[webrtc, api_key],
188
+ outputs=[webrtc],
189
+ time_limit=90,
190
+ concurrency_limit=2,
191
+ )
192
+ webrtc.on_additional_outputs(
193
+ update_chatbot,
194
+ inputs=[chatbot],
195
+ outputs=[chatbot],
196
+ show_progress="hidden",
197
+ queue=True,
198
+ )
199
+ api_key.submit(
200
+ lambda: (gr.update(visible=False), gr.update(visible=True)),
201
+ None,
202
+ [api_key_row, row],
203
+ )
204
+
205
+
206
+ if __name__ == "__main__":
207
+ demo.launch()