import asyncio import base64 import io import os import time import numpy as np import cv2 from PIL import Image from google import genai from fastrtc import AsyncAudioVideoStreamHandler, wait_for_item, WebRTCError, WebRTC, get_cloudflare_turn_credentials import gradio as gr # --- Encoder Helpers --- def encode_audio(data: np.ndarray) -> dict: return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")} def encode_image(data: np.ndarray) -> dict: if len(data.shape) == 3 and data.shape[2] == 3: data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB) with io.BytesIO() as output_bytes: pil_image = Image.fromarray(data) pil_image.thumbnail([1024, 1024]) pil_image.save(output_bytes, "JPEG") return {"mime_type": "image/jpeg", "data": base64.b64encode(output_bytes.getvalue()).decode("utf-8")} # --- Gemini Handler --- class GeminiLiveHandler(AsyncAudioVideoStreamHandler): def __init__(self) -> None: super().__init__(expected_layout="mono", output_sample_rate=24000, input_sample_rate=16000) self.audio_queue = asyncio.Queue() self.video_queue = asyncio.Queue() self.session = None self.quit = asyncio.Event() def copy(self) -> "GeminiLiveHandler": return GeminiLiveHandler() async def start_up(self): await self.wait_for_args() api_key = os.environ.get("GEMINI_API_KEY") if not api_key: raise WebRTCError("API Key missing! Please set GEMINI_API_KEY in Secrets.") system_instruction = self.latest_args[1] client = genai.Client(api_key=api_key, http_options={"api_version": "v1beta"}) config = { "response_modalities": ["AUDIO"], "system_instruction": system_instruction or "You are a helpful AI assistant.", "speech_config": {"voice_config": {"prebuilt_voice_config": {"voice_name": "Zephyr"}}} } try: async with client.aio.live.connect(model="gemini-2.0-flash-exp", config=config) as session: self.session = session # Bot speaks first to confirm connection await self.session.send(input="Hello! I am connected and ready. How can I help?", end_of_turn=True) async for response in self.session.receive(): if self.quit.is_set(): break if data := response.data: self.audio_queue.put_nowait(np.frombuffer(data, dtype=np.int16).reshape(1, -1)) except Exception as e: raise WebRTCError(f"Connection Error: {str(e)}") async def video_receive(self, frame: np.ndarray): self.video_queue.put_nowait(frame) if self.latest_args[2] != "none" and self.session: await self.session.send(input=encode_image(frame)) async def video_emit(self) -> np.ndarray: frame = await wait_for_item(self.video_queue, 0.01) return frame if frame is not None else np.zeros((480, 640, 3), dtype=np.uint8) async def receive(self, frame: tuple[int, np.ndarray]) -> None: if self.session: await self.session.send(input=encode_audio(frame[1].squeeze())) async def emit(self): array = await wait_for_item(self.audio_queue, 0.01) return (self.output_sample_rate, array) if array is not None else None async def shutdown(self) -> None: if self.session: self.quit.set() await self.session.close() # --- Custom UI --- with gr.Blocks() as demo: gr.Markdown("# 🎙️ Gemini Live: Voice & Vision") with gr.Row(): with gr.Column(scale=1): instruction = gr.Textbox(label="System Instruction", value="Be helpful and concise.") # 1. User selects mode mode = gr.Radio(choices=["camera", "screen", "none"], label="Select Video Mode") # 2. Start button is hidden until mode is selected start_btn = gr.Button("🚀 Start Conversation", variant="primary", visible=False) with gr.Column(scale=2): # 3. WebRTC component is hidden until Start is clicked webrtc = WebRTC( label="Gemini Live Stream", modality="audio-video", mode="send-receive", visible=False, rtc_configuration=get_cloudflare_turn_credentials() ) # Show start button once a radio option is picked mode.change(lambda x: gr.update(visible=True) if x else gr.update(visible=False), [mode], [start_btn]) # When Start is clicked, show the video/audio interface def on_start(): return gr.update(visible=True) start_btn.click(on_start, None, [webrtc]) # Connect the WebRTC stream to the handler webrtc.stream( fn=GeminiLiveHandler(), inputs=[webrtc, instruction, mode], outputs=[webrtc], time_limit=900 ) if __name__ == "__main__": demo.launch(server_name="0.0.0.0", server_port=7860)