nihalaninihal commited on
Commit
0f01eb3
·
verified ·
1 Parent(s): dc64990

Upload 5 files

Browse files
Files changed (5) hide show
  1. app.py +174 -0
  2. header.html +12 -0
  3. prettierrc.txt +6 -0
  4. requirements.txt +4 -0
  5. style.css +4 -0
app.py ADDED
@@ -0,0 +1,174 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ import asyncio
2
+ import pathlib
3
+ from typing import AsyncGenerator, Literal
4
+
5
+ from google import genai
6
+ from google.genai.types import (
7
+ Content,
8
+ LiveConnectConfig,
9
+ Part,
10
+ PrebuiltVoiceConfig,
11
+ SpeechConfig,
12
+ VoiceConfig,
13
+ )
14
+ import gradio as gr
15
+ from gradio_webrtc import AsyncStreamHandler, WebRTC, async_aggregate_bytes_to_16bit
16
+ import numpy as np
17
+
18
+ current_dir = pathlib.Path(__file__).parent
19
+
20
+
21
+ class GeminiHandler(AsyncStreamHandler):
22
+ """Handler for the Gemini API"""
23
+
24
+ def __init__(
25
+ self,
26
+ expected_layout: Literal["mono"] = "mono",
27
+ output_sample_rate: int = 24000,
28
+ output_frame_size: int = 480,
29
+ input_sample_rate: int = 16000,
30
+ ) -> None:
31
+ super().__init__(
32
+ expected_layout,
33
+ output_sample_rate,
34
+ output_frame_size,
35
+ input_sample_rate=input_sample_rate,
36
+ )
37
+ self.input_queue: asyncio.Queue = asyncio.Queue()
38
+ self.output_queue: asyncio.Queue = asyncio.Queue()
39
+ self.quit: asyncio.Event = asyncio.Event()
40
+
41
+ def copy(self) -> "GeminiHandler":
42
+ """Required implementation of the copy method for AsyncStreamHandler"""
43
+ return GeminiHandler(
44
+ expected_layout=self.expected_layout,
45
+ output_sample_rate=self.output_sample_rate,
46
+ output_frame_size=self.output_frame_size,
47
+ )
48
+
49
+ async def stream(self) -> AsyncGenerator[bytes, None]:
50
+ """Helper method to stream input audio to the server. Used in start_stream."""
51
+ while not self.quit.is_set():
52
+ audio = await self.input_queue.get()
53
+ yield audio
54
+ return
55
+
56
+ async def connect(
57
+ self,
58
+ project_id: str,
59
+ location: str,
60
+ voice_name: str | None = None,
61
+ system_instruction: str | None = None,
62
+ ) -> AsyncGenerator[bytes, None]:
63
+ """Connect to the Gemini server and start the stream."""
64
+ client = genai.Client(vertexai=True, project=project_id, location=location)
65
+ config = LiveConnectConfig(
66
+ response_modalities=["AUDIO"],
67
+ speech_config=SpeechConfig(
68
+ voice_config=VoiceConfig(
69
+ prebuilt_voice_config=PrebuiltVoiceConfig(
70
+ voice_name=voice_name,
71
+ )
72
+ )
73
+ ),
74
+ system_instruction=Content(parts=[Part.from_text(text=system_instruction)]),
75
+ )
76
+ async with client.aio.live.connect(
77
+ model="gemini-2.0-flash-live-preview-04-09", config=config
78
+ ) as session:
79
+ async for audio in session.start_stream(
80
+ stream=self.stream(), mime_type="audio/pcm"
81
+ ):
82
+ if audio.data:
83
+ yield audio.data
84
+
85
+ async def receive(self, frame: tuple[int, np.ndarray]) -> None:
86
+ """Receive audio from the user and put it in the input stream."""
87
+ _, array = frame
88
+ array = array.squeeze()
89
+ audio_message = array.tobytes()
90
+ self.input_queue.put_nowait(audio_message)
91
+
92
+ async def generator(self) -> None:
93
+ """Helper method for placing audio from the server into the output queue."""
94
+ async for audio_response in async_aggregate_bytes_to_16bit(
95
+ self.connect(*self.latest_args[1:])
96
+ ):
97
+ self.output_queue.put_nowait(audio_response)
98
+
99
+ async def emit(self) -> tuple[int, np.ndarray]:
100
+ """Required implementation of the emit method for AsyncStreamHandler"""
101
+ if not self.args_set.is_set():
102
+ await self.wait_for_args()
103
+ asyncio.create_task(self.generator())
104
+
105
+ array = await self.output_queue.get()
106
+ return (self.output_sample_rate, array)
107
+
108
+ def shutdown(self) -> None:
109
+ """Stop the stream method on shutdown"""
110
+ self.quit.set()
111
+
112
+
113
+ css = (current_dir / "style.css").read_text()
114
+ header = (current_dir / "header.html").read_text()
115
+
116
+ with gr.Blocks(css=css) as demo:
117
+ gr.HTML(header)
118
+ with gr.Group(visible=True, elem_id="api-form") as api_key_row:
119
+ with gr.Row():
120
+ _project_id = gr.Textbox(
121
+ label="Project ID",
122
+ placeholder="Enter your Google Cloud Project ID",
123
+ )
124
+ _location = gr.Dropdown(
125
+ label="Location",
126
+ choices=[
127
+ "us-central1",
128
+ ],
129
+ value="us-central1",
130
+ info="You can find additional locations [here](https://cloud.google.com/vertex-ai/generative-ai/docs/learn/locations#united-states)",
131
+ )
132
+ _voice_name = gr.Dropdown(
133
+ label="Voice",
134
+ choices=[
135
+ "Puck",
136
+ "Charon",
137
+ "Kore",
138
+ "Fenrir",
139
+ "Aoede",
140
+ ],
141
+ value="Puck",
142
+ )
143
+ _system_instruction = gr.Textbox(
144
+ label="System Instruction",
145
+ placeholder="Talk like a pirate.",
146
+ )
147
+ with gr.Row():
148
+ submit = gr.Button(value="Submit")
149
+ with gr.Row(visible=False) as row:
150
+ webrtc = WebRTC(
151
+ label="Conversation",
152
+ modality="audio",
153
+ mode="send-receive",
154
+ # See for changes needed to deploy behind a firewall
155
+ # https://fastrtc.org/deployment/
156
+ rtc_configuration=None,
157
+ )
158
+
159
+ webrtc.stream(
160
+ GeminiHandler(),
161
+ inputs=[webrtc, _project_id, _location, _voice_name, _system_instruction],
162
+ outputs=[webrtc],
163
+ time_limit=90,
164
+ concurrency_limit=2,
165
+ )
166
+ submit.click(
167
+ lambda: (gr.update(visible=False), gr.update(visible=True)),
168
+ None,
169
+ [api_key_row, row],
170
+ )
171
+
172
+
173
+ if __name__ == "__main__":
174
+ demo.launch()
header.html ADDED
@@ -0,0 +1,12 @@
 
 
 
 
 
 
 
 
 
 
 
 
 
1
+ <div style="text-align: center">
2
+ <h1>Gemini 2.0 Multimodal Live API Demo</h1>
3
+ <p>Speak with Gemini using real-time audio streaming</p>
4
+ <p>
5
+ You will need to enable Vertex AI
6
+ <a href="https://console.cloud.google.com/flows/enableapi?apiid=aiplatform.googleapis.com">here</a>
7
+ </p>
8
+ <p>
9
+ Also make sure you have enabled default credentials
10
+ <a href="https://cloud.google.com/docs/authentication/provide-credentials-adc#how-to">here</a>
11
+ </p>
12
+ </div>
prettierrc.txt ADDED
@@ -0,0 +1,6 @@
 
 
 
 
 
 
 
1
+ {
2
+ "tabWidth": 2,
3
+ "useTabs": false,
4
+ "printWidth": 120,
5
+ "bracketSameLine": true
6
+ }
requirements.txt ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ gradio>=5.12.0,<6.0
2
+ fastrtc
3
+ librosa
4
+ google-genai==1.10.0
style.css ADDED
@@ -0,0 +1,4 @@
 
 
 
 
 
1
+ #api-form {
2
+ width: 80%;
3
+ margin: auto;
4
+ }