File size: 5,050 Bytes
f40004d
 
 
 
 
 
 
 
 
0f20060
f40004d
 
0f20060
f40004d
0f20060
f40004d
 
 
 
 
 
 
 
0f20060
f40004d
0f20060
f40004d
 
0f20060
f40004d
 
 
 
 
 
 
 
 
 
0d25cdd
f40004d
0f20060
719147e
0d25cdd
0f20060
f40004d
 
 
0f20060
 
f40004d
719147e
0f20060
 
 
 
 
719147e
0f20060
719147e
0f20060
 
 
719147e
f8857ac
 
0f20060
f8857ac
 
 
 
719147e
f8857ac
 
 
0f20060
f8857ac
 
 
719147e
f8857ac
 
 
 
 
0f20060
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
 
f8857ac
 
0f20060
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
import asyncio
import base64
import io
import os
import time
import numpy as np
import cv2
from PIL import Image
from google import genai
from fastrtc import AsyncAudioVideoStreamHandler, wait_for_item, WebRTCError, WebRTC, get_cloudflare_turn_credentials
import gradio as gr

# --- Encoder Helpers ---
def encode_audio(data: np.ndarray) -> dict:
    return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")}

def encode_image(data: np.ndarray) -> dict:
    if len(data.shape) == 3 and data.shape[2] == 3:
        data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
    with io.BytesIO() as output_bytes:
        pil_image = Image.fromarray(data)
        pil_image.thumbnail([1024, 1024])
        pil_image.save(output_bytes, "JPEG")
        return {"mime_type": "image/jpeg", "data": base64.b64encode(output_bytes.getvalue()).decode("utf-8")}

# --- Gemini Handler ---
class GeminiLiveHandler(AsyncAudioVideoStreamHandler):
    def __init__(self) -> None:
        super().__init__(expected_layout="mono", output_sample_rate=24000, input_sample_rate=16000)
        self.audio_queue = asyncio.Queue()
        self.video_queue = asyncio.Queue()
        self.session = None
        self.quit = asyncio.Event()

    def copy(self) -> "GeminiLiveHandler":
        return GeminiLiveHandler()

    async def start_up(self):
        await self.wait_for_args()
        api_key = os.environ.get("GEMINI_API_KEY")
        if not api_key:
            raise WebRTCError("API Key missing! Please set GEMINI_API_KEY in Secrets.")

        system_instruction = self.latest_args[1]
        client = genai.Client(api_key=api_key, http_options={"api_version": "v1beta"})
        
        config = {
            "response_modalities": ["AUDIO"],
            "system_instruction": system_instruction or "You are a helpful AI assistant.",
            "speech_config": {"voice_config": {"prebuilt_voice_config": {"voice_name": "Zephyr"}}}
        }

        try:
            async with client.aio.live.connect(model="gemini-2.0-flash-exp", config=config) as session:
                self.session = session
                # Bot speaks first to confirm connection
                await self.session.send(input="Hello! I am connected and ready. How can I help?", end_of_turn=True)
                async for response in self.session.receive():
                    if self.quit.is_set(): break
                    if data := response.data:
                        self.audio_queue.put_nowait(np.frombuffer(data, dtype=np.int16).reshape(1, -1))
        except Exception as e:
            raise WebRTCError(f"Connection Error: {str(e)}")

    async def video_receive(self, frame: np.ndarray):
        self.video_queue.put_nowait(frame)
        if self.latest_args[2] != "none" and self.session:
            await self.session.send(input=encode_image(frame))

    async def video_emit(self) -> np.ndarray:
        frame = await wait_for_item(self.video_queue, 0.01)
        return frame if frame is not None else np.zeros((480, 640, 3), dtype=np.uint8)

    async def receive(self, frame: tuple[int, np.ndarray]) -> None:
        if self.session:
            await self.session.send(input=encode_audio(frame[1].squeeze()))

    async def emit(self):
        array = await wait_for_item(self.audio_queue, 0.01)
        return (self.output_sample_rate, array) if array is not None else None

    async def shutdown(self) -> None:
        if self.session:
            self.quit.set()
            await self.session.close()

# --- Custom UI ---
with gr.Blocks() as demo:
    gr.Markdown("# ๐ŸŽ™๏ธ Gemini Live: Voice & Vision")
    
    with gr.Row():
        with gr.Column(scale=1):
            instruction = gr.Textbox(label="System Instruction", value="Be helpful and concise.")
            # 1. User selects mode
            mode = gr.Radio(choices=["camera", "screen", "none"], label="Select Video Mode")
            # 2. Start button is hidden until mode is selected
            start_btn = gr.Button("๐Ÿš€ Start Conversation", variant="primary", visible=False)
        
        with gr.Column(scale=2):
            # 3. WebRTC component is hidden until Start is clicked
            webrtc = WebRTC(
                label="Gemini Live Stream",
                modality="audio-video",
                mode="send-receive",
                visible=False,
                rtc_configuration=get_cloudflare_turn_credentials()
            )

    # Show start button once a radio option is picked
    mode.change(lambda x: gr.update(visible=True) if x else gr.update(visible=False), [mode], [start_btn])

    # When Start is clicked, show the video/audio interface
    def on_start():
        return gr.update(visible=True)

    start_btn.click(on_start, None, [webrtc])
    
    # Connect the WebRTC stream to the handler
    webrtc.stream(
        fn=GeminiLiveHandler(),
        inputs=[webrtc, instruction, mode],
        outputs=[webrtc],
        time_limit=900
    )

if __name__ == "__main__":
    demo.launch(server_name="0.0.0.0", server_port=7860)