Audio_chat_bot / app.py
IFMedTechdemo's picture
Update app.py
0f20060 verified
import asyncio
import base64
import io
import os
import time
import numpy as np
import cv2
from PIL import Image
from google import genai
from fastrtc import AsyncAudioVideoStreamHandler, wait_for_item, WebRTCError, WebRTC, get_cloudflare_turn_credentials
import gradio as gr
# --- Encoder Helpers ---
def encode_audio(data: np.ndarray) -> dict:
return {"mime_type": "audio/pcm", "data": base64.b64encode(data.tobytes()).decode("UTF-8")}
def encode_image(data: np.ndarray) -> dict:
if len(data.shape) == 3 and data.shape[2] == 3:
data = cv2.cvtColor(data, cv2.COLOR_BGR2RGB)
with io.BytesIO() as output_bytes:
pil_image = Image.fromarray(data)
pil_image.thumbnail([1024, 1024])
pil_image.save(output_bytes, "JPEG")
return {"mime_type": "image/jpeg", "data": base64.b64encode(output_bytes.getvalue()).decode("utf-8")}
# --- Gemini Handler ---
class GeminiLiveHandler(AsyncAudioVideoStreamHandler):
def __init__(self) -> None:
super().__init__(expected_layout="mono", output_sample_rate=24000, input_sample_rate=16000)
self.audio_queue = asyncio.Queue()
self.video_queue = asyncio.Queue()
self.session = None
self.quit = asyncio.Event()
def copy(self) -> "GeminiLiveHandler":
return GeminiLiveHandler()
async def start_up(self):
await self.wait_for_args()
api_key = os.environ.get("GEMINI_API_KEY")
if not api_key:
raise WebRTCError("API Key missing! Please set GEMINI_API_KEY in Secrets.")
system_instruction = self.latest_args[1]
client = genai.Client(api_key=api_key, http_options={"api_version": "v1beta"})
config = {
"response_modalities": ["AUDIO"],
"system_instruction": system_instruction or "You are a helpful AI assistant.",
"speech_config": {"voice_config": {"prebuilt_voice_config": {"voice_name": "Zephyr"}}}
}
try:
async with client.aio.live.connect(model="gemini-2.0-flash-exp", config=config) as session:
self.session = session
# Bot speaks first to confirm connection
await self.session.send(input="Hello! I am connected and ready. How can I help?", end_of_turn=True)
async for response in self.session.receive():
if self.quit.is_set(): break
if data := response.data:
self.audio_queue.put_nowait(np.frombuffer(data, dtype=np.int16).reshape(1, -1))
except Exception as e:
raise WebRTCError(f"Connection Error: {str(e)}")
async def video_receive(self, frame: np.ndarray):
self.video_queue.put_nowait(frame)
if self.latest_args[2] != "none" and self.session:
await self.session.send(input=encode_image(frame))
async def video_emit(self) -> np.ndarray:
frame = await wait_for_item(self.video_queue, 0.01)
return frame if frame is not None else np.zeros((480, 640, 3), dtype=np.uint8)
async def receive(self, frame: tuple[int, np.ndarray]) -> None:
if self.session:
await self.session.send(input=encode_audio(frame[1].squeeze()))
async def emit(self):
array = await wait_for_item(self.audio_queue, 0.01)
return (self.output_sample_rate, array) if array is not None else None
async def shutdown(self) -> None:
if self.session:
self.quit.set()
await self.session.close()
# --- Custom UI ---
with gr.Blocks() as demo:
gr.Markdown("# ๐ŸŽ™๏ธ Gemini Live: Voice & Vision")
with gr.Row():
with gr.Column(scale=1):
instruction = gr.Textbox(label="System Instruction", value="Be helpful and concise.")
# 1. User selects mode
mode = gr.Radio(choices=["camera", "screen", "none"], label="Select Video Mode")
# 2. Start button is hidden until mode is selected
start_btn = gr.Button("๐Ÿš€ Start Conversation", variant="primary", visible=False)
with gr.Column(scale=2):
# 3. WebRTC component is hidden until Start is clicked
webrtc = WebRTC(
label="Gemini Live Stream",
modality="audio-video",
mode="send-receive",
visible=False,
rtc_configuration=get_cloudflare_turn_credentials()
)
# Show start button once a radio option is picked
mode.change(lambda x: gr.update(visible=True) if x else gr.update(visible=False), [mode], [start_btn])
# When Start is clicked, show the video/audio interface
def on_start():
return gr.update(visible=True)
start_btn.click(on_start, None, [webrtc])
# Connect the WebRTC stream to the handler
webrtc.stream(
fn=GeminiLiveHandler(),
inputs=[webrtc, instruction, mode],
outputs=[webrtc],
time_limit=900
)
if __name__ == "__main__":
demo.launch(server_name="0.0.0.0", server_port=7860)